In [16]:

import sys
import scanpy as sc
import os
sys.path.append(os.path.abspath(".."))

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import src.FRLC as FRLC
import src.FRLC.FRLC_multimarginal as FRLC_multimarginal
import src.HiddenMarkovOT

import src.utils.util_LR as util_LR
from src.utils.util_LR import convert_adata


idxs = [1]

In [17]:
import random
import numpy as np

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return

seed_everything(42)

In [18]:
import pandas as pd
import importlib

import src.utils.clustering
import src.HiddenMarkovOT as HiddenMarkovOT
import src.plotting as plotting

import pickle
import numpy as np
import differentiation_map_validation as dmv
importlib.reload(dmv)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Filehandle for final differentiation map outputs
diffmap_dir = "/scratch/gpfs/ph3641/hm_ot/ME_unsupervised_diffmap/"
os.makedirs(diffmap_dir, exist_ok=True)


# Cell-type labels for all timepoints + replicates
df_cell = pd.read_csv("/scratch/gpfs/ph3641/hm_ot/df_cell.csv")
df_cell = df_cell.set_index("cell_id")

edge_dir = "/scratch/gpfs/ph3641/hm_ot/edges.txt"
node_dir = "/scratch/gpfs/ph3641/hm_ot/nodes.txt"

# Load the data
nodes_df = pd.read_csv(node_dir, sep="\t")
edges_df = pd.read_csv(edge_dir, sep="\t")
G, labels_G = dmv.yield_differentiation_graph(nodes_df, edges_df, plotting=False)

diffmap_dir = "/scratch/gpfs/ph3641/hm_ot/ME_supervised_diffmap/"
diffmap_dir_moscot = "/scratch/gpfs/ph3641/hm_ot/ME_supervised_diffmap_moscot/"


In [19]:
import os
import numpy as np
import pickle
import importlib
importlib.reload(plotting)

timepoints =  ['E8.5', 'E8.75', 'E9.0', 'E9.25', 'E9.5', 'E9.75']
replicates = ['embryo_11', 'embryo_14', 'embryo_16', 'embryo_20', 'embryo_24', 'embryo_28']

_Ts = []
_Ts_m = []

_labels = []

for i in range(len(timepoints) - 1):
    
    t1 = timepoints[i]
    t2 = timepoints[i + 1]
    print(f'loading {t1} to {t2}')
    
    # Load T matrices (diff map for HM-OT supervised)
    T12 = np.load(os.path.join(diffmap_dir, f"{t1}_{t2}_T.npy"))
    _Ts.append(T12)
    
    #  Load T matrices (moscot)
    T12 = np.load(os.path.join(diffmap_dir_moscot, f"{t1}_{t2}_T.npy"))
    _Ts_m.append(T12)
    
    # Load labels
    with open(os.path.join(diffmap_dir, f"{t1}_types.pkl"), 'rb') as f:
        Label_t1 = pickle.load(f)
    with open(os.path.join(diffmap_dir, f"{t2}_types.pkl"), 'rb') as f:
        Label_t2 = pickle.load(f)
    
    if i == 0:
        _labels.append(Label_t1)
    _labels.append(Label_t2)


loading E8.5 to E8.75
loading E8.75 to E9.0
loading E9.0 to E9.25
loading E9.25 to E9.5
loading E9.5 to E9.75


In [20]:


def score_triples(Qs, Ts, labels, times, G, labels_G, edges_df):
    T13 = Ts[0] @ np.diag( 1/ np.sum(Qs[1], axis=0) ) @ Ts[1]
    edge_scores, diagonal_edge_scores = dmv.score_from_graph([Qs[0], Qs[2]], [T13], [labels[0], labels[2]],[times[0], times[2]],
                                 G, labels_G, edges_df,
                                )
    npmis = []
    diagonal_npmis = []
    
    for key in edge_scores:
        npmis.append(edge_scores[key])
    for key in diagonal_edge_scores:
        diagonal_npmis.append(diagonal_edge_scores[key])
    
    print(f'Median NPMI (off-diagonal): {np.median(npmis)}')
    print(f'Median NPMI (diagonal): {np.median(diagonal_npmis)}')
    print(f'Mean NPMI (off-diagonal): {np.mean(npmis)}')
    print(f'Mean NPMI (diagonal): {np.mean(diagonal_npmis)}')
    
    return edge_scores, diagonal_edge_scores

def merge_dicts_of_lists(dict1, dict2):
    merged = dict1.copy()
    for key, value in dict2.items():
        if key in merged:
            merged[key].extend(value)
        else:
            merged[key] = value.copy()
    return merged


In [21]:

filehandle_ME = f'/scratch/gpfs/ph3641/hm_ot/adata_JAX_dataset_1.h5ad'
sys.path.insert(0, filehandle_ME)
adata = sc.read_h5ad(filehandle_ME, backed="r")

save_dir = '/scratch/gpfs/ph3641/mouse_embryo/SC_pca_pairs/'
os.makedirs(save_dir, exist_ok=True)


for i in range(1, len(timepoints) - 1):
    t1, t2, t3 = timepoints[i-1], timepoints[i], timepoints[i+1]
    r1, r2, r3 = replicates[i-1], replicates[i], replicates[i+1]
    
    fname = f"{save_dir}/subset_{t1}_{t2}_{t3}_r{r1}{r2}{r3}.h5ad"
    if not os.path.exists(fname):
        subset_adata = adata[
            (adata.obs['day'].isin([t1, t2, t3])) &
            (adata.obs['embryo_id'].isin([r1, r2, r3]))
        ]
        subset_adata = subset_adata.to_memory()
        subset_adata.obs = subset_adata.obs.set_index("cell_id")
        subset_adata.obs = subset_adata.obs.join(df_cell[['celltype_update']], how="left")
        print('-----Starting PCA!-----')
        sc.pp.normalize_total(subset_adata, target_sum=1e4)
        sc.pp.log1p(subset_adata)
        sc.pp.pca(subset_adata, n_comps=30)
        print('-----PCA done!-----')
        subset_adata.write(fname, compression="gzip")
        print("Wrote", fname)



In [None]:
importlib.reload(util_LR)
importlib.reload(HiddenMarkovOT)
importlib.reload(FRLC)
importlib.reload(FRLC.FRLC_multimarginal)
importlib.reload(dmv)

import src.utils.clustering as clustering
import differentiation_map_validation as dmv
from src.FRLC.FRLC_multimarginal import FRLC_LR_opt_multimarginal
from sklearn.metrics import adjusted_mutual_info_score as ami
from sklearn.metrics import adjusted_rand_score as ari
import copy

"""
Loading file / AnnData
"""
filehandle_ME = f'/scratch/gpfs/ph3641/hm_ot/adata_JAX_dataset_1.h5ad'
sys.path.insert(0, filehandle_ME)
adata = sc.read_h5ad(filehandle_ME, backed="r")

moscot_diag = {}
moscot_offdiag = {}
hmot_s_diag = {}
hmot_s_offdiag = {}
hmot_u_diag = {}
hmot_u_offdiag = {}

for i in range(1, len(timepoints) - 1):
    
    t1, t2, t3 = timepoints[i-1], timepoints[i], timepoints[i+1]
    r1, r2, r3 = replicates[i-1], replicates[i], replicates[i+1]
    
    fname = f"{save_dir}/subset_{t1}_{t2}_{t3}_r{r1}{r2}{r3}.h5ad"
    subset_adata = sc.read_h5ad(fname)
    
    """
    rank1 = subset_adata[subset_adata.obs['day'] == t1].obs['celltype_update'].nunique()
    rank2 = subset_adata[subset_adata.obs['day'] == t2].obs['celltype_update'].nunique()
    rank3 = subset_adata[subset_adata.obs['day'] == t3].obs['celltype_update'].nunique()
    """
    
    r_max = 150
    C_factors_sequence, A_factors_sequence, Qs, labels, rank_list, _spatial_list = convert_adata(subset_adata,
                                                                                                 timepoints=[t1, t2, t3],
                                                                                                 replicates=[r1, r2, r3],
                                                                                                 timepoint_key = 'day',
                                                                                                 replicate_key = 'embryo_id',
                                                                                                 feature_key = 'X_pca',
                                                                                                cell_type_key = 'celltype_update',
                                                                                                spatial = False,
                                                                                                dist_eps = 0.02,
                                                                                                dist_rank = r_max, 
                                                                                                device = device)
    
    proportions = [torch.sum(torch.tensor(Qs[i]).to(torch.float32), axis=0) for i in range(len(Qs))]
    
    # Annotation types
    Qs_ann = [Q.cpu().numpy() for Q in Qs]
    
    """
    Supervised HM-OT
    """
    hmot_sup = HiddenMarkovOT.HM_OT(rank_list, 
                                tau_in = 1e-3,
                                tau_out = 1e-3,
                                gamma= 7,
                                max_iter= 20,
                                min_iter= 20,
                                device=device,
                                dtype=torch.float32,
                                printCost=False,
                                returnFull=False,
                                alpha=0.0,
                                initialization='Full'
                               )
    
    hmot_sup.impute_annotated_transitions(C_factors_sequence, 
                                     A_factors_sequence, 
                                     copy.deepcopy(Qs))
    """
    Unsupervised HM-OT
    """
    hmot = HiddenMarkovOT.HM_OT(
                                rank_list, 
                                tau_in = 1e-3,
                                tau_out = 1e-3,
                                gamma= 80,
                                max_iter= 50,
                                min_iter= 50,
                                device=device,
                                dtype=torch.float32,
                                printCost=False,
                                returnFull=False,
                                alpha=0.0,
                                initialization='Full',
                                proportions = proportions,
                                max_inner_iters_B = 200,
                                max_inner_iters_R = 200,
                               )

    '''
    hmot.gamma_smoothing(C_factors_sequence, A_factors_sequence, 
                         Qs_freeze = [False, False, False],
                         Qs_IC = [Qs[0], Qs[1], Qs[2]], 
                         warmup = True)'''

    # Evaluate middle to interpret
    hmot.gamma_smoothing(C_factors_sequence, A_factors_sequence, 
                         Qs_IC = [Qs[0], Qs[1], Qs[2]],
                         Qs_freeze = [True, False, True],
                        warmup = False)
    
    # Unsupervised types and transitions
    Qs_u = [Q_.cpu().numpy() for Q_ in hmot.Q_gammas]
    Ts = [T.cpu().numpy() for T in hmot.T_gammas]
    
    # Moscot and hmot supervised transitions
    Ts_m = [T for T in _Ts_m[i-1:i+1]]
    Ts_s = [T.cpu().numpy() for T in hmot_sup.T_gammas]
    
    # Plot diffmap
    plotting.diffmap_from_QT([ Qs_ann[0], Qs_ann[1], Qs_ann[2] ],
                         Ts_s,
                         [labels[0], labels[1], labels[2]],
                         dsf=0.01,
                        fontsize=5,
                        linethick_factor=55,
                        title=f"Alignment from {timepoints[0]} to {timepoints[-1]}",
                        save_name=os.path.join(f'/scratch/gpfs/ph3641/hm_ot/ME_supervised_figs/supervised_{timepoints[0]}_{timepoints[-1]}.svg') )
    
    plotting.diffmap_from_QT([ Qs_ann[0], Qs_u[1], Qs_ann[2] ],
                         Ts,
                         [None, None, None],
                         dsf=0.01,
                        fontsize=5,
                        linethick_factor=55,
                        title=f"Alignment from {timepoints[0]} to {timepoints[-1]}",
                        save_name=os.path.join(f'/scratch/gpfs/ph3641/hm_ot/ME_supervised_figs/unsupervised_{timepoints[0]}_{timepoints[-1]}.svg') )
    
    plotting.diffmap_from_QT([ Qs_ann[0], Qs_ann[1], Qs_ann[2] ],
                         Ts_m,
                         [labels[0], labels[1], labels[2]],
                         dsf=0.01,
                        fontsize=5,
                        linethick_factor=55,
                        title=f"Alignment from {timepoints[0]} to {timepoints[-1]}",
                        save_name=os.path.join(f'/scratch/gpfs/ph3641/hm_ot/ME_supervised_figs/moscot_{timepoints[0]}_{timepoints[-1]}.svg') )
    
    clus_u = clustering.max_likelihood_clustering([Qs_u[1]], mode='standard')[0]
    clus_ann = clustering.max_likelihood_clustering([Qs_ann[1]], mode='standard')[0]
    
    print(f"ami of predictions: {ami(clus_u, clus_ann):.3f} for slice {i} (gamma)")
    print(f"ari of predictions: {ari(clus_u, clus_ann):.3f} for slice {i} (gamma)")
    
    

Computing low-rank distance matrix!
Computing low-rank distance matrix!


  proportions = [torch.sum(torch.tensor(Qs[i]).to(torch.float32), axis=0) for i in range(len(Qs))]


Iteration: 0
Iteration: 0
Iteration: 0
Iteration: 25
Iteration: 0
Iteration: 25
Iteration: 0
