In [1]:
%config Completer.use_jedi = False

import sys
import os
import re
import json
import pickle
from collections import defaultdict, Counter, OrderedDict
from datetime import datetime

import h5py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from scipy.sparse import csc_matrix, csr_matrix

from glmnet import LogitNet

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [2]:
import scanpy as sc
sc.settings.set_figure_params(dpi=120, facecolor='white')

sys.path.append("/home/liwang/repo/scRNAseq/smartSeq/")
from lib.reference import get_gene_anno

In [3]:
def read_csc_matrix_h5(h5_file):
    """generic io function for reading sparse matrix h5"""

    with h5py.File(h5_file, "r") as h5in:
        matrix = csc_matrix((h5in["data"], h5in["indices"], h5in["indptr"]), shape=h5in["shape"])

    return matrix

def sum_sparse_matrix(matrix, axis=0):
    """Sum a sparse matrix along an axis."""
    axis_sum = np.asarray(matrix.sum(axis=axis))  # sum along given axis
    max_dim = np.prod(axis_sum.shape)  # get the max dimension
    return axis_sum.reshape((max_dim,))  # reshape accordingly


def normalize_matrix(matrix, use_median=False):
    """depth and log normalize matrix"""
    
    counts_per_bc = sum_sparse_matrix(matrix, axis=0)
    median_counts_per_bc = max(1.0, np.median(counts_per_bc))
    print(median_counts_per_bc)
    if use_median:
        scaling_factors = median_counts_per_bc / np.clip(counts_per_bc, 1.0, None)
    else:
        scaling_factors = 1000 / np.clip(counts_per_bc, 1.0, None)

    # Normalize each barcode's total count by median total count
    m = matrix.copy().astype(np.float64)
    sparsefuncs.inplace_column_scale(m, scaling_factors)

    # Use log counts
    m.data = np.log2(1 + m.data)
    
    return m


## Load data

In [4]:
df_cell_meta = pd.read_csv("data/tcell_metadata.csv")
df_cell_meta["batch"] = df_cell_meta["donor"].apply(lambda val: "donor1234" if val in ["P1", "P2", "P3", "P4"] else "donor5678")
df_cell_meta.head()


Unnamed: 0,Barcode,lane,cell_bc,nCount_ADT,nFeature_ADT,nCount_RNA,nFeature_RNA,celltype.l1,celltype.l2,celltype.l3,donor,time,Phase,umap1,umap2,batch
0,L1_AAACCCAAGACATACA,L1,AAACCCAAGACATACA-9,5949,211,5864,1617,CD4 T,CD4 TCM,CD4 TCM_1,P1,7,G1,5.28692,5.635788,donor1234
1,L1_AAACCCACAACTGGTT,L1,AAACCCACAACTGGTT-9,6547,217,5067,1381,CD8 T,CD8 Naive,CD8 Naive,P4,3,S,11.907538,-4.530682,donor1234
2,L1_AAACCCACACGTACTA,L1,AAACCCACACGTACTA-9,3508,207,4786,1890,NK,NK,NK_2,P3,7,G1,2.371758,-8.360968,donor1234
3,L1_AAACCCACAGCATACT,L1,AAACCCACAGCATACT-9,6318,219,6505,1621,CD8 T,CD8 Naive,CD8 Naive,P4,7,G1,12.371468,-5.079568,donor1234
4,L1_AAACCCACATCAGTCA,L1,AAACCCACATCAGTCA-9,5195,213,4332,1633,CD8 T,CD8 TEM,CD8 TEM_1,P3,3,G1,5.819462,-3.657757,donor1234


In [5]:
# mat = sc.read_10x_h5("/home/liwang/yard/imPACT-SMARTseq/literature_data/seurat_v4_cite-seq/cellranger_reanalyze/seurat_v4_tcells_sample/outs/filtered_feature_bc_matrix.h5", 
#                          gex_only=False)
mat = sc.read_10x_h5("/home/liwang/yard/dataset/seurat_v4_cite-seq/cellranger_reanalyze/fbm.h5",  gex_only=False)

Variable names are not unique. To make them unique, call `.var_names_make_unique`.


In [11]:
mat.var.head()

Unnamed: 0,gene_ids,feature_types,genome
AL627309.1,ENSG00000238009,Gene Expression,GRCh38
AL669831.5,AL669831.5,Gene Expression,GRCh38
LINC00115,ENSG00000225880,Gene Expression,GRCh38
FAM41C,ENSG00000230368,Gene Expression,GRCh38
NOC2L,ENSG00000188976,Gene Expression,GRCh38


#### Select gene features

In [39]:
# smart counter genes
gene_anno = get_gene_anno()

# seurat tcell genes
seurat_tcell_genes = [gene.strip() for gene in open("data/CD4CD8_genes.txt")]

# cellranger genes
cr_genes = mat.var_names[mat.var["feature_types"] == "Gene Expression"]

# common genes
common_genes = set(gene_anno["name"]).intersection(set(cr_genes))
selected_genes = set(seurat_tcell_genes).intersection(common_genes)

common_genes = sorted(list(common_genes))
selected_genes = sorted(list(selected_genes))

mat_gex = mat[df_cell_meta["cell_bc"][df_cell_meta["celltype.l3"] != "CD8 Naive_2"], mat.var["feature_types"] == "Gene Expression"]
mat_gex = mat_gex[:, common_genes]
mat.shape, mat_gex.shape

((87918, 20957), (87628, 20177))

In [40]:
mat_gex.obs = mat_gex.obs.join(df_cell_meta.set_index("cell_bc"))
mat_gex.obs.head()

Unnamed: 0,Barcode,lane,nCount_ADT,nFeature_ADT,nCount_RNA,nFeature_RNA,celltype.l1,celltype.l2,celltype.l3,donor,time,Phase,umap1,umap2,batch
AAACCCAAGACATACA-9,L1_AAACCCAAGACATACA,L1,5949,211,5864,1617,CD4 T,CD4 TCM,CD4 TCM_1,P1,7,G1,5.28692,5.635788,donor1234
AAACCCACAACTGGTT-9,L1_AAACCCACAACTGGTT,L1,6547,217,5067,1381,CD8 T,CD8 Naive,CD8 Naive,P4,3,S,11.907538,-4.530682,donor1234
AAACCCACACGTACTA-9,L1_AAACCCACACGTACTA,L1,3508,207,4786,1890,NK,NK,NK_2,P3,7,G1,2.371758,-8.360968,donor1234
AAACCCACAGCATACT-9,L1_AAACCCACAGCATACT,L1,6318,219,6505,1621,CD8 T,CD8 Naive,CD8 Naive,P4,7,G1,12.371468,-5.079568,donor1234
AAACCCACATCAGTCA-9,L1_AAACCCACATCAGTCA,L1,5195,213,4332,1633,CD8 T,CD8 TEM,CD8 TEM_1,P3,3,G1,5.819462,-3.657757,donor1234


In [41]:
%%time
sc.pp.normalize_total(mat_gex, target_sum=10000)
sc.pp.log1p(mat_gex)
sc.pp.scale(mat_gex, max_value=10)

mat_gex_filtered = mat_gex[:, selected_genes]
mat.shape, mat_gex.shape, mat_gex_filtered.shape

CPU times: user 20.6 s, sys: 7.15 s, total: 27.8 s
Wall time: 27.8 s


((87918, 20957), (87628, 20177), (87628, 488))

## Cell type label hierarchy

In [42]:
df_hierarchy = pd.read_csv("data/celltypes_hierarchy.csv", index_col="celltype")
df_hierarchy.head()

Unnamed: 0_level_0,nonProlif_Proliferating,nonNK_NK,CD4CD8_Treg_MAIT,Naive_Eff_Memory,CD4Naive_CD8Naive,CD4TEM_CD8TEM_CD4CTL,CD4TCM_CD8TCM
celltype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
CD4 CTL,0.0,0.0,0.0,1.0,,2.0,
CD4 Naive,0.0,0.0,0.0,0.0,0.0,,
CD4 Proliferating,1.0,,,,,,
CD4 TCM_1,0.0,0.0,0.0,2.0,,,0.0
CD4 TCM_2,0.0,0.0,0.0,2.0,,,0.0


In [43]:
class CelltypeDefException(Exception):
    pass


class CellTypeClassifier:
    
    def __init__(self, task, labels, members):
        
        self.task = task
        self.labels = labels
        self.members, self.members_flat = self.add_members(labels, members)
        self.from_label = self.get_from_labels()
        
    
    @staticmethod
    def add_members(labels, member_dict):
        
        members = defaultdict(set)
        members_flat = set()
        for member, assign in member_dict.items():
            if not pd.isna(assign):
                members[labels[int(assign)]].add(member)
                members_flat.add(member)

        return members, members_flat

    
    def get_members(self, label):
        return self.members[label]
    
    
    def get_from_labels(self):
        
        lookup = dict()
        for label, members in self.members.items():
            for member in members:
                lookup[member] = label
        
        return lookup
    
    def lookup_label(self, member):
        return self.from_label[member]
                  


class CellTypeHierarchy:
    
    def __init__(self, hierarchy_csv):
        self.src_table = pd.read_csv(hierarchy_csv, index_col="celltype")
        
        self.hierachy = OrderedDict()
        for task, members in self.src_table.to_dict().items():
            labels = task.split("_")
            self.hierachy[task] = CellTypeClassifier(task, labels, members) 


In [253]:
cth = CellTypeHierarchy("data/celltypes_hierarchy.csv")

In [254]:
cth.hierachy.keys()

odict_keys(['nonProlif_Proliferating', 'nonNK_NK', 'CD4CD8_Treg_MAIT', 'Naive_Eff_Memory', 'CD4Naive_CD8Naive', 'CD4TEM_CD8TEM_CD4CTL', 'CD4TCM_CD8TCM', 'CD8TermEff_CD8EffMem', 'CD8TCM1_CD8TCM2_CD8TCM3', 'CD8TEM2_CD8TEM4_CD8TEM5', 'CD8TEM1_CD8TEM3_CD8TEM6'])

In [51]:
cth.hierachy["Naive_Eff_Memory"].labels

['Naive', 'Eff', 'Memory']

In [337]:
pd_hierachy = pd.concat([pd.DataFrame({"celltype": ct.labels, "task": task}) for task, ct in cth.hierachy.items()], axis=0)
pd_hierachy.to_csv("results/cls_task_labels.csv", index=False)

## Prepare matrix for fitting

In [220]:
task = "CD4TEM_CD8TEM_CD4CTL"
mat_gex_filtered = mat_gex[mat_gex.obs["celltype.l3"].isin(cth.hierachy[task].members_flat), selected_genes]

labels = cth.hierachy[task].labels

all_labels = mat_gex_filtered.obs["celltype.l3"].apply(cth.hierachy[task].lookup_label)
Counter(all_labels), mat_gex_filtered.shape

(Counter({'CD8TEM': 11727, 'CD4CTL': 1736, 'CD4TEM': 4282}), (17745, 488))

#### Sample each class to avoid over-dominance

In [221]:
def get_banlanced_labels(all_labels):
    count = Counter(all_labels)
    count_min = min(count.values())
    
    sampled = []
    for label in labels:        
        df_sample = mat_gex_filtered.obs.loc[all_labels == label, :]
        
        if count[label] > 5*count_min:
            rate = 5*count_min/count[label]
            
            df_sample = df_sample.groupby("celltype.l3").apply(lambda x: x.sample(frac=rate))
            df_sample.index = df_sample.index.droplevel("celltype.l3")
            
        sampled.append(df_sample)

    sampled_bcs = pd.concat(sampled, axis=0).index
    return sampled_bcs

def rebanlance_data(data_mat, labels):
    
    banlanced_labels = get_banlanced_labels(labels)
    return data_mat[banlanced_labels, :], labels[banlanced_labels]
    

In [222]:
all_labels.shape

(17745,)

In [223]:
mat_gex_balanced, balanced_labels = rebanlance_data(mat_gex_filtered, all_labels)
mat_gex_balanced.shape, balanced_labels.shape

((14698, 488), (14698,))

In [206]:
Counter(balanced_labels)

Counter({'CD4TEM': 4282, 'CD8TEM': 8680, 'CD4CTL': 1736})

## Train classifiers

In [237]:
def get_model_nonzero_coef(model, feature_names):
    coef = model.coef_.T
    nonzero = coef.sum(axis=1) > 0
    classes = model.classes_ if coef.shape[1] > 1 else model.classes_[0]
    
    return  pd.DataFrame(coef[nonzero], columns=classes, index=feature_names[nonzero])

In [62]:
# Multi-class
X = mat_gex_filtered.X
y = all_labels


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

X_train.shape, y_train.shape, X_test.shape, y_test.shape

((38084, 488), (38084,), (25390, 488), (25390,))

In [105]:
%%time

# Fitting (1min 25s)
lambda_path = np.power(10, -np.linspace(1,5,50))
model = LogitNet(alpha=1, standardize=False, random_state=0, lambda_path=lambda_path, n_jobs=10)
model.fit(X_train, y_train)

CPU times: user 9min 27s, sys: 8min, total: 17min 28s
Wall time: 4min 24s


LogitNet(lambda_path=array([1.00000000e-01, 8.28642773e-02, 6.86648845e-02, 5.68986603e-02,
       4.71486636e-02, 3.90693994e-02, 3.23745754e-02, 2.68269580e-02,
       2.22299648e-02, 1.84206997e-02, 1.52641797e-02, 1.26485522e-02,
       1.04811313e-02, 8.68511374e-03, 7.19685673e-03, 5.96362332e-03,
       4.94171336e-03, 4.09491506e-03, 3.39322177e-03, 2.81176870e-03,
       2.32995181e-03, 1.93069...
       5.17947468e-04, 4.29193426e-04, 3.55648031e-04, 2.94705170e-04,
       2.44205309e-04, 2.02358965e-04, 1.67683294e-04, 1.38949549e-04,
       1.15139540e-04, 9.54095476e-05, 7.90604321e-05, 6.55128557e-05,
       5.42867544e-05, 4.49843267e-05, 3.72759372e-05, 3.08884360e-05,
       2.55954792e-05, 2.12095089e-05, 1.75751062e-05, 1.45634848e-05,
       1.20679264e-05, 1.00000000e-05]),
         n_jobs=10, random_state=0, standardize=False)

In [123]:
df_feat_logit = get_model_nonzero_coef(model, mat_gex_filtered.var_names)
df_feat_logit

Unnamed: 0,Eff,Memory,Naive
A2M-AS1,0.045319,0.000000,0.000000
ABHD17A,0.020972,0.000000,0.000000
AC243960.1,0.000000,0.000000,0.048518
ACTN1,0.000000,0.000000,0.088092
ADGRE5,0.002184,0.000000,0.000000
...,...,...,...
VIM,0.000000,0.084371,0.000000
XCL1,0.000000,0.001574,0.000000
ZFP36L2,0.068557,0.000000,-0.005143
ZNF683,0.103623,0.000000,0.000000


# Run all classifiers

In [628]:
def get_time_now():
    return datetime.now().strftime("%H:%M:%S")

def get_model_coef(model, feature_names):
    
    if len(model.classes_) > 2:
        df_coef = pd.DataFrame(model.coef_.T, columns=model.classes_)

    else:
        df_coef = pd.DataFrame({model.classes_[0]: -1*model.coef_[0]})
    
    df_coef["gene_name"] = feature_names
    df_coef["task"] = task
    df_coef = df_coef.melt(id_vars = ["gene_name", "task"], var_name="celltype", value_name="coeff")

    return df_coef
    

In [None]:
%%time
models = {}
df_coefs = []
for task, ct_obj in cth.hierachy.items():
    
    result = {}
    
    print("[{}] Building classifier for {}...".format(get_time_now(), task))
    mat_gex_filtered = mat_gex[mat_gex.obs["celltype.l3"].isin(cth.hierachy[task].members_flat), selected_genes]
    labels = cth.hierachy[task].labels
    all_labels = mat_gex_filtered.obs["celltype.l3"].apply(cth.hierachy[task].lookup_label)

    print("\t[{}] Rebalancing data...".format(get_time_now()))
    mat_gex_balanced, balanced_labels = rebanlance_data(mat_gex_filtered, all_labels)
    
    print("\t[{}] Preparing traning and testing data...".format(get_time_now()))
    X_train, X_test, y_train, y_test = train_test_split(mat_gex_balanced.X, 
                                                        balanced_labels, 
                                                        test_size=0.3, 
                                                        random_state=0, 
                                                        stratify=balanced_labels)
    
    print("\t[{}] Fitting the model...".format(get_time_now()))
    lambda_path = np.power(10, -np.linspace(1,5,50))
    model = LogitNet(alpha=1, standardize=False, random_state=0, lambda_path=lambda_path, n_jobs=10)
    model.fit(X_train, y_train)
    
    accu =  model.score(X_test, y_test)
    print("\t[{}] {}: {}".format(get_time_now(), task, accu))
    
    print("\t[{}] Extracting feature coeff...".format(get_time_now()))
    df_coef = get_model_coef(model, selected_genes)
    df_coefs.append(df_coef)
        
    print("\t[{}] Done!".format(get_time_now()))
    
    result = {
        "model": model, "accuracy": accu, "label_count": Counter(balanced_labels), 
        #"feat_coeff": df_feat_logit
    }
    models[task] = result

df_coefs = pd.concat(df_coefs, axis=0)

[14:35:27] Building classifier for nonProlif_Proliferating...
	[14:35:27] Rebalancing data...
	[14:35:28] Preparing traning and testing data...
	[14:35:28] Fitting the model...
	[14:35:29] nonProlif_Proliferating: 0.9925650557620818
	[14:35:29] Done!
[14:35:29] Building classifier for nonNK_NK...
	[14:35:29] Rebalancing data...
	[14:35:29] Preparing traning and testing data...
	[14:35:32] Fitting the model...
	[14:37:11] nonNK_NK: 0.9985037406483791
	[14:37:11] Done!
[14:37:11] Building classifier for CD4CD8_Treg_MAIT...
	[14:37:11] Rebalancing data...
	[14:37:11] Preparing traning and testing data...
	[14:37:11] Fitting the model...
	[14:40:01] CD4CD8_Treg_MAIT: 0.9725182277061133
	[14:40:01] Done!
[14:40:01] Building classifier for Naive_Eff_Memory...
	[14:40:01] Rebalancing data...
	[14:40:01] Preparing traning and testing data...
	[14:40:03] Fitting the model...
	[14:44:02] Naive_Eff_Memory: 0.9473297274589088
	[14:44:02] Done!
[14:44:02] Building classifier for CD4Naive_CD8Naive..

In [269]:
for model, content in models.items():
    print("{}\t{}\t{}".format(np.round(content["accuracy"], 4), model, content["label_count"]))

0.9926	nonProlif_Proliferating	Counter({'nonProlif': 3735, 'Proliferating': 747})
0.9985	nonNK_NK	Counter({'nonNK': 68765, 'NK': 18116})
0.9725	CD4CD8_Treg_MAIT	Counter({'CD4CD8': 12537, 'MAIT': 2784, 'Treg': 2507})
0.9473	Naive_Eff_Memory	Counter({'Naive': 27957, 'Memory': 17772, 'Eff': 17745})
0.9911	CD4Naive_CD8Naive	Counter({'CD4Naive': 17479, 'CD8Naive': 10478})
0.9816	CD4TEM_CD8TEM_CD4CTL	Counter({'CD8TEM': 8680, 'CD4TEM': 4282, 'CD4CTL': 1736})
0.9819	CD4TCM_CD8TCM	Counter({'CD4TCM': 14415, 'CD8TCM': 2883})
0.9739	CD8TermEff_CD8EffMem	Counter({'CD8TermEff': 7912, 'CD8EffMem': 3815})
0.859	CD8TCM1_CD8TCM2_CD8TCM3	Counter({'CD8TCM2': 1322, 'CD8TCM1': 929, 'CD8TCM3': 632})
0.8707	CD8TEM2_CD8TEM4_CD8TEM5	Counter({'CD8TEM4': 3504, 'CD8TEM2': 2435, 'CD8TEM5': 1973})
0.9588	CD8TEM1_CD8TEM3_CD8TEM6	Counter({'CD8TEM1': 1965, 'CD8TEM6': 636, 'CD8TEM3': 393})


## Save the classifier and feature coeff

In [626]:
df_coefs.query("celltype == 'Proliferating' & coeff > 0.2")

Unnamed: 0,gene_name,task,celltype,coeff
102,CLSPN,nonProlif_Proliferating,Proliferating,0.219409
275,MCM4,nonProlif_Proliferating,Proliferating,0.246096
279,MKI67,nonProlif_Proliferating,Proliferating,0.249142
330,PCNA,nonProlif_Proliferating,Proliferating,0.210181
433,STMN1,nonProlif_Proliferating,Proliferating,0.695056


In [None]:
# save the classifier
with open('models/hc_allmodels_seurat_genes.pkl', 'wb') as fid:
    pickle.dump(models, fid)    
    
# save model coeff
df_coefs.to_csv("models/featcoef_allmodels_seurat_genes.csv", index=False)

## Prediction

In [281]:
df_mat_pact476 = pd.read_csv("data/PACT476_filtered_rsem_matrix.csv", index_col=0)
df_mat_pact476.shape

(60728, 84)

In [307]:
mat_pact476 = np.array(df_mat_pact476.loc[common_genes, :]).T
df_obs_pact476 = pd.DataFrame({"patient_id": "PACT476"}, index=df_mat_pact476.columns)
df_var_pact476 = pd.DataFrame(index=common_genes)

adata_pact476 = sc.AnnData(mat_pact476, obs=df_obs_pact476, var=df_var_pact476)


In [312]:
sc.pp.normalize_total(adata_pact476, target_sum=10000)
sc.pp.log1p(adata_pact476)
sc.pp.scale(adata_pact476, max_value=10)

adata_pact476_filtered = adata_pact476[:, selected_genes]
adata_pact476.shape, adata_pact476_filtered.shape

((84, 20177), (84, 488))

In [340]:
%%time
df_prob = []
for task, model_obj in models.items():
    
    prob = pd.DataFrame(model_obj["model"].predict_proba(adata_pact476_filtered.X), 
                        columns=model_obj["model"].classes_, 
                        index=adata_pact476_filtered.obs_names)
    df_prob.append(prob)

df_prob = pd.concat(df_prob, axis=1)
df_prob.head()

CPU times: user 75 ms, sys: 281 ms, total: 356 ms
Wall time: 18.7 ms


Unnamed: 0,Proliferating,nonProlif,NK,nonNK,CD4CD8,MAIT,Treg,Eff,Memory,Naive,...,CD8TermEff,CD8TCM1,CD8TCM2,CD8TCM3,CD8TEM2,CD8TEM4,CD8TEM5,CD8TEM1,CD8TEM3,CD8TEM6
TCR24_T02.P02-C05,0.002261,0.997739,3.6e-05,0.999964,0.882477,0.065418,0.052105,0.521385,0.450793,0.027822,...,0.009165,0.734952,0.101459,0.163589,0.653255,0.297265,0.04948,0.751994,0.096663,0.151342
TCR7_T01.P01-B08,0.0017,0.9983,0.106468,0.893532,0.537771,0.009957,0.452273,0.040037,0.098499,0.861464,...,0.164089,0.252467,0.41964,0.327893,0.634037,0.31044,0.055523,0.373041,0.080416,0.546544
TCR14_T02.P02-H10,0.00167,0.99833,0.001984,0.998016,0.672041,0.289703,0.038256,0.343269,0.34666,0.310072,...,0.024515,0.431873,0.429938,0.13819,0.633261,0.329728,0.037011,0.675094,0.140698,0.184209
TCR19_T02.P02-C04,0.002287,0.997713,0.000318,0.999682,0.894146,0.036383,0.069471,0.09396,0.816659,0.089381,...,0.183001,0.350793,0.559307,0.0899,0.61815,0.303279,0.078571,0.418601,0.093419,0.48798
TCR5_T02.P04-F08,0.020707,0.979293,0.002257,0.997743,0.736118,0.03637,0.227512,0.318503,0.651296,0.030202,...,0.549194,0.337891,0.508925,0.153183,0.386772,0.396742,0.216486,0.246987,0.354985,0.398027


In [322]:
# df_prob.to_csv("results/PACT476_predprob_seurat_genes.csv", index=False)

## Output decision tree path

In [450]:
DECISION_TREE = {
    "Proliferating": {},
    "nonProlif": {"NK", "nonNK"},
    "NK": {},
    "nonNK": {"CD4CD8", "MAIT", "Treg"},
    "CD4CD8": {"Eff", "Memory", "Naive"},
    "MAIT": {},
    "Treg": {},
    "Eff": {"CD4CTL", "CD4TEM", "CD8TEM"},
    "Memory": {"CD4TCM", "CD8TCM"},
    "Naive": {"CD4Naive", "CD8Naive"},
    "CD4Naive": {},
    "CD8Naive": {},
    "CD4CTL": {},
    "CD4TEM": {},
    "CD8TEM": {"CD8EffMem", "CD8TermEff"},
    "CD4TCM": {},
    "CD8TCM": {"CD8TCM1", "CD8TCM2", "CD8TCM3"},
    "CD8EffMem": {"CD8TEM1", "CD8TEM3", "CD8TEM6"},
    "CD8TermEff": {"CD8TEM2", "CD8TEM4", "CD8TEM5"},
    "CD8TCM1": {},
    "CD8TCM2": {},
    "CD8TCM3": {},
    "CD8TEM2": {},
    "CD8TEM4": {},
    "CD8TEM5": {},
    "CD8TEM1": {},
    "CD8TEM3": {},
    "CD8TEM6": {}
}
TERMINAL_TYPES = set([k for k,v in DECISION_TREE.items() if not v])

In [540]:
def get_candidates(winner, candidates):
    if DECISION_TREE[winner]:
        for sub in DECISION_TREE[winner]:
            get_candidates(sub, candidates)
    else:
        return candidates.append(winner)
    
    return candidates

def find_path(score):
    candidates = TERMINAL_TYPES.copy()
    path = []
    for task, ct in cth.hierachy.items():
        winner, prob = score[ct.labels].idxmax(),  score[ct.labels].max()
        winner_prob = "{} ({})".format(winner, np.round(prob, 4))
        
        avail = get_candidates(winner, [])
        if avail:
            refined = candidates.intersection(avail)
            if refined:
                path.append(winner_prob)
                candidates = refined

        if winner in candidates:
            final = winner
            path.append(winner_prob)
    
    return "-> ".join(path), final
    

In [541]:
paths, finals = [], []
for _, row in df_prob.iterrows():
    path, final = find_path(row)
    paths.append(path)
    finals.append(final)

df_path = pd.DataFrame({"cell_id":df_prob.index, "pred": finals, "path": paths})
        

In [545]:
df_path.head()

Unnamed: 0,cell_id,pred,path
0,TCR24_T02.P02-C05,CD8TEM1,nonProlif (0.9977)-> nonNK (1.0)-> CD4CD8 (0.8...
1,TCR7_T01.P01-B08,CD8Naive,nonProlif (0.9983)-> nonNK (0.8935)-> CD4CD8 (...
2,TCR14_T02.P02-H10,CD8TCM1,nonProlif (0.9983)-> nonNK (0.998)-> CD4CD8 (0...
3,TCR19_T02.P02-C04,CD8TCM2,nonProlif (0.9977)-> nonNK (0.9997)-> CD4CD8 (...
4,TCR5_T02.P04-F08,CD8TCM2,nonProlif (0.9793)-> nonNK (0.9977)-> CD4CD8 (...


In [544]:
df_path.to_csv("results/PACT476_predpath_seurat_genes.csv", index=False)