In [None]:
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import scipy.sparse
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import keras as ks
import sklearn.metrics as metrics
import pandas as pd
import re
import json
import os

In [None]:
controls_final = "../Data/controls_final.h5ad"
granulomas_final = "../Data/granulomas_final.h5ad"
sc78_final = "../Data/sc78_final.h5ad"
sc92_final = "../Data/sc92_final.h5ad"
sc93_final = "../Data/sc93_final.h5ad"

In [None]:
# controls_final_anndata = sc.read_h5ad(controls_final) 
# granulomas_final_anndata = sc.read_h5ad(granulomas_final)
sc78_final_anndata = sc.read_h5ad(sc78_final)
sc92_final_anndata = sc.read_h5ad(sc92_final)
sc93_final_anndata = sc.read_h5ad(sc93_final)

def fix_and_print(anndata):
    anndata.uns['log1p']['base'] = None
    print(anndata.obs['sample'].value_counts())
    print('X matrix is sparse:', scipy.sparse.issparse(anndata.X))
    print('X size =', anndata.X.shape)
    print()

# fix_and_print(controls_final_anndata)
# fix_and_print(granulomas_final_anndata)
fix_and_print(sc78_final_anndata)
fix_and_print(sc92_final_anndata)
fix_and_print(sc93_final_anndata)

In [None]:
controls_final_annotation_dict = {
    '0': 'CAP1',
    '12': 'CAP2',
    '21': 'VEC',
    '17': 'AEC',
    '28': 'LEC',
    '14': 'Ciliated',
    '5': 'Secretory',
    '1': 'AT1',
    '2': 'AT2',
    '3': 'AF',
    '20': 'Pericyte',
    '26': 'SMC',
    '18': 'Mesothelial',
    '8': 'B1',
    '23b': 'Th1',
    '11': 'Tnaive',
    '24': 'NK',
    '10': 'AM',
    '15b': 'M-C1q',
    '25': 'iMon',
    '15': 'DC',
    '15c': 'pDC',
    '22': 'N1',
}

granulomas_final_annotation_dict = {
    '9': 'CAP1',
    '24': 'CAP2',
    '9b': 'VEC',
    '27': 'LEC',
    '17': 'Ciliated',
    '15': 'Secretory',
    '22': 'AT1',
    '6': 'AT2',
    '12': 'AT2-t1',
    '19': 'AT2-t2',
    '14': 'AF',
    '25': 'Pericyte',
    '20': 'Mesothelial',
    '3': 'B1',
    '3b': 'B2',
    '0': 'Th1',
    '8': 'Tnaive',
    '11': 'Tex',
    '77': 'Treg',
    '11b': 'NK',
    '4a': 'AM',
    '4': 'M-t1',
    '10': 'M-lc',
    '7': 'M-t2',
    '7b': 'M-C1q',
    '7c': 'iMon',
    '23': 'pDC',
    '13': 'DC',
    '5b': 'N1',
    '5': 'N2',
}

sc78_final_annotation_dict = {
    '3': 'CAP1',
    '3b': 'CAP2',
    '3c': 'VEC',
    '25': 'LEC',
    '16': 'Ciliated',
    '8': 'Secretory',
    '17': 'AT1',
    '7': 'AT2',
    '10': 'AT2-t1',
    '9': 'AT2-t2',
    '15': 'AF',
    '23': 'Pericyte',
    '21': 'Mesothelial',
    '13': 'B1',
    'Th1': 'Th1',
    'Tnaive': 'Tnaive',
    'Treg': 'Treg',
    'Tex': 'Tex',
    'NK': 'NK',
    '11': 'AM',
    '2b': 'M-C1q',
    '0': 'M-t1',
    '2': 'M-t2',
    '2c': 'iMon', 
    '18': 'DC',
    '18b': 'pDC',
    '22': 'N1',
    '5': 'N2'
}

sc92_final_annotation_dict = {
    '11': 'CAP1',
    '24': 'CAP2',
    '15': 'VEC',
    '28': 'LEC', 
    '18': 'Ciliated',
    '14': 'Secretory',
    '10': 'AT1',
    '4': 'AT2',
    '12': 'AT2-t1', 
    '8': 'AF',
    '25': 'Pericyte',
    '8c': 'SMC',
    '21b': 'Mesothelial',
    '6': 'B1',
    '26': 'B2',
    'Th1': 'Th1',
    'Treg': 'Treg',
    'Tex': 'Tex',
    'Tnaive': 'Tnaive',
    'Tnaive': 'Tnaive',
    'T': 'T',  
    '13': 'AM',
    '1': 'M-t1',
    '2': 'M-t2',
    '2b': 'M-C1q',
    '19b': 'pDC',
    '19': 'DC',
    '0': 'N2'
}

sc93_final_annotation_dict = {
    '9': 'CAP1',
    '9b': 'CAP2',
    '9c': 'VEC',
    '25': 'LEC',  
    '10': 'Ciliated',
    '13': 'Secretory',
    '11': 'AT1',
    '7': 'AT2',
    '16': 'AT2-t1',
    '3': 'AF',
    '28': 'Pericyte',
    '19': 'Mesothelial',
    '0': 'B1',
    '22': 'B2',
    '4': 'Th1',
    '4b': 'Treg',
    '2': 'Tnaive',
    '14': 'Tex',
    '14b': 'NK',
    '15': 'AM',
    '1': 'M-t1',
    '12': 'M-t2',
    '12b': 'M-C1q',
    '29': 'pDC',
    '21': 'DC',
    '26': 'N2',
}

In [None]:
def create_annotation(anndata, annotation_dict):
    anndata.obs['single_cell_types'] = [annotation_dict[clust] for clust in anndata.obs['my_clust_1']]
    dict_list = list(annotation_dict.keys())
    anndata_list = list(anndata.obs['my_clust_1'].unique())
    print('Keys in dictionary not in anndata:', [item for item in dict_list if item not in anndata_list])
    print('Keys in anndata not in dictionary:', [item for item in anndata_list if item not in dict_list])
    print()

In [None]:
# create_annotation(controls_final_anndata, controls_final_annotation_dict)
# create_annotation(granulomas_final_anndata, granulomas_final_annotation_dict)
create_annotation(sc78_final_anndata, sc78_final_annotation_dict)
create_annotation(sc92_final_anndata, sc92_final_annotation_dict)
create_annotation(sc93_final_anndata, sc93_final_annotation_dict)

In [None]:
def print_stats(anndata):
    unique_celltype_sub = anndata.obs['single_cell_types'].unique()
    print(unique_celltype_sub)
    num_unique_celltype_sub = anndata.obs['single_cell_types'].nunique()
    print(f"Number of unique sub cell types: {num_unique_celltype_sub}")
    print()

In [None]:
# print_stats(controls_final_anndata)
# print_stats(granulomas_final_anndata)
print_stats(sc78_final_anndata)
print_stats(sc92_final_anndata)
print_stats(sc93_final_anndata)

In [None]:
def holdout_subset(name, anndata, split, seed):
    train_anndata, test_anndata = train_test_split(anndata.obs.index, test_size=split, random_state=seed, stratify=anndata.obs['single_cell_types'].values)
    print(f'{name}_train_anndata shape', train_anndata.shape)
    print(f'{name}_test_anndata shape', test_anndata.shape)
    print()

    train_anndata = anndata[train_anndata].copy()
    test_anndata = anndata[test_anndata].copy()

    train_anndata.write(f"../Data/{name}_train_anndata.h5ad")
    test_anndata.write(f"../Data/{name}_test_anndata.h5ad")

In [None]:
seed = 8653
split = 0.2
# holdout_subset("controls_final", controls_final_anndata, split, seed)
# holdout_subset("granulomas_final", granulomas_final_anndata, split, seed)
holdout_subset("sc78_final", sc78_final_anndata, split, seed)
holdout_subset("sc92_final", sc92_final_anndata, split, seed)
holdout_subset("sc93_final", sc93_final_anndata, split, seed)

In [None]:
controls_final_anndata = sc.read_h5ad("../Data/controls_final_train_anndata.h5ad") 
granulomas_final_anndata = sc.read_h5ad("../Data/granulomas_final_train_anndata.h5ad")
sc78_final_anndata = sc.read_h5ad("../Data/sc78_final_train_anndata.h5ad")
sc92_final_anndata = sc.read_h5ad("../Data/sc92_final_train_anndata.h5ad")
sc93_final_anndata = sc.read_h5ad("../Data/sc93_final_train_anndata.h5ad")

In [None]:
def quick_look(anndata):
    anndata_hvg = anndata[:, anndata.var['highly_variable']].copy()
    print(anndata_hvg)
    print(anndata_hvg.obs['sample'].value_counts())
    print(list(anndata_hvg.obs['single_cell_types'].unique()))
    print(anndata_hvg.obs['single_cell_types'].nunique())
    print()

In [None]:
quick_look(controls_final_anndata)
quick_look(granulomas_final_anndata)
quick_look(sc78_final_anndata)
quick_look(sc92_final_anndata)
quick_look(sc93_final_anndata)

In [None]:
top_dict = {
    'Endothelial': ['CAP1','CAP2','VEC','AEC','LEC'],
    'Epithelial': ['Ciliated','Secretory','AT1','AT2','AT2-t1','AT2-t2'],
    'Mesenchyme': ['AF','Pericyte','SMC','Mesothelial'],
    'Immune': ['B1','B2','Th1','Tnaive','Treg','Tex','T','NK','AM','M-t1','M-t2','M-C1q','M-lc','iMon','DC','pDC','N1','N2'] # added 'T' b/c of sc92
}

second_dict = {
    'Blood vessels': ['CAP1','CAP2','VEC','AEC'],
    'Lymphatic EC': ['LEC'],
    'Airway epithelium': ['Ciliated','Secretory'],
    'Alveolar epithelium' : ['AT1','AT2','AT2-t1','AT2-t2'],
    'Stromal': ['AF','Pericyte','SMC'],
    'Mesothelial': ['Mesothelial'],
    'Lymphoid': ['B1','B2','Th1','Tnaive','Treg','Tex','T','NK'], # added 'T' b/c of sc92
    'Myeloid': ['AM','M-t1','M-t2','M-C1q','M-lc','iMon','DC','pDC','N1','N2']
}

third_dict = {
    'Blood vessels': ['CAP1','CAP2','VEC','AEC'],
    'Lymphatic EC': ['LEC'],
    'Airway epithelium': ['Ciliated','Secretory'],
    'Alveolar epithelium': ['AT1','AT2','AT2-t1','AT2-t2'],
    'Fibroblast': ['AF','Pericyte'],
    'Smooth muscle': ['SMC'],
    'Mesothelial': ['Mesothelial'],
    'B lineage': ['B1','B2'],
    'T lineage': ['Th1','Tnaive','Treg','Tex','T'], # added 'T' b/c of sc92
    'NK': ['NK'],
    'mononuclear broad': ['AM','M-t1','M-t2','M-C1q','M-lc','iMon','DC','pDC'], # originally two mononuclear
    'Neutrophil': ['N1','N2'] # originally polymorphonuclear=Neutrophil
}

fourth_dict = {
    'Blood vessels': ['CAP1','CAP2','VEC','AEC'],
    'Lymphatic EC': ['LEC'],
    'Airway epithelium': ['Ciliated','Secretory'],
    'Alveolar epithelium': ['AT1','AT2','AT2-t1','AT2-t2'],
    'Fibroblast': ['AF','Pericyte'],
    'Smooth muscle': ['SMC'],
    'Mesothelial': ['Mesothelial'],
    'B lineage': ['B1','B2'],
    'T lineage': ['Th1','Tnaive','Treg','Tex','T'], # added 'T' b/c of sc92
    'NK': ['NK'],
    'Macrophage': ['AM','M-t1','M-t2','M-C1q','M-lc'],        
    'mononuclear fine': ['iMon','DC','pDC'], # originally two mononuclear
    'Neutrophil': ['N1','N2']
}

L1_annotation = {}
for cell_type, cluster_num in top_dict.items():
    for x in cluster_num:
        L1_annotation[x] = cell_type

L2_annotation = {}
for cell_type, cluster_num in second_dict.items():
    for x in cluster_num:
        L2_annotation[x] = cell_type

L3_annotation = {}
for cell_type, cluster_num in third_dict.items():
    for x in cluster_num:
        L3_annotation[x] = cell_type

L4_annotation = {}
for cell_type, cluster_num in fourth_dict.items():
    for x in cluster_num:
        L4_annotation[x] = cell_type

def create_hierarchical_annotations(anndata):
    anndata.obs["top_level"] = anndata.obs['single_cell_types'].map(L1_annotation)
    anndata.obs["second_level"] = anndata.obs['single_cell_types'].map(L2_annotation)
    anndata.obs["third_level"] = anndata.obs['single_cell_types'].map(L3_annotation)
    anndata.obs["fourth_level"] = anndata.obs['single_cell_types'].map(L4_annotation)

In [None]:
# create_hierarchical_annotations(controls_final_anndata)
# create_hierarchical_annotations(granulomas_final_anndata)
create_hierarchical_annotations(sc78_final_anndata)
create_hierarchical_annotations(sc92_final_anndata)
create_hierarchical_annotations(sc93_final_anndata)

In [None]:
hierarchy = ['top_level', 'second_level', 'third_level', 'fourth_level', 'single_cell_types']

In [None]:
# controls_final_hierarchy_dict = {}
# granulomas_final_hierarchy_dict = {}
sc78_final_hierarchy_dict = {}
sc92_final_hierarchy_dict = {}
sc93_final_hierarchy_dict = {}

In [None]:
def add_path(root, path):
    node = root
    prev = None 
    for label in path:
        if (label == prev): continue
        node = node.setdefault(label, {})
        prev = label

def create_hierarchy_dict(anndata, hierarchy_dict):
    unique_paths = anndata.obs[hierarchy].drop_duplicates().values
    for path in unique_paths: add_path(hierarchy_dict, path)

In [None]:
# create_hierarchy_dict(controls_final_anndata, controls_final_hierarchy_dict)
# create_hierarchy_dict(granulomas_final_anndata, granulomas_final_hierarchy_dict)
create_hierarchy_dict(sc78_final_anndata, sc78_final_hierarchy_dict)
create_hierarchy_dict(sc92_final_anndata, sc92_final_hierarchy_dict)
create_hierarchy_dict(sc93_final_anndata, sc93_final_hierarchy_dict)

In [None]:
def save_hierarchy_dict(name, hierarchy_dict):
    with open(f"../Data/{name}_hierarchy_dict.json", "w") as file:
        json.dump(hierarchy_dict, file, indent=4)

In [None]:
# save_hierarchy_dict("controls_final", controls_final_hierarchy_dict)
# save_hierarchy_dict("granulomas_final", granulomas_final_hierarchy_dict)
save_hierarchy_dict("sc78_final", sc78_final_hierarchy_dict)
save_hierarchy_dict("sc92_final", sc92_final_hierarchy_dict)
save_hierarchy_dict("sc93_final", sc93_final_hierarchy_dict)

In [1]:
def create_name(input): return re.sub(r"[^A-Za-z0-9]+", "_", input).strip("_").lower()

In [2]:
def get_leaves(tree):
    res = []
    for key, value in tree.items():
        if value: res.extend(get_leaves(value))
        else: res.append(key)
    return res

In [None]:
def preprocess_node(cell_name, dataset_name, node, anndata, split, seed):
    sub_dict = {key: get_leaves(value) if value else [key] for key, value in node.items()}
    int_mapping = {key: idx for idx, key in enumerate(sub_dict)}

    cell_name = create_name(cell_name)
    with open(f"../Data/{dataset_name}_int_mapping_{cell_name}.json", "w") as file:
        json.dump(int_mapping, file, indent=4)

    reverse_mapping = {value: key for key, values in sub_dict.items() for value in values}
    
    finest_level = hierarchy[-1]
    anndata_subset = anndata[anndata.obs[finest_level].isin(reverse_mapping)].copy()
    anndata_subset.obs["cell_names"] = anndata_subset.obs[finest_level].map(reverse_mapping)
    anndata_subset.obs["cell_integers"] = anndata_subset.obs["cell_names"].map(int_mapping)
    anndata_subset_hvg = anndata_subset[:, anndata_subset.var['highly_variable']].copy()

    if scipy.sparse.issparse(anndata_subset_hvg.X):
        X = anndata_subset_hvg.X.toarray()
    else:
        X = anndata_subset_hvg.X
    
    y = anndata_subset_hvg.obs["cell_integers"].values

    train_features, val_features, train_labels, val_labels = train_test_split(X, y, test_size=split, random_state=seed, stratify=y)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)

    np.save(f'../Arrays/{dataset_name}_train_features_hvg_{cell_name}.npy', train_features)
    np.save(f'../Arrays/{dataset_name}_val_features_hvg_{cell_name}.npy', val_features)
    np.save(f'../Arrays/{dataset_name}_train_labels_hvg_{cell_name}.npy', train_labels)
    np.save(f'../Arrays/{dataset_name}_val_labels_hvg_{cell_name}.npy', val_labels)
    np.save(f'../Arrays/{dataset_name}_weights_hvg_{cell_name}.npy', weights)

    anndata_subset_hvg.write(f"../Data/{dataset_name}_train_anndata_hvg_{cell_name}.h5ad")

In [None]:
def preprocess_leaf(cell_name, dataset_name, leaf, anndata, split, seed):
    int_mapping = {key: idx for idx, key in enumerate(leaf)}

    cell_name = create_name(cell_name)
    with open(f"../Data/{dataset_name}_int_mapping_{cell_name}.json", "w") as file:
        json.dump(int_mapping, file, indent=4)

    finest_level = hierarchy[-1]
    anndata_subset = anndata[anndata.obs[finest_level].isin(leaf)].copy()
    anndata_subset.obs["cell_integers"] = anndata_subset.obs[finest_level].map(int_mapping)
    anndata_subset_hvg = anndata_subset[:, anndata_subset.var['highly_variable']].copy()

    if scipy.sparse.issparse(anndata_subset_hvg.X):
        X = anndata_subset_hvg.X.toarray()
    else:
        X = anndata_subset_hvg.X
    
    y = anndata_subset_hvg.obs["cell_integers"].values

    train_features, val_features, train_labels, val_labels = train_test_split(X, y, test_size=split, random_state=seed, stratify=y)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)

    np.save(f'../Arrays/{dataset_name}_train_features_hvg_{cell_name}.npy', train_features)
    np.save(f'../Arrays/{dataset_name}_val_features_hvg_{cell_name}.npy', val_features)
    np.save(f'../Arrays/{dataset_name}_train_labels_hvg_{cell_name}.npy', train_labels)
    np.save(f'../Arrays/{dataset_name}_val_labels_hvg_{cell_name}.npy', val_labels)
    np.save(f'../Arrays/{dataset_name}_weights_hvg_{cell_name}.npy', weights)

    anndata_subset_hvg.write(f"../Data/{dataset_name}_train_anndata_hvg_{cell_name}.h5ad")

In [None]:
def hierarchical_classification(cell_name, dataset_name, dict, anndata, split, seed):
    children = list(dict)
    if (not children): return

    if all(not dict[child] for child in children):
        preprocess_leaf(cell_name, dataset_name, children, anndata, split, seed)
        return
    
    preprocess_node(cell_name, dataset_name, dict, anndata, split, seed)
    for child in children: hierarchical_classification(child, dataset_name, dict[child], anndata, split, seed)

In [None]:
seed = 6296
split = 0.2
# hierarchical_classification("top_level", "controls_final", controls_final_hierarchy_dict, controls_final_anndata, split, seed)
# hierarchical_classification("top_level", "granulomas_final", granulomas_final_hierarchy_dict, granulomas_final_anndata, split, seed)
hierarchical_classification("top_level", "sc78_final", sc78_final_hierarchy_dict, sc78_final_anndata, split, seed)
hierarchical_classification("top_level", "sc92_final", sc92_final_hierarchy_dict, sc92_final_anndata, split, seed)
hierarchical_classification("top_level", "sc93_final", sc93_final_hierarchy_dict, sc93_final_anndata, split, seed)

In [None]:
# look at data to confirm

cell_name = create_name("lymphoid")
dataset_name = "sc93_final"

anndata = sc.read_h5ad(f"../Data/{dataset_name}_train_anndata_hvg_{cell_name}.h5ad")

print(anndata)
print()
print(anndata.obs["single_cell_types"].value_counts())
print(anndata.obs["single_cell_types"].unique())
print()
# print(anndata.obs["cell_names"].value_counts())
# print(anndata.obs["cell_names"].unique())
# print()
print(anndata.obs["cell_integers"].value_counts())
print(anndata.obs["cell_integers"].unique())
print()

train_features = np.load(f"../Arrays/{dataset_name}_train_features_hvg_{cell_name}.npy")
val_features = np.load(f"../Arrays/{dataset_name}_val_features_hvg_{cell_name}.npy")
train_labels = np.load(f"../Arrays/{dataset_name}_train_labels_hvg_{cell_name}.npy")
val_labels = np.load(f"../Arrays/{dataset_name}_val_labels_hvg_{cell_name}.npy")
weights = np.load(f"../Arrays/{dataset_name}_weights_hvg_{cell_name}.npy")

print('train features shape:', train_features.shape)
print('val features shape:', val_features.shape)
print('train labels shape:', train_labels.shape)
print('val labels shape:', val_labels.shape)
print('weights shape:', weights.shape)
print()

class_weights = dict(enumerate(weights))
print(class_weights)

In [None]:
def flat_classification(dataset_name, anndata, split, seed):
    finest_level = hierarchy[-1]
    cell_types = sorted(anndata.obs[finest_level].unique())

    int_mapping = {groups: i for i, groups in enumerate(cell_types)}
    with open(f"../Data/{dataset_name}_int_mapping_flat.json", "w") as file:
        json.dump(int_mapping, file, indent=4)

    anndata.obs["cell_integers"] = anndata.obs[finest_level].map(int_mapping)

    anndata_hvg = anndata[:, anndata.var['highly_variable']].copy()

    if scipy.sparse.issparse(anndata_hvg.X):
        X = anndata_hvg.X.toarray()
    else:
        X = anndata_hvg.X
    
    y = anndata_hvg.obs["cell_integers"].values

    train_features, val_features, train_labels, val_labels = train_test_split(X, y, test_size=split, random_state=seed, stratify=y)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)

    np.save(f'../Arrays/{dataset_name}_train_features_hvg_flat.npy', train_features)
    np.save(f'../Arrays/{dataset_name}_val_features_hvg_flat.npy', val_features)
    np.save(f'../Arrays/{dataset_name}_train_labels_hvg_flat.npy', train_labels)
    np.save(f'../Arrays/{dataset_name}_val_labels_hvg_flat.npy', val_labels)
    np.save(f'../Arrays/{dataset_name}_weights_hvg_flat.npy', weights)

    anndata_hvg.write(f"../Data/{dataset_name}_train_anndata_hvg_flat.h5ad")

In [None]:
seed = 6296
split = 0.2
# flat_classification("controls_final", controls_final_anndata, split, seed)
# flat_classification("granulomas_final", granulomas_final_anndata, split, seed)
flat_classification("sc78_final", sc78_final_anndata, split, seed)
flat_classification("sc92_final", sc92_final_anndata, split, seed)
flat_classification("sc93_final", sc93_final_anndata, split, seed)

In [3]:
class Node:
    def __init__(self, name, model, index_to_child, children):
        self.name = name
        self.model = model
        self.index_to_child = index_to_child
        self.children = children

    def predict(self, gene):
        
        if (not self.children): return self.name

        if ((self.model == None) or (len(self.children) == 1)):
            child_node = next(iter(self.children.values()))
            return child_node.predict(gene)

        logits = self.model.predict(np.array([gene]), verbose=0)
        max_index = np.argmax(logits, axis=1)[0]
        child_name = self.index_to_child[max_index]

        return self.children[child_name].predict(gene)

In [4]:
def create_tree(dataset, hierarchy_dict):
    def recurse(node_name, subtree):

        name = create_name(node_name)
        model_path = f"../Models/{dataset}_hvg_{name}_jax_v1.keras"
        model = None
        
        if (os.path.exists(model_path)):
            model = ks.models.load_model(model_path, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)
            
        index_to_child = None
        if subtree:
            children = list(subtree)
            int_mapping = f"../Data/{dataset}_int_mapping_{name}.json"

            if (os.path.exists(int_mapping)):
                with open(int_mapping) as file:
                    child_to_index = json.load(file)

                index_to_child = {int(value): key for key, value in child_to_index.items()}

            else: index_to_child = {i: child for i, child in enumerate(children)}
            child_nodes = {child: recurse(child, subtree[child]) for child in children}
        
        else: child_nodes = {}
        return Node(node_name, model, index_to_child, child_nodes)
    
    return recurse("top_level", hierarchy_dict)

In [None]:
def hierarchical_predictions(dataset, hierarchy_dict):
    test_anndata = sc.read_h5ad(f"../Data/{dataset}_test_anndata.h5ad")
    test_anndata_hvg = test_anndata[:, test_anndata.var['highly_variable'] ].copy()

    root = create_tree(dataset, hierarchy_dict)

    X = test_anndata_hvg.X

    if scipy.sparse.issparse(X):
        X = X.toarray()

    predictions = []
    for gene in X:
        predictions.append(root.predict(gene))

    test_anndata_hvg.obs['predicted_cell_type'] = predictions
    true = test_anndata_hvg.obs['single_cell_types'].values

    per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())
    
    overall = {
        'accuracy': metrics.accuracy_score(true, predictions),
        'balanced_accuracy': metrics.balanced_accuracy_score(true, predictions),
        'precision': metrics.precision_score(true, predictions, average='macro', zero_division=0),
        'recall': metrics.recall_score(true, predictions, average='macro', zero_division=0),
        'f1_score': metrics.f1_score(true, predictions, average='macro', zero_division=0),
        'AUPRC': 'N/A'
    }

    class_report = metrics.classification_report(true, predictions, digits=2, zero_division=0, output_dict=True)

    class_report_df = (pd.DataFrame(class_report).transpose().rename_axis("label").reset_index())

    df = pd.DataFrame(X, columns=test_anndata_hvg.var.index.to_list())
    df.insert(0, 'predicted_cell', predictions)
    df.insert(0, 'true_cell', true)

    xlsx_path = f"Results/{dataset}_hierarchical_classification.xlsx"
    with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
        pd.Series(overall, name="value").to_frame().to_excel(writer, sheet_name="overall_metrics")
        pd.Series(per_class_accuracy, name="accuracy").to_frame().to_excel(writer, sheet_name="per_cell_accuracy")
        class_report_df.to_excel(writer, sheet_name="classification_report", index=False)
        df.to_excel(writer, sheet_name="data", index=False)

    print(f"Excel written to `{xlsx_path}`")

    test_anndata_hvg.write(f"../Data/{dataset}_test_anndata_hierarchy_predictions.h5ad")

In [None]:
def load_hierarchy_dict(path):
    with open(path, "r") as file:
        return json.load(file)

In [None]:
# controls_final_hierarchy_dict = load_hierarchy_dict("../Data/controls_final_hierarchy_dict.json")
# granulomas_final_hierarchy_dict = load_hierarchy_dict("../Data/granulomas_final_hierarchy_dict.json")
sc78_final_hierarchy_dict = load_hierarchy_dict("../Data/sc78_final_hierarchy_dict.json")
sc92_final_hierarchy_dict = load_hierarchy_dict("../Data/sc92_final_hierarchy_dict.json")
sc93_final_hierarchy_dict = load_hierarchy_dict("../Data/sc93_final_hierarchy_dict.json")

In [None]:
# hierarchical_predictions("controls_final", controls_final_hierarchy_dict)
# hierarchical_predictions("granulomas_final", granulomas_final_hierarchy_dict)
hierarchical_predictions("sc78_final", sc78_final_hierarchy_dict)
hierarchical_predictions("sc92_final", sc92_final_hierarchy_dict)
hierarchical_predictions("sc93_final", sc93_final_hierarchy_dict)

In [None]:
def flat_prediction(dataset):
    test_anndata = sc.read_h5ad(f"../Data/{dataset}_test_anndata.h5ad")
    test_anndata_hvg = test_anndata[:, test_anndata.var['highly_variable'] ].copy()

    model_path = f"../Models/{dataset}_hvg_flat_jax_v1.keras"
    model = ks.models.load_model(model_path, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)

    X = test_anndata_hvg.X

    if scipy.sparse.issparse(X):
        X = X.toarray()

    logits = model.predict(X)
    max_indices = np.argmax(logits, axis=1)

    path = f"../Data/{dataset}_int_mapping_flat.json"
    with open(path) as file:
        int_mapping = json.load(file)

    inverse_dict = {i: j for j, i in int_mapping.items()}
    predictions = [inverse_dict[i] for i in max_indices]

    test_anndata_hvg.obs['predicted_cell_type'] = predictions
    true = test_anndata_hvg.obs['single_cell_types'].values

    per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())
    
    overall = {
        'accuracy': metrics.accuracy_score(true, predictions),
        'balanced_accuracy': metrics.balanced_accuracy_score(true, predictions),
        'precision': metrics.precision_score(true, predictions, average='macro', zero_division=0),
        'recall': metrics.recall_score(true, predictions, average='macro', zero_division=0),
        'f1_score': metrics.f1_score(true, predictions, average='macro', zero_division=0),
        'average_precision': metrics.average_precision_score(true, logits, average='macro')
    }

    class_report = metrics.classification_report(true, predictions, digits=2, zero_division=0, output_dict=True)

    class_report_df = (pd.DataFrame(class_report).transpose().rename_axis("label").reset_index())

    df = pd.DataFrame(X, columns=test_anndata_hvg.var.index.to_list())
    df.insert(0, 'predicted_cell', predictions)
    df.insert(0, 'true_cell', true)

    xlsx_path = f"Results/{dataset}_flat_classification.xlsx"
    with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
        pd.Series(overall, name="value").to_frame().to_excel(writer, sheet_name="overall_metrics")
        pd.Series(per_class_accuracy, name="accuracy").to_frame().to_excel(writer, sheet_name="per_cell_accuracy")
        class_report_df.to_excel(writer, sheet_name="classification_report", index=False)
        df.to_excel(writer, sheet_name="data", index=False)

    print(f"Excel written to `{xlsx_path}`")

    test_anndata_hvg.write(f"../Data/{dataset}_test_anndata_flat_predictions.h5ad")

In [None]:
# flat_prediction("controls_final")
# flat_prediction("granulomas_final")
flat_prediction("sc78_final")
flat_prediction("sc92_final")
flat_prediction("sc93_final")

In [5]:
# ==== Cross-dataset label transfer: granulomas_30 -> granulomas_60/90 ====
# Assumes your earlier imports are present (numpy, scanpy as sc, pandas, sklearn.metrics, keras as ks, etc.)
# and that your preprocessing pipeline produced the usual files in ../Data and ../Models.

import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import scipy.sparse as sp
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import keras as ks
import sklearn.metrics as metrics
import pandas as pd
import re
import json
import os

SOURCE = "granulomas_final"         # 30-day dataset (source models and feature order)
TARGETS = ["sc92_final", "sc93_final"]  # 60-day, 90-day
RESULTS_DIR = "Results"

os.makedirs(RESULTS_DIR, exist_ok=True)

# ------------------------------
# Utilities
# ------------------------------
def _load_source_hvgs(source: str) -> pd.Index:
    """
    Load the *exact* HVG order used to train the SOURCE models.
    Prefer the flat training HVG file (created by flat_classification),
    fall back to the base train anndata HVG if needed.
    """
    flat_hvg_path = f"../Data/{source}_train_anndata_hvg_flat.h5ad"
    if os.path.exists(flat_hvg_path):
        ad = sc.read_h5ad(flat_hvg_path)
        return ad.var_names.copy()
    # # Fallback: use the train set and keep its HVG order
    # base_train_path = f"../Data/{source}_train_anndata.h5ad"
    # if not os.path.exists(base_train_path):
    #     raise FileNotFoundError(
    #         f"Could not find {flat_hvg_path} nor {base_train_path}. "
    #         "Please generate the source HVG training file first."
    #     )
    # ad = sc.read_h5ad(base_train_path)
    # hvgs = ad.var.index[ad.var['highly_variable']].copy()
    # return hvgs

def _load_source_flat_model(source: str):
    mpath = f"../Models/{source}_hvg_flat_jax_v1.keras"
    if not os.path.exists(mpath):
        return None
    return ks.models.load_model(mpath, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)

def _load_source_hierarchy_dict(source: str):
    jpath = f"../Data/{source}_hierarchy_dict.json"
    if not os.path.exists(jpath):
        return None
    with open(jpath, "r") as f:
        return json.load(f)

def _load_source_flat_intmap(source: str):
    jpath = f"../Data/{source}_int_mapping_flat.json"
    if not os.path.exists(jpath):
        raise FileNotFoundError(f"Missing int mapping for source flat model: {jpath}")
    with open(jpath, "r") as f:
        return json.load(f)  # {cell_type: int}

def _align_to_source_genes(target_ad_full: sc.AnnData, source_hvgs: pd.Index) -> np.ndarray:
    """
    Create a dense (float32) array of shape (n_obs_target, len(source_hvgs)),
    where columns are in the *exact* order of source_hvgs.

    We slice from the *full* target matrix (not just its HVGs), so if a source gene
    exists in target.var_names (even if not HVG there), we use it. Otherwise we pad zeros.
    """
    n = target_ad_full.n_obs
    m = len(source_hvgs)
    # Use normalized/log1p X (the same space your models trained on).
    X_full = target_ad_full.X
    # if not sp.issparse(X_full) and not isinstance(X_full, np.ndarray):
    #     # AnnData may hold a backed array-like; force to CSR then work from there
    #     X_full = sp.csr_matrix(X_full)

    # Map target gene name -> column index in full matrix
    tgt_varnames = target_ad_full.var_names
    tgt_idx = {g: j for j, g in enumerate(tgt_varnames)}

    present_positions_in_source = []
    present_indices_in_target = []
    for pos, g in enumerate(source_hvgs):
        j = tgt_idx.get(g, None)
        if j is not None:
            present_positions_in_source.append(pos)
            present_indices_in_target.append(j)

    # Slice only present genes from sparse matrix -> dense
    if len(present_indices_in_target) > 0:
        X_present = X_full[:, present_indices_in_target]
        if sp.issparse(X_present):
            X_present = X_present.toarray()
        else:
            X_present = np.asarray(X_present)
    else:
        # None present? (extremely unlikely)
        X_present = np.zeros((n, 0), dtype=np.float32)

    # Allocate final and place present columns
    X_aligned = np.zeros((n, m), dtype=np.float32)
    if len(present_positions_in_source) > 0:
        X_aligned[:, present_positions_in_source] = X_present.astype(np.float32, copy=False)

    # Diagnostics
    print(f"[Gene alignment] Target: {target_ad_full.shape[0]} cells; "
          f"Source HVGs: {m}. Present in target: {len(present_indices_in_target)}. "
          f"Padded zeros: {m - len(present_indices_in_target)}.")
    return X_aligned

def _filter_unknown_celltypes(target_ad: sc.AnnData, allowed_types: set) -> sc.AnnData:
    """
    Keep only cells whose 'single_cell_types' are in allowed_types.
    Returns a *copy* to avoid chained assignment issues.
    """
    if 'single_cell_types' not in target_ad.obs.columns:
        raise KeyError("target AnnData missing obs['single_cell_types']")
    mask = target_ad.obs['single_cell_types'].isin(allowed_types)
    kept = int(mask.sum())
    dropped = int((~mask).sum())
    print(f"[Cell-type filtering] Kept {kept} cells; dropped {dropped} (unknown to source).")
    return target_ad[mask].copy()

def _save_results_xlsx(base_path: str, overall: dict, per_class_acc: dict, class_report: dict, X=None, true=None, pred=None, genes=None):
    os.makedirs(os.path.dirname(base_path), exist_ok=True)
    class_report_df = pd.DataFrame(class_report).transpose().rename_axis("label").reset_index()
    with pd.ExcelWriter(base_path, engine="openpyxl") as writer:
        pd.Series(overall, name="value").to_frame().to_excel(writer, sheet_name="overall_metrics")
        pd.Series(per_class_acc, name="accuracy").to_frame().to_excel(writer, sheet_name="per_cell_accuracy")
        class_report_df.to_excel(writer, sheet_name="classification_report", index=False)
        if X is not None and true is not None and pred is not None:
            df = pd.DataFrame(X, columns=list(genes) if genes is not None else None)
            df.insert(0, "predicted_cell", pred)
            df.insert(0, "true_cell", true)
            df.to_excel(writer, sheet_name="data", index=False)
    print(f"Excel written to `{base_path}`")

def _compute_metrics(true_labels: np.ndarray, pred_labels: np.ndarray):
    per_class_accuracy = (pd.Series(pred_labels == true_labels)
                            .groupby(pd.Series(true_labels))
                            .mean().to_dict())
    overall = {
        'accuracy': metrics.accuracy_score(true_labels, pred_labels),
        'balanced_accuracy': metrics.balanced_accuracy_score(true_labels, pred_labels),
        'precision': metrics.precision_score(true_labels, pred_labels, average='macro', zero_division=0),
        'recall': metrics.recall_score(true_labels, pred_labels, average='macro', zero_division=0),
        'f1_score': metrics.f1_score(true_labels, pred_labels, average='macro', zero_division=0),
        'AUPRC': 'N/A'
    }
    class_report = metrics.classification_report(true_labels, pred_labels, digits=2, zero_division=0, output_dict=True)
    return overall, per_class_accuracy, class_report

# ------------------------------
# Hierarchical prediction support (reuse your Node/create_tree with source dataset)
# ------------------------------
# class Node:
#     def __init__(self, name, model, index_to_child, children):
#         self.name = name
#         self.model = model
#         self.index_to_child = index_to_child
#         self.children = children

#     def predict_one(self, gene_vec: np.ndarray):
#         if not self.children:
#             return self.name
#         if (self.model is None) or (len(self.children) == 1):
#             child_node = next(iter(self.children.values()))
#             return child_node.predict_one(gene_vec)
#         logits = self.model.predict(np.array([gene_vec], dtype=np.float32), verbose=0)
#         max_index = int(np.argmax(logits, axis=1)[0])
#         child_name = self.index_to_child[max_index]
#         return self.children[child_name].predict_one(gene_vec)

# def _create_tree_for_source(source: str, hierarchy_dict: dict) -> Node:
#     def recurse(node_name, subtree):
#         name = re_sub_safe(node_name)
#         model_path = f"../Models/{source}_hvg_{name}_jax_v1.keras"
#         model = None
#         if os.path.exists(model_path):
#             model = ks.models.load_model(model_path, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)

#         if subtree:
#             children = list(subtree)
#             int_map_path = f"../Data/{source}_int_mapping_{name}.json"
#             if os.path.exists(int_map_path):
#                 with open(int_map_path) as f:
#                     child_to_index = json.load(f)  # {child_name: int}
#                 index_to_child = {int(v): k for k, v in child_to_index.items()}
#             else:
#                 index_to_child = {i: child for i, child in enumerate(children)}
#             child_nodes = {child: recurse(child, subtree[child]) for child in children}
#         else:
#             index_to_child = None
#             child_nodes = {}
#         return Node(node_name, model, index_to_child, child_nodes)
#     # helper: mirror your create_name
#     import re as _re
#     def re_sub_safe(s): return _re.sub(r"[^A-Za-z0-9]+", "_", s).strip("_").lower()
#     return recurse("top_level", hierarchy_dict)

# ------------------------------
# Cross-dataset transfer: FLAT
# ------------------------------
def cross_dataset_flat_transfer(source: str, target: str):
    print(f"\n=== FLAT transfer {source} -> {target} ===")

    # Load source artifacts
    int_map = _load_source_flat_intmap(source)                # {cell_type: int}
    allowed_types = set(int_map.keys())
    inverse_int = {v: k for k, v in int_map.items()}

    source_hvgs = _load_source_hvgs(source)
    model = _load_source_flat_model(source)
    if model is None:
        print(f"[SKIP] Flat model not found for source: ../Models/{source}_hvg_flat_jax_v1.keras")
        return

    # # Load target full test set (all genes), then filter and align
    tgt_test_path = f"../Data/{target}_test_anndata.h5ad"
    if not os.path.exists(tgt_test_path):
        raise FileNotFoundError(f"Missing target test AnnData: {tgt_test_path}")
    ad_tgt = sc.read_h5ad(tgt_test_path)

    # # Filter out unknown cell types
    ad_tgt = _filter_unknown_celltypes(ad_tgt, allowed_types)
    if ad_tgt.n_obs == 0:
        print("[WARN] No cells left after filtering; nothing to evaluate.")
        return

    # # Align to source HVGs (use full var_names, not only HVGs)
    X_aligned = _align_to_source_genes(ad_tgt, source_hvgs)  # dense float32

    # # Predict
    logits = model.predict(X_aligned, verbose=0)
    max_idx = np.argmax(logits, axis=1)
    preds = [inverse_int[int(i)] for i in max_idx]
    true = ad_tgt.obs['single_cell_types'].values

    # # Metrics + save
    overall, per_class_acc, class_report = _compute_metrics(true, np.array(preds, dtype=object))
    base = os.path.join(RESULTS_DIR, f"{source}_to_{target}_flat.xlsx")
    _save_results_xlsx(base, overall, per_class_acc, class_report, X=X_aligned, true=true, pred=preds, genes=source_hvgs)

    # # Save a copy of the target AnnData with predictions for downstream exploration
    # ad_out = ad_tgt.copy()
    # ad_out.X = X_aligned  # aligned feature space
    # ad_out.var = pd.DataFrame(index=source_hvgs)  # reflect aligned genes
    # ad_out.obs['predicted_cell_type'] = preds
    # out_h5ad = f"../Data/{target}_test_anndata_flat_from_{source}_predictions.h5ad"
    # ad_out.write(out_h5ad)
    # print(f"H5AD with predictions written to `{out_h5ad}`")

# ------------------------------
# Cross-dataset transfer: HIERARCHICAL
# ------------------------------
def cross_dataset_hier_transfer(source: str, target: str):
    print(f"\n=== HIERARCHICAL transfer {source} -> {target} ===")

    # Load hierarchy dict and build tree
    hierarchy = _load_source_hierarchy_dict(source)
    if hierarchy is None:
        print(f"[SKIP] Hierarchy dict not found for source: ../Data/{source}_hierarchy_dict.json")
        return
    root = create_tree(source, hierarchy)

    # Allowed labels are the leaves of the source hierarchy (finest level)
    allowed_types = set(get_leaves(hierarchy))

    # Load source HVGs for alignment
    source_hvgs = _load_source_hvgs(source)

    # Load target test (full), filter, align
    tgt_test_path = f"../Data/{target}_test_anndata.h5ad"
    if not os.path.exists(tgt_test_path):
        raise FileNotFoundError(f"Missing target test AnnData: {tgt_test_path}")
    ad_tgt = sc.read_h5ad(tgt_test_path)
    ad_tgt = _filter_unknown_celltypes(ad_tgt, allowed_types)
    if ad_tgt.n_obs == 0:
        print("[WARN] No cells left after filtering; nothing to evaluate.")
        return

    X_aligned = _align_to_source_genes(ad_tgt, source_hvgs)

    # Predict by traversing the tree
    preds = []
    for i in range(X_aligned.shape[0]):
        preds.append(root.predict(X_aligned[i, :]))
    true = ad_tgt.obs['single_cell_types'].values

    # Metrics + save
    overall, per_class_acc, class_report = _compute_metrics(true, np.array(preds, dtype=object))
    base = os.path.join(RESULTS_DIR, f"{source}_to_{target}_hierarchical.xlsx")
    _save_results_xlsx(base, overall, per_class_acc, class_report,X=X_aligned, true=true, pred=preds, genes=source_hvgs)

    # Save H5AD with predictions
    # ad_out = ad_tgt.copy()
    # ad_out.X = X_aligned
    # ad_out.var = pd.DataFrame(index=source_hvgs)
    # ad_out.obs['predicted_cell_type'] = preds
    # out_h5ad = f"../Data/{target}_test_anndata_hier_from_{source}_predictions.h5ad"
    # ad_out.write(out_h5ad)
    # print(f"H5AD with predictions written to `{out_h5ad}`")

# ------------------------------
# Run transfers for requested targets
# ------------------------------
source_hvgs_preview = _load_source_hvgs(SOURCE)
print(f"[Source HVG] {SOURCE}: {len(source_hvgs_preview)} genes")

for tgt in TARGETS:
    # Flat transfer
    cross_dataset_flat_transfer(SOURCE, tgt)
    # Hierarchical transfer (if artifacts exist)
    cross_dataset_hier_transfer(SOURCE, tgt)

print("\nDone.")




[Source HVG] granulomas_final: 3475 genes

=== FLAT transfer granulomas_final -> sc92_final ===
[Cell-type filtering] Kept 6094 cells; dropped 291 (unknown to source).
[Gene alignment] Target: 6094 cells; Source HVGs: 3475. Present in target: 3475. Padded zeros: 0.


  .groupby(pd.Series(true_labels))


Excel written to `Results\granulomas_final_to_sc92_final_flat.xlsx`

=== HIERARCHICAL transfer granulomas_final -> sc92_final ===
[Cell-type filtering] Kept 6094 cells; dropped 291 (unknown to source).
[Gene alignment] Target: 6094 cells; Source HVGs: 3475. Present in target: 3475. Padded zeros: 0.


  .groupby(pd.Series(true_labels))


Excel written to `Results\granulomas_final_to_sc92_final_hierarchical.xlsx`

=== FLAT transfer granulomas_final -> sc93_final ===
[Cell-type filtering] Kept 4147 cells; dropped 0 (unknown to source).
[Gene alignment] Target: 4147 cells; Source HVGs: 3475. Present in target: 3475. Padded zeros: 0.


  .groupby(pd.Series(true_labels))


Excel written to `Results\granulomas_final_to_sc93_final_flat.xlsx`

=== HIERARCHICAL transfer granulomas_final -> sc93_final ===
[Cell-type filtering] Kept 4147 cells; dropped 0 (unknown to source).
[Gene alignment] Target: 4147 cells; Source HVGs: 3475. Present in target: 3475. Padded zeros: 0.


  .groupby(pd.Series(true_labels))


Excel written to `Results\granulomas_final_to_sc93_final_hierarchical.xlsx`

Done.
