In [1]:
import os
from os.path import join as pj
import argparse
import sys
sys.path.append("modules")
import utils
import numpy as np
import torch as th
import scib
import scib.metrics as me
import anndata as ad
import scipy
import pandas as pd
import re
import itertools
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score
from scipy.stats import pearsonr
import copy
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from operator import itemgetter
from collections import Counter
from tqdm import tqdm

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='atlas_tissues_cl')
parser.add_argument('--real_task', type=str, default='atlas_tissues')
parser.add_argument('--reference', type=str, default='atlas_new')
parser.add_argument('--experiment', type=str, default='c_3')
parser.add_argument('--real_experiment', type=str, default='offline')
parser.add_argument('--model', type=str, default='default')
parser.add_argument('--init_model', type=str, default='')
parser.add_argument('--init_model_real', type=str, default='sp_latest')
parser.add_argument('--method', type=str, default='midas_embed')
o, _ = parser.parse_known_args()  # for python interactive
# o = parser.parse_args()

In [3]:
# Load latent variables of reference data
cfg_task = re.sub("_atlas|_generalize|_transfer|_ref_.*", "", o.task)
data_config = utils.gen_data_config(cfg_task)
data_config_ref = utils.gen_data_config(o.reference)
data_config_ref["raw_data_dirs"] += data_config["raw_data_dirs"]
data_config_ref["raw_data_frags"] += data_config["raw_data_frags"]
data_config_ref["combs"] = data_config["combs"]
data_config_ref["comb_ratios"] = data_config["comb_ratios"]
data_config_ref["s_joint"] = data_config["s_joint"]
data_config = utils.load_toml("configs/data.toml")[cfg_task]
for k, v in data_config.items():
    vars(o)[k] = v
model_config = utils.load_toml("configs/model.toml")["default"]
if o.model != "default":
    model_config.update(utils.load_toml("configs/model.toml")[o.model])
for k, v in model_config.items():
    vars(o)[k] = v
o.s_joint, o.combs, *_ = utils.gen_all_batch_ids(o.s_joint, o.combs)
o.pred_dir = pj("result", o.task, o.experiment, o.model, "predict", o.init_model)
pred = utils.load_predicted(o, group_by="subset")

c = [v["z"]["joint"][:, :o.dim_c] for v in pred.values()]
subset_num = 34
c_ref = np.concatenate(c[:subset_num], axis=0)
c_orgs = np.concatenate(c[subset_num:], axis=0)
c_all = np.concatenate([c_orgs, c_ref], axis=0)

o_real = copy.deepcopy(o)
o_real.task = o.real_task

o_real.pred_dir = pj("result", o_real.task, o.real_experiment, o.model, "predict", o.init_model_real)
pred_real = utils.load_predicted(o_real, group_by="subset")

c_real = [v["z"]["joint"][:, :o_real.dim_c] for v in pred_real.values()]
# c_real = np.concatenate(c_real, axis=0)
c_ref_real = np.concatenate(c_real[:subset_num], axis=0)
c_orgs_real = np.concatenate(c_real[subset_num:], axis=0)
c_all_real = np.concatenate([c_orgs_real, c_ref_real], axis=0)
c_all_real = np.concatenate([c_orgs_real, c_ref_real], axis=0)
nbrs_real = NearestNeighbors(n_neighbors=15000, algorithm='ball_tree').fit(c_all_real)
distances_real, indices_real = nbrs_real.kneighbors(c_all_real)
indices_real


nbrs_all = NearestNeighbors(n_neighbors=15000, algorithm='ball_tree').fit(c_all)
distances_all, indices_all = nbrs_all.kneighbors(c_all)
indices_all

100%|██████████| 29/29 [00:00<00:00, 380.38it/s]
100%|██████████| 24/24 [00:00<00:00, 365.49it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Loading predicted variables ...
Loading subset 0: z, joint
Loading subset 1: z, joint
Loading subset 2: z, joint


100%|██████████| 40/40 [00:00<00:00, 329.40it/s]
100%|██████████| 38/38 [00:00<00:00, 360.89it/s]
100%|██████████| 29/29 [00:00<00:00, 343.00it/s]
  0%|          | 0/26 [00:00<?, ?it/s]

Loading subset 3: z, joint
Loading subset 4: z, joint
Loading subset 5: z, joint


100%|██████████| 26/26 [00:00<00:00, 389.54it/s]
100%|██████████| 27/27 [00:00<00:00, 352.14it/s]
100%|██████████| 27/27 [00:00<00:00, 354.04it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

Loading subset 6: z, joint
Loading subset 7: z, joint
Loading subset 8: z, joint


100%|██████████| 28/28 [00:00<00:00, 108.26it/s]
100%|██████████| 24/24 [00:00<00:00, 355.53it/s]
100%|██████████| 29/29 [00:00<00:00, 397.72it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Loading subset 9: z, joint
Loading subset 10: z, joint
Loading subset 11: z, joint


100%|██████████| 39/39 [00:00<00:00, 348.34it/s]
100%|██████████| 38/38 [00:00<00:00, 373.51it/s]
  0%|          | 0/44 [00:00<?, ?it/s]

Loading subset 12: z, joint
Loading subset 13: z, joint


100%|██████████| 44/44 [00:00<00:00, 365.22it/s]
100%|██████████| 11/11 [00:00<00:00, 397.16it/s]
100%|██████████| 17/17 [00:00<00:00, 344.22it/s]
100%|██████████| 21/21 [00:00<00:00, 387.79it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Loading subset 14: z, joint
Loading subset 15: z, joint
Loading subset 16: z, joint
Loading subset 17: z, joint


100%|██████████| 20/20 [00:00<00:00, 359.73it/s]
100%|██████████| 15/15 [00:00<00:00, 353.73it/s]
100%|██████████| 25/25 [00:00<00:00, 382.22it/s]
100%|██████████| 24/24 [00:00<00:00, 363.28it/s]
  0%|          | 0/19 [00:00<?, ?it/s]

Loading subset 18: z, joint
Loading subset 19: z, joint
Loading subset 20: z, joint
Loading subset 21: z, joint


100%|██████████| 19/19 [00:00<00:00, 329.25it/s]
100%|██████████| 21/21 [00:00<00:00, 383.17it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

Loading subset 22: z, joint
Loading subset 23: z, joint


100%|██████████| 28/28 [00:00<00:00, 80.10it/s]
100%|██████████| 24/24 [00:00<00:00, 364.60it/s]
100%|██████████| 35/35 [00:00<00:00, 373.45it/s]
  0%|          | 0/35 [00:00<?, ?it/s]

Loading subset 24: z, joint
Loading subset 25: z, joint
Loading subset 26: z, joint


100%|██████████| 35/35 [00:00<00:00, 353.88it/s]
100%|██████████| 22/22 [00:00<00:00, 397.52it/s]
100%|██████████| 20/20 [00:00<00:00, 355.73it/s]
  0%|          | 0/47 [00:00<?, ?it/s]

Loading subset 27: z, joint
Loading subset 28: z, joint
Loading subset 29: z, joint


100%|██████████| 47/47 [00:00<00:00, 367.57it/s]
100%|██████████| 58/58 [00:00<00:00, 332.67it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Loading subset 30: z, joint
Loading subset 31: z, joint


100%|██████████| 39/39 [00:00<00:00, 349.93it/s]
100%|██████████| 52/52 [00:00<00:00, 366.45it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

Loading subset 32: z, joint
Loading subset 33: z, joint


100%|██████████| 28/28 [00:00<00:00, 362.10it/s]
100%|██████████| 22/22 [00:00<00:00, 386.77it/s]
100%|██████████| 42/42 [00:00<00:00, 341.71it/s]
  0%|          | 0/17 [00:00<?, ?it/s]

Loading subset 34: z, joint
Loading subset 35: z, joint
Loading subset 36: z, joint


100%|██████████| 17/17 [00:00<00:00, 375.52it/s]


Converting to numpy ...
Converting subset 0: s, joint
Converting subset 0: z, joint
Converting subset 1: s, joint
Converting subset 1: z, joint
Converting subset 2: s, joint
Converting subset 2: z, joint
Converting subset 3: s, joint
Converting subset 3: z, joint
Converting subset 4: s, joint
Converting subset 4: z, joint
Converting subset 5: s, joint
Converting subset 5: z, joint
Converting subset 6: s, joint
Converting subset 6: z, joint
Converting subset 7: s, joint
Converting subset 7: z, joint
Converting subset 8: s, joint
Converting subset 8: z, joint
Converting subset 9: s, joint
Converting subset 9: z, joint
Converting subset 10: s, joint
Converting subset 10: z, joint
Converting subset 11: s, joint
Converting subset 11: z, joint
Converting subset 12: s, joint
Converting subset 12: z, joint
Converting subset 13: s, joint
Converting subset 13: z, joint
Converting subset 14: s, joint
Converting subset 14: z, joint
Converting subset 15: s, joint
Converting subset 15: z, joint
Conv

Converting subset 22: s, joint
Converting subset 22: z, joint
Converting subset 23: s, joint
Converting subset 23: z, joint
Converting subset 24: s, joint
Converting subset 24: z, joint
Converting subset 25: s, joint
Converting subset 25: z, joint
Converting subset 26: s, joint
Converting subset 26: z, joint
Converting subset 27: s, joint
Converting subset 27: z, joint
Converting subset 28: s, joint
Converting subset 28: z, joint
Converting subset 29: s, joint
Converting subset 29: z, joint
Converting subset 30: s, joint
Converting subset 30: z, joint
Converting subset 31: s, joint
Converting subset 31: z, joint


 38%|███▊      | 11/29 [00:00<00:00, 84.14it/s]

Converting subset 32: s, joint
Converting subset 32: z, joint
Converting subset 33: s, joint
Converting subset 33: z, joint
Converting subset 34: s, joint
Converting subset 34: z, joint
Converting subset 35: s, joint
Converting subset 35: z, joint
Converting subset 36: s, joint
Converting subset 36: z, joint
Loading predicted variables ...
Loading subset 0: z, joint


100%|██████████| 29/29 [00:00<00:00, 166.19it/s]
100%|██████████| 24/24 [00:00<00:00, 357.84it/s]
100%|██████████| 40/40 [00:00<00:00, 351.14it/s]
  0%|          | 0/38 [00:00<?, ?it/s]

Loading subset 1: z, joint
Loading subset 2: z, joint
Loading subset 3: z, joint


100%|██████████| 38/38 [00:00<00:00, 356.02it/s]
100%|██████████| 29/29 [00:00<00:00, 353.16it/s]
100%|██████████| 26/26 [00:00<00:00, 354.63it/s]
  0%|          | 0/27 [00:00<?, ?it/s]

Loading subset 4: z, joint
Loading subset 5: z, joint
Loading subset 6: z, joint


100%|██████████| 27/27 [00:00<00:00, 351.37it/s]
100%|██████████| 27/27 [00:00<00:00, 382.05it/s]
100%|██████████| 28/28 [00:00<00:00, 356.51it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

Loading subset 7: z, joint
Loading subset 8: z, joint
Loading subset 9: z, joint


100%|██████████| 24/24 [00:00<00:00, 350.73it/s]
100%|██████████| 29/29 [00:00<00:00, 350.11it/s]
100%|██████████| 39/39 [00:00<00:00, 352.65it/s]


Loading subset 10: z, joint
Loading subset 11: z, joint
Loading subset 12: z, joint


100%|██████████| 38/38 [00:00<00:00, 121.09it/s]
100%|██████████| 44/44 [00:00<00:00, 358.69it/s]
100%|██████████| 11/11 [00:00<00:00, 326.21it/s]
  0%|          | 0/17 [00:00<?, ?it/s]

Loading subset 13: z, joint
Loading subset 14: z, joint
Loading subset 15: z, joint


100%|██████████| 17/17 [00:00<00:00, 381.15it/s]
100%|██████████| 21/21 [00:00<00:00, 345.96it/s]
100%|██████████| 20/20 [00:00<00:00, 341.63it/s]
100%|██████████| 15/15 [00:00<00:00, 392.45it/s]
  0%|          | 0/25 [00:00<?, ?it/s]

Loading subset 16: z, joint
Loading subset 17: z, joint
Loading subset 18: z, joint
Loading subset 19: z, joint


100%|██████████| 25/25 [00:00<00:00, 353.42it/s]
100%|██████████| 24/24 [00:00<00:00, 373.16it/s]
100%|██████████| 19/19 [00:00<00:00, 352.25it/s]
100%|██████████| 21/21 [00:00<00:00, 355.04it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

Loading subset 20: z, joint
Loading subset 21: z, joint
Loading subset 22: z, joint
Loading subset 23: z, joint


100%|██████████| 28/28 [00:00<00:00, 385.18it/s]
100%|██████████| 24/24 [00:00<00:00, 346.57it/s]
100%|██████████| 35/35 [00:00<00:00, 366.21it/s]
  0%|          | 0/35 [00:00<?, ?it/s]

Loading subset 24: z, joint
Loading subset 25: z, joint
Loading subset 26: z, joint


100%|██████████| 35/35 [00:00<00:00, 356.98it/s]
100%|██████████| 22/22 [00:00<00:00, 220.69it/s]
100%|██████████| 20/20 [00:00<00:00, 340.69it/s]
  0%|          | 0/47 [00:00<?, ?it/s]

Loading subset 27: z, joint
Loading subset 28: z, joint
Loading subset 29: z, joint


100%|██████████| 47/47 [00:00<00:00, 96.93it/s] 
100%|██████████| 58/58 [00:00<00:00, 340.71it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Loading subset 30: z, joint
Loading subset 31: z, joint


100%|██████████| 39/39 [00:00<00:00, 304.41it/s]
100%|██████████| 52/52 [00:00<00:00, 374.04it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

Loading subset 32: z, joint
Loading subset 33: z, joint


100%|██████████| 28/28 [00:00<00:00, 346.53it/s]
100%|██████████| 22/22 [00:00<00:00, 350.84it/s]
100%|██████████| 42/42 [00:00<00:00, 357.89it/s]
  0%|          | 0/17 [00:00<?, ?it/s]

Loading subset 34: z, joint
Loading subset 35: z, joint
Loading subset 36: z, joint


100%|██████████| 17/17 [00:00<00:00, 324.39it/s]


Converting to numpy ...
Converting subset 0: s, joint
Converting subset 0: z, joint
Converting subset 1: s, joint
Converting subset 1: z, joint
Converting subset 2: s, joint
Converting subset 2: z, joint
Converting subset 3: s, joint
Converting subset 3: z, joint
Converting subset 4: s, joint
Converting subset 4: z, joint
Converting subset 5: s, joint
Converting subset 5: z, joint
Converting subset 6: s, joint
Converting subset 6: z, joint
Converting subset 7: s, joint
Converting subset 7: z, joint
Converting subset 8: s, joint
Converting subset 8: z, joint
Converting subset 9: s, joint
Converting subset 9: z, joint
Converting subset 10: s, joint
Converting subset 10: z, joint
Converting subset 11: s, joint
Converting subset 11: z, joint
Converting subset 12: s, joint
Converting subset 12: z, joint
Converting subset 13: s, joint
Converting subset 13: z, joint
Converting subset 14: s, joint
Converting subset 14: z, joint
Converting subset 15: s, joint
Converting subset 15: z, joint
Conv

array([[     0,   2755, 156305, ...,  48745, 124476, 127633],
       [     1,   2231,   1436, ..., 216293,  98148,   3475],
       [     2,   2075,  16500, ...,  36751, 167059, 110644],
       ...,
       [273247, 272576, 272969, ..., 158523,  25479, 100087],
       [273248, 272000, 271372, ...,  81099, 144553,  58367],
       [273249, 272504, 273088, ...,  61377, 131481,  62091]])

In [None]:
neigh_size = [5, 10, 100, 1000, 2500, 5000, 10000, 15000]
overlap_num_list = []
sum_list = []
overlap_percent = []
for j in tqdm(neigh_size):
    overlap_num = 0
    sum_num = 0
    for i in range(len(c_orgs)):
        overlap_list = list(set(indices_all[i,1:j]) & set(indices_real[i,1:j]))
        overlap_num += len(overlap_list)
        sum_num += j-1
    overlap_num_list.append(overlap_num)
    sum_list.append(sum_num)
    overlap_percent.append(overlap_num/sum_num)

In [None]:
neigh_size = [5, 10, 100, 1000, 2500, 5000, 10000, 15000]
overlap_num_list = []
sum_list_all = []
overlap_percent_all = []
for j in tqdm(neigh_size):
    overlap_num = 0
    sum_num = 0
    for i in range(len(c_all)):
        overlap_list = list(set(indices_all[i,1:j]) & set(indices_real[i,1:j]))
        overlap_num += len(overlap_list)
        sum_num += j-1
    overlap_num_list.append(overlap_num)
    sum_list_all.append(sum_num)
    overlap_percent_all.append(overlap_num/sum_num)

In [None]:
fig = plt.figure()
ax=plt.axes()
ax.spines['bottom'].set_linewidth('1.0')
ax.spines['left'].set_linewidth('1.0')
ax.spines['top'].set_linewidth('1.0')
ax.spines['right'].set_linewidth('1.0')
plt.tick_params(axis="both", which="major", width=1, length=3)
plt.plot(neigh_size[:8], overlap_percent[:8], label='Query Neighborhood overlap', linewidth =1.5)
plt.plot(neigh_size[:8], overlap_percent_all[:8], label='Overall Neighborhood overlap', linewidth =1.5)

plt.xlabel('Neighborhood size', fontsize=12)
plt.ylabel('Neighborhood overlap', fontsize=12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.title('atlas_tissues_cl', fontsize=12)
plt.legend(prop = {'size':8}, frameon=False)
plt.savefig("overlap_tissues_cl.png")
plt.show()


In [None]:
df = pd.DataFrame(overlap_percent)
df1 = pd.DataFrame(overlap_percent_all)
result_dir = pj("analysis", "overlap")
utils.mkdirs(result_dir, remove_old=False)
df.to_csv(pj(result_dir, "overlap_percent_tissues.csv"), index=False)
df1.to_csv(pj(result_dir, "overlap_percent_all_tissues.csv"), index=False)