In [13]:
from hydra import initialize, compose
from omegaconf import OmegaConf
from pathlib import Path

import sys 
sys.path.insert(0, "../../src")

import numpy as np
import scipy.spatial.distance as spd
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

from dataloader import SingleCellAndCodexDataset 
from model import FlowMatchingModelWrapper
from torch.utils.data import random_split
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

plt.rcParams['figure.figsize'] = (3, 3) 

In [2]:
adata_codex = sc.read_h5ad("/nfs/staff-hdd/pala/scportrait/data/codex/cellfeaturization_results_healthy_processed.h5ad")
adata_single_cell = sc.read_h5ad("/nfs/staff-hdd/pala/scportrait/data/scrnaseq/sce_converted_processed_discovery.h5ad")

In [3]:
(adata_codex.X[:, 0]==0).any()

np.False_

In [4]:
adata_single_cell.X

array([[2.87997273, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [5.13941058, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [2.2890183 , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [1.95297184, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [2.53255074, 0.        , 1.78594984, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]], shape=(263286, 55))

In [5]:
with initialize(config_path="../../configs/"):
    # Step 2: Compose the configuration
    config_dict = compose(config_name="train")  # replace 'config' with your actual config name

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../../configs/"):


In [6]:
dataset = SingleCellAndCodexDataset(config_dict.datamodule.rna_adata_path, 
                                        config_dict.datamodule.codex_adata_path, 
                                        config_dict.datamodule.label_columns, 
                                        config_dict.datamodule.obsm_key_rna, 
                                        config_dict.datamodule.obsm_key_codex) 

In [7]:
dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=64,
                                            shuffle=True,
                                            num_workers=4)



In [18]:
distances = {}
for codex_cl in np.unique(adata_codex.obs.shared_leiden_cluster_id):
    distances[codex_cl] = {"RNA cell type": [], "Distance": []}
    for rna_cl in np.unique(adata_single_cell.obs.annotation_figure_1):
        distances[codex_cl]["RNA cell type"].append(rna_cl)
        X_codex = adata_codex[adata_codex.obs.shared_leiden_cluster_id==codex_cl].X
        X_rna = adata_single_cell[adata_single_cell.obs.annotation_figure_1==rna_cl].X
        ave_dist = np.matmul(X_codex, X_rna.T).mean()
        distances[codex_cl]["Distance"].append(ave_dist)

In [19]:
distances

{'0': {'RNA cell type': ['Activated NBC',
   'CD4 T',
   'CD8 T',
   'DC',
   'DN',
   'FDC',
   'GCBC',
   'Granulocytes',
   'ILC',
   'MBC',
   'Mast',
   'Mono/Macro',
   'NBC',
   'NK',
   'Naive CD4 T',
   'Naive CD8 T',
   'PC',
   'PDC',
   'cycling FDC',
   'cycling T',
   'cycling myeloid',
   'epithelial',
   'preB/T'],
  'Distance': [np.float64(0.14369722559737264),
   np.float64(0.22015643949299407),
   np.float64(0.19977377183045825),
   np.float64(0.15253882771839344),
   np.float64(0.20746984565572535),
   np.float64(0.11107734441493625),
   np.float64(0.14828930087261227),
   np.float64(0.0950939859959813),
   np.float64(0.18350662453760377),
   np.float64(0.1690949386629015),
   np.float64(0.11534250118241929),
   np.float64(0.12628834678898987),
   np.float64(0.1482847341235781),
   np.float64(0.1754566103955214),
   np.float64(0.18331112641057465),
   np.float64(0.16936823834708506),
   np.float64(0.06547655611748164),
   np.float64(0.13784242154134294),
   np.float