In [None]:
import scanpy as sc
import anndata as ad
import numpy as np

import sys
import os
import importlib
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
sys.path.append(os.path.abspath(".."))

import src.HiddenMarkovOT as HiddenMarkovOT
import src.utils.util_LR as util_LR
from src.utils.util_LR import convert_adata
import src.plotting as plotting


In [None]:
import random

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return

seed_everything(42)

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'On device: {device}')
dtype = torch.float64


In [None]:

# --- zebrafish slices and their labels
zf_names = ['zf3', 'zf5', 'zf10', 'zf12', 'zf18', 'zf24']
N = len(zf_names)

adata_pairs = []
base = '/scratch/gpfs/ph3641/zebrafish/'


In [None]:
print('Loading zebrafish AnnDatas')

for i in range(N-1):           # 0,1,2,...,N-2
    s1, s2 = zf_names[i], zf_names[i+1]
    pair_dir = os.path.join(base, f'pair{i}')
    fname    = f'{s1}_{s2}.h5ad'

    print(f'Slice pair {i}: {s1} to {s2}')
    ad = sc.read_h5ad(os.path.join(pair_dir, fname))

    sc.pp.normalize_total(ad)
    sc.pp.log1p(ad)
    sc.pp.pca(ad, n_comps=30)

    # free memory
    ad.X = None
    if 'count' in ad.layers:
        del ad.layers['count']

    adata_pairs.append(ad)

print('PCA Finished!')

# save
save_dir = '/scratch/gpfs/ph3641/zebrafish/pca_pairs/'
os.makedirs(save_dir, exist_ok=True)

for ad, (s1, s2) in zip(adata_pairs, zip(zf_names, zf_names[1:])):
    out_name = f'zebrafish_pair_{s1}_{s2}.h5ad'
    ad.write(os.path.join(save_dir, out_name))


In [None]:

load_dir = '/scratch/gpfs/ph3641/zebrafish/pca_pairs/'

adata_pairs = []

for (s1, s2) in zip(zf_names, zf_names[1:]):
    out_name = f'zebrafish_pair_{s1}_{s2}.h5ad'
    out_path = os.path.join(save_dir, out_name)
    ad = sc.read_h5ad(out_path)
    adata_pairs.append(ad)



In [None]:
adata_pairs[0]

In [None]:
adata_pairs[1].obs['time']

In [None]:

from src.utils.util_LR import convert_adata_pairwise

timepoints = ['10hpf', 
              '12hpf', 
              '18hpf']

adata_pairs = [adata_pairs[2], adata_pairs[3]]

# Set torch device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Convert AnnData in HM-OT ready factors
C_factors_sequence, A_factors_sequence, Qs, labels, rank_list, spatial_list = convert_adata_pairwise(adata_pairs,
                                                                                    timepoints=timepoints,
                                                                                     timepoint_key = 'time',
                                                                                     replicate_key = 'time',
                                                                                     feature_key = 'X_pca',
                                                                                     cell_type_key = 'bin_annotation',
                                                                                     spatial = True,
                                                                                     spatial_key = ['spatial_x', 'spatial_y'],
                                                                                     fallback_spatial_key = 'spatial',
                                                                                     dist_eps = 0.02,
                                                                                     dist_rank = 100,
                                                                                     dist_rank_2 = 50,
                                                                                     device = device,
                                                                                    normalize=True
                                                                                )


In [None]:
import src.utils.clustering as clustering
from sklearn.metrics import adjusted_mutual_info_score as ami
from sklearn.metrics import adjusted_rand_score as ari

importlib.reload(HiddenMarkovOT)

num_iter = 2
running_min_cost = torch.inf

for i in range(num_iter):

    # Set a replicable random seed
    gen = torch.Generator(device=device)
    gen.manual_seed(i + 42)
    
    # Initialize HM-OT 0.9999
    hmot_p = HiddenMarkovOT.HM_OT(
                                rank_list = rank_list,
                                tau_in = 1e-3,
                                tau_out = 1e-3,
                                gamma = 80,
                                max_iter = 100,
                                min_iter = 100,
                                device=device,
                                alpha = 0.5,
                                dtype = torch.float32,
                                printCost = True,
                                returnFull = False,
                                initialization = 'Full',
                                generator = gen
                               )
    
    # Run HM-OT with Qs_IC from only timepoint 1, and freeze it
    hmot_p.gamma_smoothing(C_factors_sequence,
                         A_factors_sequence,
                         Qs_IC = [Qs[0], None, Qs[2]],
                         Qs_freeze = [True, False, True],
                        warmup = False
                        )
    
    if hmot_p.compute_total_cost(C_factors_sequence, A_factors_sequence) < running_min_cost:
        hmot_partial = hmot_p


In [None]:


# Extract outputs
Qs_hmot_partial = [Q.cpu().numpy() for Q in hmot_partial.Q_gammas]
Ts_hmot_partial = [T.cpu().numpy() for T in hmot_partial.T_gammas]

# Plotting results
Qs_ann = [Q.cpu().numpy() for Q in Qs]
Qs_hmot_partial = [Q.cpu().numpy() for Q in hmot_partial.Q_gammas]
Ts_hmot_partial = [T.cpu().numpy() for T in hmot_partial.T_gammas]

clusterings_ann = clustering.max_likelihood_clustering(Qs_ann, mode='standard')
#clusterings_hmot = clustering.max_likelihood_clustering(Qs_hmot_partial, mode='standard')
clusterings_hmot = clustering.reference_clustering(
                Qs = Qs_hmot_partial,
                Ts = Ts_hmot_partial,
                reference_index = 2
                )

for i, gt_types in enumerate(clusterings_ann[:-1]):
    
    print(f'slice {i}')
    
    pred_types = clusterings_hmot[i]
    
    print(f"ami of predictions: {ami(gt_types, pred_types):.3f} for slice {i} (gamma)")
    print(f"ari of predictions: {ari(gt_types, pred_types):.3f} for slice {i} (gamma)")

# Visualize the transfer from time 1 to time 2
plotting.plot_clusters_from_QT(spatial_list, Qs_hmot_partial,
                               Ts_hmot_partial, [labels[0], labels[0], labels[0]], clustering_type='reference',
                               reference_index=0, flip=False, dotsize=40)
