In [2]:
import os
from os.path import join

import pandas as pd
import numpy as np

In [3]:
DATA_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm_parquet'

# Compute lookup matrices 

In [4]:
cell_type_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/cell_type.parquet'))

inverse_mapping = (
    cell_type_mapping
    .assign(idx=range(len(cell_type_mapping)))
    .set_index('label', drop=True)
)
inverse_mapping.head()

Unnamed: 0_level_0,idx
label,Unnamed: 1_level_1
B cell,0
"CD4-positive, alpha-beta T cell",1
"CD8-positive, alpha-beta T cell",2
capillary endothelial cell,3
classical monocyte,4


In [5]:
from sfaira.consts import OC

celltype_ontology = OC.cell_type
celltypes = cell_type_mapping.label.tolist()



In [6]:
parents = []
parents_idx = []

children = []
children_idxs = []

for cell_type in cell_type_mapping.label:
    parent_nodes = [celltype_ontology.convert_to_name(node) for node in celltype_ontology.get_descendants(cell_type)]
    parent_nodes = [node for node in parent_nodes if node in celltypes]
    parents.append(parent_nodes)
    parents_idx.append(inverse_mapping.loc[parent_nodes].idx.tolist())
    
    child_nodes = [celltype_ontology.convert_to_name(node) for node in celltype_ontology.get_ancestors(cell_type)]
    child_nodes = [node for node in child_nodes if node in celltypes]
    children.append(child_nodes)
    children_idxs.append(inverse_mapping.loc[child_nodes].idx.tolist())

    
cell_type_mapping['children'] = children_idxs
cell_type_mapping['parents'] = parents_idx

In [7]:
cell_type_mapping.head()

Unnamed: 0,label,children,parents
0,B cell,"[98, 19, 50, 100, 53]","[48, 111]"
1,"CD4-positive, alpha-beta T cell","[25, 31, 102, 26, 96, 107, 115, 54]","[48, 69, 111, 30, 17]"
2,"CD8-positive, alpha-beta T cell","[28, 108, 127, 36, 103, 27, 126, 55]","[48, 69, 111, 30, 17]"
3,capillary endothelial cell,"[84, 83]","[37, 40, 33]"
4,classical monocyte,[67],"[48, 18]"


In [8]:
os.makedirs(join(DATA_PATH, 'cell_type_hierarchy'), exist_ok=True)

In [9]:
parent_matrix = np.eye(len(cell_type_mapping))

for i, parent_nodes in enumerate(cell_type_mapping.parents):
    parent_matrix[i, parent_nodes] = 1.

with open(join(DATA_PATH, 'cell_type_hierarchy/parent_matrix.npy'), 'wb') as f:
    np.save(f, parent_matrix)


In [10]:
child_matrix = np.eye(len(cell_type_mapping))

for i, child_nodes in enumerate(cell_type_mapping.children):
    child_matrix[i, child_nodes] = 1.
    
with open(join(DATA_PATH, 'cell_type_hierarchy/child_matrix.npy'), 'wb') as f:
    np.save(f, child_matrix)


# Sanity check lookup matrices

In [11]:
import torch

In [12]:
child_lookup = torch.nn.Embedding.from_pretrained(torch.tensor(child_matrix))
parent_lookup = torch.nn.Embedding.from_pretrained(torch.tensor(parent_matrix))

In [13]:
cell_type_mapping.loc[np.where(child_lookup(torch.tensor([0])).numpy().squeeze() == 1.)[0]].label.tolist()

['B cell',
 'plasmablast',
 'memory B cell',
 'naive B cell',
 'immature B cell',
 'precursor B cell']

In [14]:
cell_type_mapping.loc[np.where(parent_lookup(torch.tensor([0])).numpy().squeeze() == 1.)[0]].label.tolist()

['B cell', 'leukocyte', 'lymphocyte']