# visualzing output of WDM on zf

## imports

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install scanpy -q

In [3]:
!pip install -U plotly -q   # for making Sankey diagram

In [4]:
!pip install -U kaleido -q   # for saving Sankey diagram

In [5]:
import pandas as pd
import scanpy as sc
import numpy as np
import torch
import sys
from scipy.spatial.distance import cdist
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import sys
import importlib

import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [6]:
filehandle_wdm = 'drive/Othercomputers/numac/GitHub/WDM/'
filehandle_load_peter = 'drive/Othercomputers/numac/GitHub/celltypediscovery/_wdm_peter/'
filehandle_zf = 'drive/MyDrive/DX/_data/zebrafish/cleaned_common_pca/'
filehandle_ctd = 'drive/Othercomputers/numac/GitHub/celltypediscovery/'


sys.path.insert(0, filehandle_wdm)
sys.path.insert(0, filehandle_load_peter)
sys.path.insert(0, filehandle_zf)
sys.path.insert(0, filehandle_ctd)

import clustering
import util_LR
import util_zf
import FRLC_LRDist

## reload

In [7]:
importlib.reload(clustering)
importlib.reload(util_LR)
importlib.reload(util_zf)
importlib.reload(FRLC_LRDist)

<module 'FRLC_LRDist' from '/content/drive/Othercomputers/numac/GitHub/WDM/FRLC_LRDist.py'>

## device

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'On device: {device}')
dtype = torch.float64

On device: cuda


## load zf: spatial, annotation hard clusterings as $Q$'s

In [9]:
# daniocell_dir = '/scratch/gpfs/ph3641/zebrafish/daniocell.h5ad'
# adata_daniocell = sc.read_h5ad(daniocell_dir)
# stages = np.unique(adata_daniocell.obs['hpf'].values)

zf_list = ['zf3', 'zf5', 'zf10', 'zf12', 'zf18', 'zf24']
zf_names = zf_list
filehandles_zf = [filehandle_zf + 'pair' + str(k) + '/' for k in range(len(zf_list))]
N = len(filehandles_zf)

spatial_list = []
exclude_rows = [1099, None, None, 325, None, None] # [None, None, None, None, None, None] #

for i in range(len(zf_list)):
    file_spatial = filehandles_zf[i] + zf_names[i] +'_spatial.npy'
    if i == len(zf_list) - 1:
        file_spatial = filehandles_zf[i-1] + zf_names[i] +'_spatial.npy'
    spatial = np.load(file_spatial)
    nidx = exclude_rows[i]
    if nidx is not None:
        spatial = np.concatenate((spatial[:nidx,:], spatial[nidx+1:,:]))
    spatial_list.append(spatial)

In [10]:
N = len(filehandles_zf)
Qs = [None]*(N-1)
Rs = [None]*(N-1)

key = 'bin_annotation'

labels = []

for i in range(N-2, -1, -1):

    s1_name = zf_names[i]
    s2_name = zf_names[i+1]

    print(f'Slice pair {i}, aligning {s1_name} to {s2_name}')

    adata_pair = sc.read_h5ad(filehandle_zf + f'pair{i}/' + s1_name + '_' + s2_name + '.h5ad')

    adata1 = adata_pair[adata_pair.obs['timepoint'] == 1]
    # sc.pp.normalize_total(adata1)
    # sc.pp.log1p(adata1)

    adata2 = adata_pair[adata_pair.obs['timepoint'] == 2]
    # sc.pp.normalize_total(adata2)
    # sc.pp.log1p(adata2)

    encoder1 = OneHotEncoder(sparse_output=False)
    ys_onehot1 = encoder1.fit_transform(adata1.obs[key].values.reshape(-1, 1))
    _Q = ys_onehot1 / np.sum(ys_onehot1)

    encoder2 = OneHotEncoder(sparse_output=False)
    ys_onehot2 = encoder2.fit_transform(adata2.obs[key].values.reshape(-1, 1))
    _R = ys_onehot2 / np.sum(ys_onehot2)

    labels.append(list(encoder2.categories_[0]))

    if i == 0:
        labels.append(list(encoder1.categories_[0]))

    _Q, _R = torch.from_numpy(_Q).to(device).float(), torch.from_numpy(_R).to(device).float()

    # Filter rows
    nidx_1, nidx_2 = exclude_rows[i], exclude_rows[i+1]
    if nidx_1 is not None:
        _Q = torch.cat((_Q[:nidx_1,:], _Q[nidx_1+1:,:]))
    if nidx_2 is not None:
        _R = torch.cat((_R[:nidx_2,:], _R[nidx_2+1:,:]))

    Qs[i], Rs[i] = _Q, _R

Slice pair 4, aligning zf18 to zf24
Slice pair 3, aligning zf12 to zf18
Slice pair 2, aligning zf10 to zf12
Slice pair 1, aligning zf5 to zf10
Slice pair 0, aligning zf3 to zf5


In [11]:
Qs_gt = Qs + [Rs[4]]
len(Qs_gt)

6

## set ranks, set `i` = first timepoint of three: `i`, `i+1`, `i+2`

In [12]:
ranks = [(3,7),(7,7),(7,11),(11,14),(14,19)]

i=2

r1, r2 = ranks[i]
r2, r3 = ranks[i+1]

## load features and spatial coords specific to the triple

In [13]:
s0_name = zf_names[i]
s1_name = zf_names[i+1]
s2_name = zf_names[i+2]

filehandle_pair1 = filehandle_zf + 'pair' + str(i) + '/'
filehandle_pair2 = filehandle_zf + 'pair' + str(i+1) + '/'

X0 = np.load(filehandle_pair1 + s0_name + '_feature.npy')
X1 = np.load(filehandle_pair1 + s1_name + '_feature.npy')
X2 = np.load(filehandle_pair2 + s2_name + '_feature.npy')

S0 = spatial_list[i]
S1 = spatial_list[i+1]
S2 = spatial_list[i+2]

## load `WDM` output

In [14]:
Q0_pred = np.load(filehandle_load_peter + 'Q0.npy')
Q1_pred = np.load(filehandle_load_peter + 'Q1.npy')
Q2_pred = np.load(filehandle_load_peter + 'Q2.npy')
Q3_pred = np.load(filehandle_load_peter + 'Q3.npy')
Q4_pred = np.load(filehandle_load_peter + 'Q4.npy')
Q5_pred = np.load(filehandle_load_peter + 'Q5.npy')

#T01_pred = np.load(filehandle_load_peter + 'T01.npy')
#T12_pred = np.load(filehandle_load_peter + 'T12.npy')
#T23_pred = np.load(filehandle_load_peter + 'T23.npy')
#T34_pred = np.load(filehandle_load_peter + 'T34.npy')

Qs_pred = [Q0_pred, Q1_pred, Q2_pred, Q3_pred, Q4_pred, Q5_pred]
#Ts_pred = [Q0_pred, Q1_pred, Q2_pred, Q3_pred, Q4_pred]

## make `ml` clusterings

In [15]:
pred_clustering_list = clustering.max_likelihood_clustering(Qs_pred)
gt_clustering_list = clustering.max_likelihood_clustering([Q.cpu().numpy() for Q in Qs_gt])

## load labels

In [16]:
slice0_types = list(np.load(filehandle_ctd + 'a_gt_types/' + 'slice0_types.npy'))
slice1_types = list(np.load(filehandle_ctd + 'a_gt_types/' + 'slice1_types.npy'))
slice2_types = list(np.load(filehandle_ctd + 'a_gt_types/' + 'slice2_types.npy'))
slice3_types = list(np.load(filehandle_ctd + 'a_gt_types/' + 'slice3_types.npy'))
slice4_types = list(np.load(filehandle_ctd + 'a_gt_types/' + 'slice4_types.npy'))
slice5_types = list(np.load(filehandle_ctd + 'a_gt_types/' + 'slice5_types.npy'))

In [17]:
ct_labels = [ list(set(slice0_types)),
               list(set(slice1_types)),
                list(set(slice2_types)),
                list(set(slice3_types)),
                list(set(slice4_types)),
                list(set(slice5_types))]

## sankey for all time points

In [19]:
importlib.reload(util_zf)

<module 'util_zf' from '/content/drive/Othercomputers/numac/GitHub/WDM/util_zf.py'>

In [20]:
for i in range(len(Qs_gt)):
   util_zf.make_sankey(gt_clustering_list[i],
                pred_clustering_list[i],
                ct_labels[i],
                save_format='jpg',
                save_name=f'sankey_slice_{i}',
                title=f'Sankey between predicted and GT, slice: {i}')