Present notebook gives an example of application of ClinTrajan package (in particular ElPiGraph) for analysis of METABRIC breast cancer dataset. It is based on notebook: https://github.com/auranic/ClinTrajan/blob/master/ClinTrajan_tutorial.ipynb 

ClinTrajan package:
https://github.com/auranic/ClinTrajan 

ElPiGraph package: 
https://github.com/j-bac/elpigraph-python


In [None]:
#pd.set_option('display.max_rows', 500)
#pd.set_option('display.max_columns', 500)
#pd.set_option('display.width', 1000)

# Example of using ClinTrajan

# Part 1. Quantification of Data

### Importing libraries for quantification

In [None]:
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

from scipy.stats import mode
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt
import seaborn as sns

from importlib import reload  
import scipy.stats


In [None]:
!pip install lifelines
from lifelines import KaplanMeierFitter
from lifelines.utils import concordance_index
!pip install  --no-dependencies  git+https://github.com/j-bac/elpigraph-python.git
import elpigraph
!pip install trimap

import sys
sys.path.insert(0,'/kaggle/input/breast-cancer-omics-bulk-data/code/')# "/path/to/your/package_or_module")
print(sys.path)

from clintraj_qi import *
from clintraj_eltree import *
from clintraj_util import *
from clintraj_ml import *
from clintraj_optiscale import *

### Loading data (categorical variables are assumed to be dummy-encoded already)

In [None]:
# load omics data
df1 = pd.read_csv('/kaggle/input/breast-cancer-omics-bulk-data/METABRIC.txt', sep = '\t', index_col = 0)
df1=df1.T
i1 = [s.replace('BRCA-METABRIC-S1-','') for s in df1.index ]
#print('number of common ids:', len(set(i2) & set(df1.index) ) )
df1.index = i1
df1
# load clinical data
df2 = pd.read_csv('/kaggle/input/breast-cancer-omics-bulk-data/METABRIC_clinical.txt', sep = '\t')#, index_col = 0)
df2 = df2.set_index('Patient ID')
df2
df = df2.join(df1, how = 'inner')
print('Joined data shape', df.shape)
df
m = df['Relapse Free Status'].notnull()
print( m.sum() )
df = df[m].copy()
df['Relapse Free Status'] = df['Relapse Free Status'].map({'0:Not Recurred':0,'1:Recurred':1 } )
print(df.shape)
display(df.head())

df_full = df.copy()
df = df.iloc[:,37:] # OMICS data only

X = df.values  
X1 = X.copy()
X_columns = df.columns
X

In [None]:
#df = pd.read_csv('data/infarction/all_dummies.txt',delimiter='\t')
display(df)
quantify_nans(df)

### Cut 1000 top variance variables 

Step is intended to use for gene expression data , so we leave 1000 out of dozen thousands of gene expressions 


In [None]:
X_var = X.var(axis = 0)
print( X_var.shape, X.shape, X_var[:5] )
ix = np.argsort(X_var)
X = X[:,ix[-1000:]]
df = df.iloc[:,ix[-1000:]]
print(X.shape,df.shape)

### Detecting variable types

We need to know what types different variables have, which can be 'BINARY', 'ORDINAL', 'CONTINUOUS:

In [None]:
variable_types, binary, continuous, ordinal = detect_variable_type(df,10,verbose=False)

In [None]:
print( type(variable_types), type(binary), type(continuous), type(ordinal) ) # All are list
variable_types[:3], binary[:3], continuous[:3], ordinal[:3]

### Simplest univariate quantification (can work when there are missing values)

We need to impute missing values in the table, but for this we need some quantification already. Simple univariate quantification can be done via quantify_dataframe_univariate function 

Note that we have written down just in case the quantification parameters such that we could use them after imputation of missing values and restore the data table to the initial variable scales.

Now we can impute the missing values. One of simple ideas is to compute SVD on the complete part of the data matrix and then project the data points with missing variables onto the principal components. The imputed value will be the value of the variable in the projection point.

In [None]:
if 0:
    dfq,replacement_info = quantify_dataframe_univariate(df,variable_types)
    with open('temp.txt','w') as fid:
        fid.write(replacement_info)
else:        
    dfq = df        

### Simplest missing value imputation using SVD computed at complete part of the dataset

In [None]:
if 0:
    dfq_imputed = SVDcomplete_imputation_method(dfq, variable_types, verbose=True,num_components=-1)
    #dequant_info = invert_quant_info(replacement_info)
    dequant_info = invert_quant_info(load_quantification_info('temp.txt'))
    df_imputed = dequantify_table(dfq_imputed,dequant_info)
    display(df_imputed)
else:    
    dfq_imputed = df.copy()    
    df_imputed = dfq_imputed

## Now we quantify (and optimize) the ordinal variables via optimal scaling

Now, we are ready to quantify the data table. We will do it by applying optimal scaling to the ordinal values.

In [None]:
df = remove_constant_columns_from_dataframe(df_imputed)
variable_names = [str(s) for s in df.columns[0:]]
X = df[df.columns[0:]].to_numpy()
X_original = X
X_before_scaling = X.copy()
X,cik = optimal_scaling(X,variable_types,verbose=True,vmax=0.6)

In [None]:
X.shape

#### OK, we finished preparing the data matrix X, which is now complete and properly quantified. We also keep the 'original matrix' X_original, with 'raw' values of the variables (will be needed for visualizations)

# Part 2. Computing the principal tree

## Visualization function

In [None]:
from sklearn.decomposition import PCA
try :
    import umap
except:
    print('cannot import umap')

def plot_graph(edges, nodes_positions, data = None, dim_reduction = 'PCA', graph_color = 'black', graph_linewidth=2, 
               plot_data = True, data_linewidth = 1,  data_color = 'tab:red', data_transparency_alpha = 0.9,
               umap_n_neighbors = 50, umap_min_dist = 0.99):
  '''
  #' Plots graphs defined by edges and nodes_positions, optionally - scatter plot the "data" on the same plot,
  #' Optionally performs PCA/etc (depending on dim_reduction)
  #'
  #' @param edges Nx2-shape matrix with edges ends, i.e. edges[k,0], edges[k,1] - ends of k-th edge  
  #' @param nodes_positions  matrix of nodes positions 
  #' @param data  "original dataset", basically arbitrary dataset for scatter plot, it should have same shape[1] as nodes_positions
  #' @param plot_data  True/False - to scatterplot or not data
  #' @param dim_reduction  'PCA', 'plot_first2axis', 'umap'
  #' @param data_color can be a vector or predefined color - argument for c = data_color in scatter

  #' @examples
  # edges = np.array([ [0,1],[1,2],[2,0] ] )
  # nodes_positions = np.random.rand(3,10) # 3 points in 10d space
  # plot_graph(edges, nodes_positions)
  #
  # t = elpigraph_output
  # edges = t[0]['Edges'][0]
  # nodes_positions = t[0]['NodePositions']
  # plot_graph(edges, nodes_positions)
  '''
  str_dim_reduction = dim_reduction
  if dim_reduction in ['PCA', 'umap' ]: #  not 'plot_first2axis':
    if dim_reduction.upper() == 'PCA':
      reducer = PCA()
    elif dim_reduction.lower() == 'umap':
      n_neighbors = umap_n_neighbors#  50
      min_dist= umap_min_dist # 0.99
      #n_components=n_components
      reducer = umap.UMAP( n_neighbors=n_neighbors,        min_dist=min_dist, n_components = 2)

    if data is not None:
      data2 = reducer.fit_transform(data)
      if plot_data == True:
        if data_color is None:
          plt.scatter(data2[:,0],data2[:,1], linewidth = data_linewidth , alpha = data_transparency_alpha)# ,cmap=plt.cm.Paired) # ,c=np.array(irx) 
          plt.xlabel(str_dim_reduction+'1')
          plt.ylabel(str_dim_reduction+'2')
        else:
          #plt.scatter(data2[:,0],data2[:,1] ,cmap=plt.cm.Paired,c= data_color, linewidth = data_linewidth, alpha = data_transparency_alpha ) 
          sns.scatterplot( x=data[:,0], y=data[:,1], hue = data_color )

          plt.xlabel(str_dim_reduction+'1')
          plt.ylabel(str_dim_reduction+'2')
    else:
      reducer.fit(nodes_positions)

    nodes_positions2 = reducer.transform( nodes_positions )
  else:
    if plot_data == True:
      if data is not None:
        if data_color is None:
          plt.scatter(data[:,0],data[:,1] , linewidth = linewidth, alpha = data_transparency_alpha )# ,cmap=plt.cm.Paired) # ,c=np.array(irx) 
        else:
          plt.scatter(data[:,0],data[:,1] ,cmap=plt.cm.Paired,c= data_color , linewidth = data_linewidth, alpha = data_transparency_alpha ) 
          #sns.scatterplot( x=data[:,0], y=data[:,1], hue = data_color )

    nodes_positions2 = nodes_positions

  plt.scatter(nodes_positions2[:,0],nodes_positions2[:,1],c = graph_color, linewidth = graph_linewidth)#, cmap=plt.cm.Paired)

  edgeCount = edges.shape[0]
  for k in range(edgeCount):
    n0 = edges[k,0]
    n1 = edges[k,1]
    x_line = [ nodes_positions2[n0,0],  nodes_positions2[n1,0] ]
    y_line = [ nodes_positions2[n0,1],  nodes_positions2[n1,1] ]
    plt.plot(x_line, y_line, graph_color, linewidth = graph_linewidth) # 'black')

    
edges = np.array([ [0,1],[1,2],[2,0] ] )
nodes_positions = np.random.rand(3,10) # 3 points in 10d space
plot_graph(edges, nodes_positions)
plt.title('Example graph plot with  plot_graph function')
plt.show()

## Loading ClinTrajan libraries

In [None]:
from clintraj_eltree import *
from clintraj_util import *
from clintraj_ml import *
from clintraj_optiscale import *

## First of all, we will reduce the dimension using PCA

In [None]:
reduced_dimension = 30
X = scipy.stats.zscore(X)
pca = PCA(n_components=X.shape[1],svd_solver='full')
Y = pca.fit_transform(X)
v = pca.components_.T
mean_val = np.mean(X,axis=0)
X = Y[:,0:reduced_dimension]

## We are ready to compute the principal tree, let us do it

In [None]:
#import sys
#print(sys.path)
#sys.path.append('/home/zinovyev/anaconda3/lib/python3.7/site-packages')
#print(sys.path)

nnodes = 20
tree_elpi = elpigraph.computeElasticPrincipalTree(X,nnodes, # drawPCAView=True,
                                                  alpha=0.01,Mu=0.1,Lambda=0.05,
                                                  FinalEnergy='Penalized')
tree_elpi = tree_elpi[0]
# some additional pruning of the graph
prune_the_tree(tree_elpi)
# extend the leafs to reach the extreme data points
tree_extended = ExtendLeaves_modified(X, tree_elpi, Mode = "QuantDists", ControlPar = .5, DoSA = False)

## Now we will create two data partitioning, by nodes of the principal tree and by the linear segments of the principal tree

In [None]:
# paritioning the data by tree branches
vec_labels_by_branches = partition_data_by_tree_branches(X,tree_extended)
print(len(set(vec_labels_by_branches)),'labels generated')
# paritioning the data by proximity to nodes
nodep = tree_elpi['NodePositions']
partition, dists = elpigraph.src.core.PartitionData(X = X, NodePositions = nodep, 
                                                    SquaredX = np.sum(X**2,axis=1,keepdims=1),
                                                    MaxBlockSize = 100000000, TrimmingRadius = np.inf
                                                    )
partition_by_node = np.zeros(len(partition))
for i,p in enumerate(partition):
    partition_by_node[i] = p[0]

In [None]:
print(vec_labels_by_branches.shape, np.unique( vec_labels_by_branches) )  # .unique() 
# column4color = 'Pam50 + Claudin-low subtype' #
# vec4color = df_full[column4color]

from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt
pca = PCA
r = pca().fit_transform(X = X )
c = 0; fig = plt.figure(figsize = (20,10))

c+=1; fig.add_subplot(1, 2 , c) 
vec4color = vec_labels_by_branches
sns.scatterplot( x=r[:,0], y=r[:,1], hue = vec4color )
plt.title('PCA for Omics data colored by Graph groups')



c+=1; fig.add_subplot(1, 2 , c) 
column4color = 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color]

sns.scatterplot( x=r[:,0], y=r[:,1], hue = vec4color )
plt.title('PCA for Omics data colored by Pam50 groups')


plt.show()

In [None]:
dict_groups_correspondence_approximate = {0:'Graph LumAB', 1:'Graph LumB',2:'Graph Basal',3:'Graph Her2',4:'Graph LumA'}

In [None]:
df_small['Status'].value_counts()

In [None]:
df_small = df_full[['Relapse Free Status (Months)', 'Relapse Free Status']].copy()
f = 'Pam50 + Claudin-low subtype'
df_small[['Groups1']] = df_full[f].copy()
df_small[['Groups2']] = vec_labels_by_branches
df_small[['Groups2']] = df_small[['Groups2']].astype(int) # .apply(lambda x: int(x))

from sklearn import preprocessing
le = preprocessing.LabelEncoder()
df_small[['Groups1']] = le.fit_transform(df_small[['Groups1']])


df_small.columns = ['Months', 'Status', 'Groups1','Groups2']

for i in range(len(df_small)):
    m = df_small['Months'].iat[i]
    if m > 60: 
        df_small['Months'].iat[i] = 60
        df_small['Status'].iat[i] = 0
        
        
        

df_small.to_csv('RelapseFreeAndGroups.csv', index=False)
df_small


In [None]:
    T = df_small.iloc[:,0]# [c1][m] 
    E = df_small.iloc[:,1]#[c2][m]
    #T = df['Overall Survival (Months)'][m]
    #E =  df['Overall Survival Status'][m].map({'Living':1, 'Deceased':0} )
    kmf = KaplanMeierFitter()
    kmf.fit(T,E)
    kmf.plot()

In [None]:

c1, c2 = 'Relapse Free Status (Months)', 'Relapse Free Status'
T0 = df_full[c1] 
E0 = df_full[c2]


c = 0; fig = plt.figure(figsize = (20,10))

c+=1; fig.add_subplot(1, 2 , c) 

for uv in  np.unique( vec_labels_by_branches):
    mask = vec_labels_by_branches == uv
    print(uv, mask.sum())
  
    T = T0[mask]
    E = E0[mask]
    #T = df['Overall Survival (Months)'][m]
    #E =  df['Overall Survival Status'][m].map({'Living':1, 'Deceased':0} )
    lbl = dict_groups_correspondence_approximate[uv]
    kmf = KaplanMeierFitter(label=lbl)
    kmf.fit(T,E)
    kmf.plot()    
plt.xlim([0,200])
plt.title('Relapse Free' + ' Graph based groups') # str(c2.split(' ')[:2] ) ) 

c+=1; fig.add_subplot(1, 2 , c) 
f = 'Pam50 + Claudin-low subtype'
vec4types = df_full[f]
for uv in  np.unique( vec4types):
    mask = vec4types == uv
    print(uv, mask.sum())
    if mask.sum() < 30 : continue
  
    T = T0[mask]
    E = E0[mask]
    #T = df['Overall Survival (Months)'][m]
    #E =  df['Overall Survival Status'][m].map({'Living':1, 'Deceased':0} )
    lbl = uv # dict_groups[uv]
    kmf = KaplanMeierFitter(label=lbl)
    kmf.fit(T,E)
    kmf.plot()    
plt.xlim([0,200])
plt.title('Relapse Free ' + ' Original PAM50 groups') # str(c2.split(' ')[:2] ) ) 




plt.show()

In [None]:

c1, c2 = 'Relapse Free Status (Months)', 'Relapse Free Status'
T0 = df_full[c1] 
E0 = df_full[c2]


c = 0; fig = plt.figure(figsize = (20,10))

c+=1; fig.add_subplot(1, 2 , c) 

for uv in  np.unique( vec_labels_by_branches):
    mask = vec_labels_by_branches == uv
    print(uv, mask.sum())
  
    T = T0[mask]
    E = E0[mask]
    #T = df['Overall Survival (Months)'][m]
    #E =  df['Overall Survival Status'][m].map({'Living':1, 'Deceased':0} )
    lbl = dict_groups_correspondence_approximate[uv]
    kmf = KaplanMeierFitter(label=lbl)
    kmf.fit(T,E)
    kmf.plot()    
plt.xlim([0,12*5])
plt.title('Relapse Free' + ' Graph based groups') # str(c2.split(' ')[:2] ) ) 

c+=1; fig.add_subplot(1, 2 , c) 
f = 'Pam50 + Claudin-low subtype'
vec4types = df_full[f]
for uv in  np.unique( vec4types):
    mask = vec4types == uv
    print(uv, mask.sum())
    if mask.sum() < 30 : continue
  
    T = T0[mask]
    E = E0[mask]
    #T = df['Overall Survival (Months)'][m]
    #E =  df['Overall Survival Status'][m].map({'Living':1, 'Deceased':0} )
    lbl = uv # dict_groups[uv]
    kmf = KaplanMeierFitter(label=lbl)
    kmf.fit(T,E)
    kmf.plot()    
plt.xlim([0,12*5])
plt.title('Relapse Free ' + ' Original PAM50 groups') # str(c2.split(' ')[:2] ) ) 




plt.show()

## Let us visualize the tree, with data points, colored by the tree segments

In [None]:
column4color = 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color].values.copy()
from sklearn import preprocessing
le = preprocessing.LabelEncoder()
vec4color = le.fit_transform(vec4color) # [1, 1, 2, 6])

fig = plt.figure(figsize=(8, 8))
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,'k',variable_names,
                          Color_by_partitioning = True, visualize_partition =vec4color)
plt.legend()
plt.show()

In [None]:
column4color = 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color]

fig = plt.figure(figsize=(8, 8))
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,'k',variable_names,
                          Color_by_partitioning = True, visualize_partition =vec4color)
plt.legend()
plt.show()

In [None]:
column4color = 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color]

fig = plt.figure(figsize=(8, 8))
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,'k',variable_names,
                          Color_by_partitioning = True, visualize_partition =vec4color)
plt.colorbar()
plt.show()

In [None]:
fig = plt.figure(figsize=(8, 8))
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,'k',variable_names,
                          Color_by_partitioning = True, visualize_partition = vec_labels_by_branches)
plt.show()

In [None]:
column4color =  'Nottingham prognostic index'# 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color]
fig = plt.figure(figsize=(8, 8))
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,'k',variable_names,
                          Color_by_partitioning = True, visualize_partition = vec4color)
plt.title('colored by'+column4color)
plt.show()


In [None]:
column4color = 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color]

from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt
pca = PCA
r = pca().fit_transform(X = X )
plt.figure(figsize = (20,10))
sns.scatterplot( x=r[:,0], y=r[:,1], hue = vec4color )
plt.title('PCA for Omics data colored by Pam50 groups')
plt.show()

In [None]:

column4color =  'Nottingham prognostic index'# 'Pam50 + Claudin-low subtype' #
vec4color = df_full[column4color]

plt.figure(figsize = (20,10))
sns.scatterplot( x=r[:,0], y=r[:,1], hue = vec4color )
plt.title('PCA for Omics data colored by Pam50 groups')
plt.show()

## Now let us show, on top of the tree, lethal cases, and show the lethality trend by the edge width

In [None]:
fig = plt.figure(figsize=(8, 8))
non_lethal_feature = 'LET_IS_0'
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,'k',variable_names,
                          Color_by_feature=non_lethal_feature, Feature_Edge_Width=non_lethal_feature,
                           Invert_Edge_Value=True,Min_Edge_Width=10,Max_Edge_Width=50,
                           Visualize_Edge_Width_AsNodeCoordinates=True,cmap='winter')
plt.show()

## Let us higlight patients with AGE<65 and having bronchyal asthma in their anamneses

In [None]:
fig = plt.figure(figsize=(8, 8))
inds = np.where((X_original[:,variable_names.index('AGE')]<=65)&(X_original[:,variable_names.index('zab_leg_03')]==1))[0]
colors = ['k' for i in range(len(X))]
for i in inds:
    colors[i] = 'r'
colors = list(colors)
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,colors,variable_names,
                          highlight_subset=inds,Big_Point_Size=100,cmap='hot')

plt.show()

## Now let us quantify the pseudotime value, for each trajectory

#### 1. We need to specify the root node. In order to do this, we will highlight all cases without any myocardial infarction complications, and will select the node where the complications are rare. In order to make the selection visual, we will show the node numbers as well

In [None]:
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(1,1,1)
complication_vars = ['FIBR_PREDS','PREDS_TAH','JELUD_TAH','FIBR_JELUD',
                     'A_V_BLOK','OTEK_LANC','RAZRIV','DRESSLER',
                     'ZSN','REC_IM','P_IM_STEN']
inds_compl = [variable_names.index(a) for a in complication_vars]
lethal = 1-X_original[:,variable_names.index('LET_IS_0')]
has_complication = np.sum(X_original[:,inds_compl],axis=1)>0
inds = np.where((has_complication==0)&(lethal==0))[0]
colors = ['r' for i in range(len(X))]
for i in inds:
    colors[i] = 'k'
visualize_eltree_with_data(tree_extended,X,X_original,v,mean_val,colors,variable_names,
                          highlight_subset=inds,Big_Point_Size=2,Normal_Point_Size=2,showNodeNumbers=True)
add_pie_charts(ax,tree_extended['NodePositions2D'],colors,['r','k'],partition,scale=30)
plt.show()
root_node = 8
print('Root node=',8)

#### 2. Now we are ready to quantify pseudo-time

In [None]:
all_trajectories,all_trajectories_edges = extract_trajectories(tree_extended,root_node)
print(len(all_trajectories),' trajectories found.')
ProjStruct = project_on_tree(X,tree_extended)
PseudoTimeTraj = quantify_pseudotime(all_trajectories,all_trajectories_edges,ProjStruct)

#### 3. Let us find all associations by regression of a clinical variable with pseudotime along all trajectories

In [None]:
vars = ['ritm_ecg_p_01','ritm_ecg_p_02','ritm_ecg_p_04']
for var in vars:
    List_of_Associations = regression_of_variable_with_trajectories(PseudoTimeTraj,var,variable_names,
                                                                    variable_types,X_original,R2_Threshold=0.5,
                                                                    producePlot=True,
                                                                    Continuous_Regression_Type='gpr',
                                                                    verbose=True)

#### 4. We can plot several variable dependencies against pseudotime

In [None]:
pstt = PseudoTimeTraj[1]
colors = ['r','b','g']
for i,var in enumerate(vars):
    vals = draw_pseudotime_dependence(pstt,var,variable_names,variable_types,X_original,colors[i],
                                               linewidth=3,draw_datapoints=False)
plt.legend()
plt.show()


#### 5. Now let us show how we can plot anything as a function of pseudotime, for example, cumulative hazard of death estimated using standard survival analysis

In [None]:
import lifelines
from lifelines import SplineFitter
from lifelines import NelsonAalenFitter
from lifelines import KaplanMeierFitter
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w','tab:pink','tab:green']

event_data = np.zeros((len(df),2))
events = 1-np.array(df['LET_IS_0'])
label = 'Death'

for i,pstt in enumerate(PseudoTimeTraj):
    points = pstt['Points']
    times = pstt['Pseudotime']
    for i,p in enumerate(points):
        event_data[p,0] = times[i]
        event_data[p,1] = events[p]

plt.figure(figsize=(8,8))

for i,pstt in enumerate(PseudoTimeTraj):
    TrajName = 'Trajectory:'+str(pstt['Trajectory'][0])+'--'+str(pstt['Trajectory'][-1])
    points = pstt['Points']
    naf = NelsonAalenFitter()
    T = event_data[points,0]
    E = event_data[points,1]
    naf.fit(event_data[points,0], event_observed=event_data[points,1],label=TrajName)  
    naf.plot_hazard(bandwidth=3.0,fontsize=20,linewidth=10,color=colors[i])

