In [1]:
import torch
import numpy as np
import umap
import hdbscan
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import scipy
from collections import defaultdict
from sklearn.metrics.pairwise import euclidean_distances

In [2]:
l2_coeff = torch.load('PBMC3k/attn_layers_data/pbmc3k_hidden_layer_2_attention_coeff.pt')
l2_ei = torch.load('PBMC3k/attn_layers_data/pbmc3k_hidden_layer_2_attention_edge_index.pt')

In [3]:
l2_coeff.shape

torch.Size([86251, 16])

In [4]:
l2_ei.shape

torch.Size([2, 86251])

In [5]:
dim = None
k = 80

edges = l2_ei.T
if not dim:
    w = l2_coeff.mean(dim=1)
else:
    w = l2_coeff[:, dim]
w = w.squeeze()
top_values, top_indices = torch.topk(w, k)

In [6]:
top_values

tensor([0.8296, 0.8181, 0.8174, 0.8171, 0.8141, 0.8018, 0.7811, 0.7683, 0.7664,
        0.7474, 0.7367, 0.7322, 0.6623, 0.6453, 0.6392, 0.6340, 0.6263, 0.5891,
        0.5548, 0.5268, 0.5240, 0.5213, 0.5008, 0.4999, 0.4984, 0.4851, 0.4840,
        0.4592, 0.4496, 0.4424, 0.4256, 0.4206, 0.4192, 0.4157, 0.4047, 0.4020,
        0.3964, 0.3787, 0.3765, 0.3759, 0.3758, 0.3690, 0.3651, 0.3648, 0.3310,
        0.3236, 0.3208, 0.3204, 0.3201, 0.3148, 0.3122, 0.3115, 0.3107, 0.3062,
        0.2969, 0.2937, 0.2925, 0.2882, 0.2864, 0.2858, 0.2818, 0.2805, 0.2779,
        0.2755, 0.2753, 0.2745, 0.2731, 0.2669, 0.2648, 0.2540, 0.2479, 0.2464,
        0.2437, 0.2429, 0.2422, 0.2370, 0.2366, 0.2358, 0.2357, 0.2353],
       grad_fn=<TopkBackward>)

In [7]:
top_indices

tensor([19996, 20127, 20158, 23743, 58699, 17332, 16783, 58766, 56862, 58800,
        65252, 56831, 55254, 50618, 16978, 55330,  7193, 20263, 23609, 55349,
         8141, 55098,  8312, 55394,  8070,  8029,  8334, 55446, 65374, 79625,
         8323, 79442, 20203,  8434,  8435, 65175, 55124, 23690, 23949, 55252,
        23800, 79782, 55191, 55406, 32974, 32635, 48259, 17049, 55150, 55165,
        56659, 17414, 50552,  8408,  7999, 16706, 47750,  8307,  8508, 65025,
        33007, 62380, 26156, 48499, 48132, 50631, 79916, 32801, 65264, 55222,
        32781, 37054, 38718, 48623, 37128, 62686,  8288, 19971, 17300, 16964])

In [8]:
seurat_df = pd.read_csv('PBMC3k/seurat_clusters.csv')
cell_types = seurat_df['Cluster'].values.tolist()

In [9]:
top_edges = edges[top_indices]

In [10]:
latent = np.load('PBMC3k/CellVGAE/cellvgae_node_embs.npy')
u = np.load('PBMC3k/CellVGAE/cellvgae_umap.npy')

In [11]:
latent.shape

(2638, 50)

In [12]:
selected_cells = latent[torch.unique(top_edges).numpy()]

In [13]:
selected_cells.shape

(89, 50)

In [14]:
ed = euclidean_distances(selected_cells, selected_cells)

In [16]:
ed.shape

(89, 89)

In [15]:
selected_cells_types = seurat_df['Cluster'].values[torch.unique(top_edges)]

In [61]:
len(selected_cells_types.tolist())

89

In [68]:
cell_to_cluster = seurat_df.to_dict()['Cluster']

In [71]:
cell_type_counter = defaultdict(int)
new_cell_names = []
for cell_cluster in cell_to_cluster.values():
    cnt = cell_type_counter[cell_cluster]
    cell_type_counter[cell_cluster] += 1
    new_cell_names.append(cell_cluster + ' ' + str(cnt))

In [79]:
selected_cells_names = np.array(new_cell_names)[torch.unique(top_edges)]

In [80]:
df = pd.DataFrame(ed, columns=selected_cells_names)

In [81]:
df.insert(0, 'Cell', selected_cells_names)

In [103]:
df

Unnamed: 0,Cell,FCGR3A+ Mono 1,T 18,T 30,T 60,T 86,T 96,T 97,FCGR3A+ Mono 16,Platelet 0,...,T 1054,T 1059,T 1080,T 1099,FCGR3A+ Mono 149,CD8 T 262,Platelet 13,T 1155,T 1157,DC 31
0,FCGR3A+ Mono 1,0.000000,2.001639,2.108213,2.024708,2.025151,2.064373,2.006296,0.057816,2.204433,...,2.219748,2.258620,2.222086,2.080150,0.167265,1.928919,2.217709,2.213296,2.028470,2.157073
1,T 18,2.001639,0.000000,0.489999,0.232578,0.390233,0.408154,0.206148,1.949221,1.649520,...,0.849409,0.928125,0.900572,0.436520,1.925066,0.376843,1.710673,0.876135,0.374226,2.214857
2,T 30,2.108213,0.489999,0.000000,0.261048,0.128851,0.103925,0.303473,2.058736,2.128136,...,0.361603,0.439352,0.413897,0.082917,2.067943,0.849508,2.188628,0.390020,0.137417,2.220363
3,T 60,2.024708,0.232578,0.261048,0.000000,0.160305,0.176402,0.053797,1.973174,1.869785,...,0.618279,0.697320,0.669733,0.204431,1.965537,0.600286,1.930733,0.646442,0.145451,2.211995
4,T 86,2.025151,0.390233,0.128851,0.160305,0.000000,0.044255,0.193124,1.974759,2.014036,...,0.465340,0.544933,0.515488,0.068839,1.978331,0.750581,2.074533,0.493267,0.026408,2.196364
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84,CD8 T 262,1.928919,0.376843,0.849508,0.600286,0.750581,0.774240,0.575558,1.877286,1.326404,...,1.205286,1.282322,1.253193,0.803325,1.828452,0.000000,1.384930,1.226497,0.734113,2.173274
85,Platelet 13,2.217709,1.710673,2.188628,1.930733,2.074533,2.098627,1.890777,2.176587,0.068751,...,2.535257,2.614046,2.582680,2.126901,2.055536,1.384930,0.000000,2.560737,2.062014,2.921122
86,T 1155,2.213296,0.876135,0.390020,0.646442,0.493267,0.476335,0.684649,2.168123,2.501199,...,0.052491,0.065942,0.037145,0.450277,2.201151,1.226497,2.560737,0.000000,0.506468,2.239413
87,T 1157,2.028470,0.374226,0.137417,0.145451,0.026408,0.057508,0.180758,1.978112,2.001473,...,0.479693,0.558992,0.529746,0.083341,1.980334,0.734113,2.062014,0.506468,0.000000,2.184808


In [87]:
edges_to_name = np.array([(np.array(new_cell_names)[edge[0]], np.array(new_cell_names)[edge[1]]) for edge in top_edges])

In [93]:
edges_to_name

array([['Platelet 4', 'T 287'],
       ['Platelet 4', 'CD8 T 150'],
       ['Platelet 4', 'CD8 T 159'],
       ['Platelet 5', 'T 743'],
       ['Platelet 10', 'T 846'],
       ['Platelet 3', 'T 865'],
       ['Platelet 3', 'T 96'],
       ['Platelet 10', 'T 982'],
       ['Platelet 9', 'T 808'],
       ['Platelet 10', 'T 1054'],
       ['Platelet 12', 'T 768'],
       ['Platelet 9', 'T 712'],
       ['Platelet 8', 'T 676'],
       ['Platelet 7', 'T 623'],
       ['Platelet 3', 'T 352'],
       ['Platelet 8', 'T 879'],
       ['Platelet 0', 'T 30'],
       ['Platelet 4', 'CD8 T 235'],
       ['Platelet 5', 'T 444'],
       ['Platelet 8', 'T 914'],
       ['Platelet 2', 'FCGR3A+ Mono 40'],
       ['Platelet 8', 'T 225'],
       ['Platelet 2', 'FCGR3A+ Mono 87'],
       ['Platelet 8', 'T 1040'],
       ['Platelet 2', 'FCGR3A+ Mono 16'],
       ['Platelet 2', 'FCGR3A+ Mono 1'],
       ['Platelet 2', 'FCGR3A+ Mono 96'],
       ['Platelet 8', 'T 1155'],
       ['Platelet 12', 'CD8 T 241'],
 

In [94]:
df.columns[1:].values

array(['FCGR3A+ Mono 1', 'T 18', 'T 30', 'T 60', 'T 86', 'T 96', 'T 97',
       'FCGR3A+ Mono 16', 'Platelet 0', 'Platelet 1', 'Platelet 2',
       'T 170', 'T 225', 'T 231', 'CD8 T 57', 'Platelet 3',
       'FCGR3A+ Mono 40', 'Platelet 4', 'T 287', 'T 295', 'CD8 T 68',
       'Platelet 5', 'T 352', 'NK 50', 'T 369', 'CD14+ Mono 144', 'T 400',
       'CD8 T 97', 'T 425', 'T 444', 'T 466', 'T 473', 'T 487', 'DC 13',
       'CD8 T 131', 'T 555', 'T 575', 'T 623', 'CD14+ Mono 255',
       'CD8 T 150', 'T 666', 'FCGR3A+ Mono 87', 'T 669', 'T 676', 'T 691',
       'FCGR3A+ Mono 94', 'Platelet 6', 'FCGR3A+ Mono 96', 'T 712',
       'Platelet 7', 'CD8 T 159', 'T 743', 'T 768', 'Platelet 8', 'T 808',
       'T 817', 'Platelet 9', 'T 846', 'Platelet 10', 'T 865', 'T 869',
       'T 879', 'T 881', 'T 884', 'FCGR3A+ Mono 112', 'NK 112',
       'Platelet 11', 'T 914', 'Platelet 12', 'FCGR3A+ Mono 120', 'T 959',
       'FCGR3A+ Mono 121', 'T 982', 'DC 22', 'CD8 T 223', 'CD8 T 235',
       'T 1040',

In [None]:
sorted_cell_names = sorted(selected_cells_names)

In [101]:
dct_order = {k : df.columns[1:].values.tolist().index(k) for k in sorted_cell_names}

In [106]:
edges_to_df_positions = np.array([(dct_order[edge_n[0]], dct_order[edge_n[1]]) for edge_n in edges_to_name])

In [124]:
zr = np.zeros(ed.shape)

In [140]:
for item in edges_to_df_positions:
    zr[item[0]][item[1]] = ed[item[0]][item[1]]

In [142]:
df_zr = pd.DataFrame(zr, columns=selected_cells_names)

In [143]:
df_zr.insert(0, 'Cell', selected_cells_names)

In [148]:
df_zr.to_csv('attn_euclidean_distance_top_edges.csv', index=False)

In [149]:
df_zr.shape

(89, 90)

In [150]:
df.to_csv('attn_euclidean_distance.csv', index=False)

In [152]:
np.nonzero(zr)[0].shape

(80,)