In [None]:
import glob
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from dredFISH.Utils import basicu
from dredFISH.Utils import designu
from dredFISH.Utils import powerplots
from dredFISH.Visualization import compile_tex

import importlib
config = importlib.import_module("dredfish_processing_config")

importlib.reload(basicu)
importlib.reload(powerplots)
importlib.reload(compile_tex)

In [None]:
path_dataset = '/bigstore/GeneralStorage/Data/dredFISH/NN1'
path_fig = os.path.join(path_dataset, 'figures')
if not os.path.isdir(path_fig):
    os.mkdir(path_fig)

files_mtx = np.sort(glob.glob('*_matrix.csv', dir_fd=path_dataset))
# files_meta = np.sort(glob.glob('*_metadata_filtered.csv', dir_fd=path_dataset))
files_meta = np.sort(glob.glob('*_metadata.csv', dir_fd=path_dataset))
files_mtx.shape, files_meta.shape

In [None]:
sections = {i: f.replace('_matrix.csv', '')#.replace('DPNM1A__11A_12B_2022Jul28_Section_', '')
            for i, f in enumerate(files_mtx)}
sections

In [None]:
def plot_basis_box(ftrs_mat, output=None, ylabel='zscore', ylim=[-3,3]):
    """
    """
    fig, axs = plt.subplots(2, 1, figsize=(10,3*2), sharex=True)
    for ax in axs:
        sns.boxplot(data=ftrs_mat, ax=ax)
        ax.set_ylabel('zscore')
    ax.set_xlabel('basis')
    ax.set_ylim(ylim)
    if output is not None:
        powerplots.savefig_autodate(fig, output)
    plt.show()

In [None]:
col_orders = [bit[1] for bit in config.bitmap]

In [None]:
sctn = 1
sctn_name = sections[sctn]
file_mtx = files_mtx[sctn]
file_meta = files_meta[sctn]
print(files_mtx[1], files_meta[1])

mtx = pd.read_csv(os.path.join(path_dataset, file_mtx), sep=',', index_col=0)
assert np.all(col_orders == mtx.columns.values) # check if the order agrees with the bitmap
mtx = mtx.iloc[:,:24]
mtx.columns = np.char.add('br', np.arange(24).astype(str))
meta = pd.read_csv(os.path.join(path_dataset, file_meta), sep=',', index_col=0)
df = meta.join(mtx)

# cond = df['cytoplasm_size'] > 10
cond = df['nuclei_signal'] > 1500
df = df.loc[cond]
mtx = df[np.char.add('br', np.arange(24).astype(str))]

# norm
ftrs_mat = basicu.normalize_fishdata(mtx.values, norm_cell=True, norm_basis=True, allow_nan=True)
raw_mat = mtx.values 
for i in range(ftrs_mat.shape[1]):
    df[f'b{i}'] = ftrs_mat[:,i]
    
# XY
XY = df[['stage_x', 'stage_y']].values
x = XY[:,0] 
y = XY[:,1] 

In [None]:
output = os.path.join(path_fig, f"fig1_xy_sect{sctn}_{sctn_name}.pdf")
fig, ax = plt.subplots(figsize=(10,10))
ax.scatter(x, y, s=1, edgecolor='none', color='black', rasterized=True)
ax.set_title(sctn_name)
ax.set_aspect('equal')
powerplots.savefig_autodate(fig, output)
plt.show()

# output = os.path.join(path_fig, f"fig2_basis_box_sect{sctn}_{sctn_name}.pdf")
output = None
plot_basis_box(ftrs_mat, output=output)


# output = os.path.join(path_fig, f"fig3_basis_xy_sect{sctn}_{sctn_name}.pdf")
output = None
powerplots.plot_basis_spatial(df, xcol='stage_x', ycol='stage_y', vmin=-1, vmax=1, output=output)

In [None]:
designu.plot_intn(raw_mat)

In [None]:
def plot_basis_box_v2(ftrs_mat, output=None, ylabel='zscore', ylim=[-3,3]):
    """
    """
    fig, axs = plt.subplots(2, 1, figsize=(10,3*2), sharex=True)
    for ax in axs:
        sns.boxplot(data=ftrs_mat, color='gray', ax=ax)
        ax.set_ylabel(ylabel)
    ax.set_xlabel('basis')
    ax.set_yscale('log')
    if output is not None:
        powerplots.savefig_autodate(fig, output)
    plt.show()

In [None]:
output = None
plot_basis_box_v2(raw_mat, output=output, ylabel="Vector")


In [None]:
fig, ax = plt.subplots()
ax.plot(np.percentile(raw_mat, 50, axis=0))
ax.plot(np.percentile(raw_mat, 10, axis=0))
ax.plot(np.percentile(raw_mat, 90, axis=0))
plt.show()

In [None]:
order = np.argsort(np.percentile(raw_mat, 50, axis=0))[::-1]

fig, ax = plt.subplots()
ax.plot(np.percentile(raw_mat, 50, axis=0)[order], '-o', markersize=5,  label='median')
ax.plot(np.percentile(raw_mat, 10, axis=0)[order], '-o', markersize=5,  label='10 perctl')
ax.plot(np.percentile(raw_mat, 90, axis=0)[order], '-o', markersize=5,  label='90 perctl')
ax.set_yscale('log')
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
sns.histplot(np.log10(raw_mat.sum(axis=1)+1), bins=np.linspace(2,5,50))

In [None]:
fig, axs = plt.subplots(4,6,figsize=(6*5,4*4))
for i in range(24):
    ax = axs.flat[i]
    sns.histplot(np.log10(raw_mat+1)[:,i], bins=np.linspace(1,3,50), ax=ax)
plt.show()

# push thru and ask more