In [2]:
from interpretable_ssl.immune.trainer import ImmuneTrainer

trainer = ImmuneTrainer()
trainer.batch_size = 64
trainer.hidden_dim = 64
trainer.num_prototypes = 32

  from .autonotebook import tqdm as notebook_tqdm


loading data
training with number of prototypes : num-prot-32_hidden-128_bs-32.pth


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [4]:
# load model
import torch
model = trainer.get_model()
model.load_state_dict(torch.load(trainer.get_model_path())["model_state_dict"])

<All keys matched successfully>

In [6]:
# decode each prototype

prototype_cells = model.decoder(model.prototype_vectors)
prototype_cells.shape

In [128]:
def calculate_mean_by_class(latent, y):
    M = torch.zeros(y.max()+1, len(latent))
    M = M.to(trainer.device)
    M[y, torch.arange(len(latent))] = 1
    M = torch.nn.functional.normalize(M, p=1, dim=1)
    class_means = torch.mm(M, latent)
    return class_means

In [116]:
# calculate all latent
adata = trainer.dataset.adata
x = adata.X.toarray()
x = torch.tensor(x)
x = x.to(trainer.device)
latent, _, _ = model(x)

In [124]:
y = adata.obs.cell_type.values
y = list(y)
y = trainer.dataset.le.transform(y)
y = torch.tensor(y, device=trainer.device)

In [129]:
class_mean_latent = calculate_mean_by_class(latent, y)
class_mean_cells = model.decoder(class_mean_latent)
prototype_cells.shape, class_mean_cells.shape

In [141]:
prototype_cells = prototype_cells.to(trainer.device)
class_mean_cells = class_mean_cells.reshape(16, 1, -1)
prototype_cells = prototype_cells.reshape(1, 32, -1)
class_mean_cells.shape, prototype_cells.shape

In [145]:
prototype_cells = prototype_cells.repeat(16, 1, 1)
diff = class_mean_cells - prototype_cells
diff.shape

In [148]:
k = 5
top_idx = torch.topk(diff, k, dim=2).indices
top_idx.shape

In [153]:
from interpretable_ssl.evaluation.interpretation import downstream
genes = adata.var.index
results = []
for i, cell_row in enumerate(top_idx):
    cell_downstreams = []
    for j, prot in enumerate(cell_row):
        cell_downstreams.append(downstream(list(genes[prot.cpu()])))
    results.append(cell_downstreams)

Index(['LINC00115', 'KLHL17', 'HES4', 'ISG15', 'B3GALT6', 'PUSL1', 'TMEM52',
       'PRKCZ', 'PEX10', 'PANK4',
       ...
       'RRP1', 'PWP2', 'TRPM2', 'ITGB2', 'SLC19A1', 'LSS', 'MCM3AP-AS1',
       'YBEY', 'PCNT', 'S100B'],
      dtype='object', name='index', length=4000)

In [168]:
results = []
for i, cell_row in enumerate(top_idx):
    cell_downstreams = []
    for j, prot in enumerate(cell_row):
        cell_downstreams.append(downstream(list(genes[prot.cpu()])))
    results.append(cell_downstreams)

In [175]:
index = trainer.dataset.le.inverse_transform(range(16))

In [176]:
import pandas as pd
columns = [f'prototype{i}' for i in range(1, 33)]
df = pd.DataFrame(results, columns = columns)
df.index = index
df.head()

Unnamed: 0,prototype1,prototype2,prototype3,prototype4,prototype5,prototype6,prototype7,prototype8,prototype9,prototype10,...,prototype23,prototype24,prototype25,prototype26,prototype27,prototype28,prototype29,prototype30,prototype31,prototype32
CD10+ B cells,[],[{'description': 'Viral protein interaction wi...,"[{'description': '""H2O + L-glutamate 5-semiald...","[{'description': '""Any process that results in...","[{'description': '""A defense response against ...",[],[],[],[],[],...,[],[],[],"[{'description': '""Binding to a chemokine rece...",[],[],"[{'description': '""The process in which the an...",[{'description': 'Viral protein interaction wi...,"[{'description': '""A defense response against ...",[]
CD14+ Monocytes,[],"[{'description': '""Combining with the C-X-C mo...",[],[],[],[],[],"[{'description': 'Efferocytosis', 'effective_d...","[{'description': 'Efferocytosis', 'effective_d...",[],...,"[{'description': 'Efferocytosis', 'effective_d...","[{'description': 'Efferocytosis', 'effective_d...",[],[],"[{'description': '""Any process that activates ...",[],[],"[{'description': '""Combining with the C-X-C mo...","[{'description': '""Any process that activates ...",[]
CD16+ Monocytes,[],"[{'description': '""A secretory organelle found...","[{'description': '""A ribonucleoprotein granule...",[],"[{'description': '""The sequential process in w...","[{'description': '""Binding to a ribosomal RNA....",[],[],[],[],...,[],[],"[{'description': '""A ribonucleoprotein granule...",[],[],"[{'description': '""A series of progressive, ov...",[],"[{'description': '""A secretory organelle found...","[{'description': '""An acute inflammatory respo...","[{'description': '""Any process that results in..."
CD20+ B cells,[],"[{'description': '""Any immune system process t...",[],"[{'description': '""Any process that results in...",[],"[{'description': '""A cellular process that res...",[],[],[],[{'description': 'Virion - Human immunodeficie...,...,[],[],[],"[{'description': '""Binding to heparin, a membe...",[],[],[],"[{'description': '""Any immune system process t...","[{'description': '""Any process that results in...",[]
CD4+ T cells,[{'description': 'Virion - Human immunodeficie...,"[{'description': '""Any immune system process t...","[{'description': 'Viral myocarditis', 'effecti...","[{'description': 'African trypanosomiasis', 'e...",[],[],[],"[{'description': '""Any process involved in the...",[],[{'description': 'Virion - Human immunodeficie...,...,"[{'description': '""Any process involved in the...","[{'description': '""Any process involved in the...","[{'description': 'Viral myocarditis', 'effecti...",[],[],[],[{'description': 'Virion - Human immunodeficie...,"[{'description': '""Any immune system process t...",[],[]


In [178]:
df.to_excel(f'./results/cell-type-prototype.xlsx')

In [179]:
result_genes = []
result_cnts = []
for i, cell_row in enumerate(top_idx):
    cell_genes = []
    cell_cnts = []
    for j, prot in enumerate(cell_row):
        cell_downstream = results[i][j]
        cell_cnts.append(len(cell_downstream))
        cell_genes.append(list(genes[prot.cpu()]))
    result_genes.append(cell_genes)
    result_cnts.append(cell_cnts)

In [180]:
with pd.ExcelWriter('./results/cell-type-prototype.xlsx') as writer:
    df.to_excel(writer, sheet_name='biological process')
    gene_df = pd.DataFrame(result_genes, columns = columns)
    gene_df.index = index
    gene_df.to_excel(writer, sheet_name='genes')
    cnt_df = pd.DataFrame(result_cnts, columns = columns)
    cnt_df.index = index
    cnt_df.to_excel(writer, sheet_name='biological process counts')

In [183]:
prototype_cells = model.decoder(model.prototype_vectors)

In [184]:
prototype_cells.shape

torch.Size([32, 4000])

In [185]:
x.shape

torch.Size([33506, 4000])

In [186]:
prot_dist = torch.cdist(prototype_cells, x)
prot_dist.shape()

In [190]:
nearest_cell_index = prot_dist.min(dim=1).indices
nearest_cell_index.shape

torch.Size([32])

In [227]:
def get_row(i):
    row = adata[i]
    obs = row.obs
    res = {'batch': obs.batch, 'cell_type':obs.cell_type, 
           'specie': obs.species, 'study': obs.study, 'index': i}
    
    for key in res:
        try:
            res[key] = list(res[key])
        except:
            continue   
    return res
get_row(0)

{'batch': ['Oetjen_A'],
 'cell_type': ['CD16+ Monocytes'],
 'specie': ['Human'],
 'study': ['Oetjen'],
 'index': 0}

In [228]:
df_list = []
for i in nearest_cell_index:
    df_list.append(get_row(int(i.cpu())))

In [229]:
protoype_df = pd.DataFrame(df_list)
protoype_df.index = columns

In [230]:
protoype_df.head()

Unnamed: 0,batch,cell_type,specie,study,index
prototype1,[Freytag],[CD20+ B cells],[Human],[Freytag],10672
prototype2,[Freytag],[Plasmacytoid dendritic cells],[Human],[Freytag],11194
prototype3,[Freytag],[CD4+ T cells],[Human],[Freytag],9800
prototype4,[Freytag],[CD4+ T cells],[Human],[Freytag],9800
prototype5,[Freytag],[CD20+ B cells],[Human],[Freytag],10672


In [231]:
protoype_df.to_excel('./results/prototype-cells.xlsx')
