In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import itertools
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import math
from matplotlib.colors import TwoSlopeNorm
import scipy
from matplotlib.cm import ScalarMappable
import matplotlib as mpl
from matplotlib.patches import Rectangle
import networkx as nx
from matplotlib_venn import venn3
from scipy.interpolate import make_interp_spline
from sklearn.base import clone
from sklearn.utils import check_random_state
from sklearn.metrics import adjusted_rand_score, silhouette_score, normalized_mutual_info_score, adjusted_mutual_info_score
import warnings
from fastcluster import linkage
from scipy.cluster.hierarchy import dendrogram, set_link_color_palette
from matplotlib.colors import rgb2hex, colorConverter
from sklearn.cluster import AgglomerativeClustering
warnings.filterwarnings('ignore')


In [None]:
def make_corr_graph(hspc_type, threshold, absolute = False, just_positive = False):
    #correlations_pivot = pd.read_csv("pathway_correlations/" + hspc_type + "_correlations_pivot.csv", index_col = 0)
    correlations_pivot = pd.read_csv("pseudotime_lin_temp/" + hspc_type + "_correlations_pivot_ZSCORE.csv", index_col = 0)
    
    links = correlations_pivot.stack().reset_index()
    links.columns = ['Pathway 1', 'Pathway 2', 'Value']
    links = links[ abs(links['Value']) > threshold ]
    
    if just_positive:
        links = links[ links['Value'] > 0 ]
    
    if absolute:
        links['Value'] = abs(links["Value"])
    
    graph = nx.from_pandas_edgelist(links, 'Pathway 1', 'Pathway 2', ['Value'])
    return graph

# Degree centrality (G::graph, N::node) <- Degree(N)/#(G)-1
# Degree(G::graph, N::node) <- Number of connections N has

def get_degree_centrality(hspc_type, threshold ,absolute=True):
    #print(hspc_type)
    graph = make_corr_graph(hspc_type, threshold, absolute)
    total_degrees = sum(dict(graph.degree()).values())
    
    
    # Fraction of the node its connected to
    degree_ratio = nx.degree_centrality(graph)
    # instead, should we be consdiering closeness centrality? This 
    # would take into account the actual correlation values
    for u, v, data in graph.edges(data=True):
        data['distance'] = 1 / data['Value']
    
    degree_ratio = nx.closeness_centrality(graph, distance = 'distance')
    
    
    
    
    node_prop_df = pd.DataFrame.from_dict(degree_ratio.items())
    if (len(node_prop_df) > 0):
        node_prop_df.columns = ['Pathway name', 'degree proportion']
        node_prop_df = node_prop_df.set_index('Pathway name')
        
    else:
        node_prop_df = pd.DataFrame()
        node_prop_df['Pathway name'] = [0]
        node_prop_df['degree proportion'] = [0]
        node_prop_df = node_prop_df.set_index('Pathway name')
    return node_prop_df

def filter_out_rows(df, threshold):
    row_max = df.max(axis = 1)
    keep = row_max[row_max >= threshold].index
    return df.loc[keep]

def filter_rows_rank(df, threshold, n = 40):
    row_max_names = df.max(axis = 1).sort_values(ascending = False).index[:n]
    return df.loc[row_max_names]

rng = np.random.RandomState(1)

def cluster_stability(X, est, n_iter=200, random_state=None):
    labels = []
    indices = []
    for i in range(n_iter):
        # draw bootstrap samples, store indices
        sample_indices = rng.randint(0, X.shape[0], X.shape[0])
        indices.append(sample_indices)
        est = clone(est)
        X_bootstrap = X.iloc[sample_indices]
        est.fit(X_bootstrap)
        # store clustering outcome using original indices
        relabel = -np.ones(X.shape[0])
        relabel[sample_indices] = est.labels_
        labels.append(relabel)
    scores = []
    for l, i in zip(labels, indices):
        for k, j in zip(labels, indices):
            # we also compute the diagonal which is a bit silly
            in_both = np.intersect1d(i, j)
            scores.append(adjusted_rand_score(l[in_both], k[in_both]))
    return np.mean(scores)

def extract_mathdefault_num(s):
    start = '$\\mathdefault{'
    end = '}$'
    return s[s.find(start)+len(start):s.rfind(end)]

In [None]:
def get_centrality_clustered_df(lineage, corr_threshold = 0.1, 
                              degree_centrality_threshold = 0.1
                             ):
    degree_node_centrality = []

    for i in lineage:
        degree_node_centrality.append(get_degree_centrality(i, corr_threshold))
        
    hspc_degree_centrality = pd.concat(degree_node_centrality, axis = 1)
    hspc_degree_centrality.columns = [str(i) for i in range(0,len(hspc_degree_centrality.columns))]
    hspc_degree_centrality = hspc_degree_centrality.fillna(0)
    hspc_degree_centrality = filter_out_rows(hspc_degree_centrality, degree_centrality_threshold)
    
    if 0 in hspc_degree_centrality.index:
        hspc_degree_centrality = hspc_degree_centrality.drop(0) 
    
    #return hspc_degree_centrality
    
    cg = sns.clustermap(hspc_degree_centrality, yticklabels = hspc_degree_centrality.index,
                    xticklabels = [], metric = 'correlation',
                     linewidth = 0.1, linecolor = 'k', 
                    square = True, cmap = 'gist_heat_r',
                    col_cluster = False, row_cluster = True)
    
    clustering_order = [i.get_text() for i in cg.ax_heatmap.yaxis.get_majorticklabels()]
    hspc_degree_centrality = hspc_degree_centrality.T
    hspc_degree_centrality = hspc_degree_centrality[clustering_order]
    hspc_degree_centrality = hspc_degree_centrality.T
    
    return hspc_degree_centrality

In [None]:
def plot_centrality_and_weights(centrality_dict, lineage, corr_threshold = 0.1, 
                              degree_centrality_threshold = 0.1,
                               lineage_name = 'test_corr', sww = None):
    
    
    
    ending_early_frame = 50
    ending_next_frame = 112
    
    
    lineage_number = lineage_name.split('n')[1]
    peak2_windows = centrality_dict[lineage_number]['peak2_windows']
    peak3_windows = centrality_dict[lineage_number]['peak3_windows']
    
    # GET STEP AND WINDOW SIZE FROM FILE NAMES
    step_size = lineage[0].split('_')[3]
    window_size = lineage[0].split('_')[4]
    
    # PREPARE PLOTS...
    fig, ((ax1, cbar_ax), (ax2, dummy_ax)) = plt.subplots(nrows=2, ncols=2, figsize=(4, 7),
                                                          sharex='col',
                                                      gridspec_kw={'height_ratios': [6, 30],
                                                                   'width_ratios': [20, 1]})
    
    # CACLUATE THE CENTRALITY OF EACH WINDOW
    centrality_df = get_centrality_clustered_df(lineage, corr_threshold = corr_threshold, 
                              degree_centrality_threshold = degree_centrality_threshold
                             )
    
    # PLOT THE CENTRALITY
    heatmap = sns.heatmap(centrality_df, cmap="gist_heat_r", cbar_ax=dummy_ax,
                          xticklabels=False, linewidths=0, 
                ax=ax2, yticklabels = False, linecolor = 'k',cbar_kws={'label': 'Pathway Cross-Talk'})
    heatmap.add_patch(Rectangle((0,0), len(centrality_df.columns), len(centrality_df.index), fill = False,
                               edgecolor = 'black', lw = 1))
    ax2.set_xlabel('Pseudotime ' + "\u2192")
    ax2.set_ylabel('Pathway')
    
    
    
    # GET THE AVERAGE WEIGHT PER WINDOW
    if sww == None:
        sliding_window_weight = sliding_window_weights(lineage_name,
                                                   window_size = int(window_size),
                                                   step_size = int(step_size))
    else:
        sliding_window_weight = sww
    
    # PLOT THE WINDOW WEIGHT
    x = [i + 0.5 for i in np.array(range(len(centrality_df.columns)))]
    y = sliding_window_weight
    print(len(y))
    ax1.plot(x, y, c='#3b3bc4')
    
    # PLOT THE AVERAGE WINDOW CROSS-TALK
    axTwin = ax1.twinx()
    x_mean = np.array(range(len(centrality_df.columns)))
    y_mean = centrality_df.mean()
    # Create a smoother curve
    z = np.polyfit(x_mean,
                 y_mean, 20)
    polyfitted = np.poly1d(z)
    x_smooth = [i + 0.5 for i in x_mean]
    #axTwin.plot(x_smooth, polyfitted(x_smooth), c = '#c93853')
    axTwin.plot(x_smooth, y_mean, c = '#c93853')
    
    # TOUCH UP THE FIGURE
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    axTwin.spines['top'].set_visible(False)
    
    # PLOT THE EARLY CENTRALITY PEAK WINDOW BOUNDS...
    #ax1.axvline(x = ending_early_frame, linestyle = 'dotted', c= '#F5B700', linewidth = 2)
    #ax1.axvline(x = ending_next_frame, linestyle = 'dotted', c= '#008bf8', linewidth = 2)
    
    ax1.axvline(x = min(peak2_windows), linestyle = 'dotted', c= 'black', linewidth = 2)
    ax1.axvline(x = max(peak2_windows), linestyle = 'dotted', c= 'black', linewidth = 2)
    
    ax1.axvline(x = min(peak3_windows), linestyle = 'dotted', c= 'purple', linewidth = 2)
    ax1.axvline(x = max(peak3_windows), linestyle = 'dotted', c= 'purple', linewidth = 2)
    
    ax1.set_ylabel('Mean Lineage\nCommitment Score', c = '#3b3bc4')
    axTwin.set_ylabel('Mean Pathway\nCross-Talk', c = '#c93853')
    cbar_ax.axis('off')
    
    cbar = heatmap.collections[0].colorbar
    cbar_ax = cbar.ax
    for spine in cbar_ax.spines.values():
        spine.set_visible(True)
        spine.set(lw = 1, edgecolor = 'k')
    
    #pos1 = ax1.get_position()  # Get the original position
    #pos2 = ax_celltype.get_position()
    #gap = pos2.y1 - pos1.y0     # Calculate the gap between the subplots
    #ax1.set_position([pos1.x0 , pos1.y0 + gap, pos1.width, pos1.height])
    
    #pos1 = ax_celltype.get_position()  # Get the original position
    #pos2 = ax2.get_position()
    #gap = pos2.y1 - pos1.y0     # Calculate the gap between the subplots
    #ax2.set_position([pos2.x0 , pos2.y0 - gap, pos2.width, pos2.height])
    #pos3 = dummy_ax.get_position()
    #dummy_ax.set_position([pos3.x0 , pos3.y0 - gap, pos3.width, pos3.height])
    
    fig.set_size_inches(5,6)
    plt.show()
    
    return centrality_df

In [None]:
def sliding_window_celltypes(lineage_name, window_size = 100, step_size = 10):
    weightDF = pd.read_csv("../hspc_pseudotime_metacell_" + lineage_name + "_metadata_weights_v2_ALL_CELLS.csv", index_col = 0)
    weightDF = weightDF.sort_values(by = 'median_pseudotime')

    sliding_window_weights = pd.DataFrame()
    
    starting_ind, ending_ind = 0, window_size
    
    while ending_ind <= len(weightDF):
        
        pseudotime_window_sub = weightDF.iloc[list(range(starting_ind,ending_ind)), ]
        pseudotime_window_sub = pseudotime_window_sub[['singlecelltype_HSC', 'singlecelltype_MPP.Flk2n', 'singlecelltype_MPP.Flk2p',
                                                       'singlecelltype_CMP', 
                                                      'singlecelltype_CLP', 'singlecelltype_GMP', 'singlecelltype_MEP', 
                                                      'singlecelltype_MkP']]
        
        celltype_fraction = pseudotime_window_sub.sum(axis = 0)
        
        sliding_window_weights[starting_ind] = celltype_fraction
        
        starting_ind += step_size
        ending_ind += step_size
        
    return sliding_window_weights

In [None]:
def sliding_window_weights(lineage_name, window_size = 100, step_size = 10, subsampling = 0):
    
    
    weightDF = pd.read_csv("../hspc_pseudotime_metacell_" + str(lineage_name) + "_metadata_weights_v2_ALL_CELLS.csv")

    if subsampling > 0:
        subsample_ids = pd.read_csv("../hspc_pseudotime_metacell_" + str(lineage_name) + "_subsample_ID_" + str(subsampling) + ".csv",
                                    index_col = 0)['ids']
        weightDF = weightDF.loc[subsample_ids,]
        
    weightDF = weightDF.sort_values(by = 'median_pseudotime')
    
    pseudotime_weights = list(weightDF['mean_pseudotime_weight'])

    sliding_window_weights = []
    
    starting_ind, ending_ind = 0, window_size
    
    while ending_ind <= len(pseudotime_weights):
        
        pseudotime_window_weights = pseudotime_weights[starting_ind:ending_ind]
        
        sliding_window_weights.append(sum(pseudotime_window_weights)/window_size)
        
        starting_ind += step_size
        ending_ind += step_size
        
    return sliding_window_weights

In [None]:
CORR_THRESHOLD = 1.96 # this is a z-score threshold...
CENTRALITY_THRESHOLD = 0.1

In [None]:
#lineage1 = ['lin1_v2_allcells_downsampled_608_' + str(i) + "_10_200" for i in range(0,2630,10)]
lineage1 = ['lin1_v2_allcells_downsampled_608_' + str(i) + "_20_500" for i in range(0,2330,20)]
lineage2 = ['lin2_v2_allcells_downsampled_624_' + str(i) + "_20_500" for i in range(0,1850,20)]
lineage3 = ['lin3_v2_allcells_downsampled_362_' + str(i) + "_20_500" for i in range(0,1230,20)]
lineage4 = ['lin4_v2_allcells_downsampled_197_' + str(i) + "_20_500" for i in range(0,450,20)]

# get the subsampled cell IDs
# sliding_window_weight function for subsampled group...

subsample_numbers = {'1':608,
                     '2':624,
                     '3':362,
                     '4':197
                     }

lineage_peak1_windows = {'1' : range(0,1),
                        '2' : range(0,1),
                        '3' : range(0,1),
                        '4' : range(0,1)}

lineage_peak2_windows = {'1' : range(5,25),
                        '2' : range(1,21),
                        '3' : range(2,18),
                        '4' : range(0,9)}

lineage_peak3_windows = {'1' : range(48,73),
                        '2' : range(37,57),
                        '3' : range(27,44),
                        '4' : range(13,20)}


lineages = [lineage1, lineage2, lineage3, lineage4]
lineage_ids = ['1', '2', '3', '4']

lineage_name = {'1' : 'GMP/Gran. Lineage',
               '2' : 'MEP Lineage',
                '3' : 'GMP/Mono. Lineage',
               '4' : 'Lymphoid Lineage'}

centrality_dict_downsampled = {}

for lineage, lineage_id in zip(lineages, lineage_ids):
    centrality_dict_downsampled[lineage_id] = {}
    centrality_dict_downsampled[lineage_id]['lineage_name'] = lineage_name[lineage_id]
    centrality_dict_downsampled[lineage_id]['peak1_windows'] = lineage_peak1_windows[lineage_id]
    centrality_dict_downsampled[lineage_id]['peak2_windows'] = lineage_peak2_windows[lineage_id]
    centrality_dict_downsampled[lineage_id]['peak3_windows'] = lineage_peak3_windows[lineage_id]
    
    sww = sliding_window_weights('lin'+lineage_id, 500, 20, subsampling = subsample_numbers[lineage_id])
    #swct = sliding_window_celltypes('lin'+lineage_id, 200, 10)
    
    centrality_heatmap = plot_centrality_and_weights(centrality_dict_downsampled, lineage,
                                                    corr_threshold = CORR_THRESHOLD,
                                                    degree_centrality_threshold = CENTRALITY_THRESHOLD,
                                                    lineage_name = 'lin' + lineage_id,
                                                    sww = sww)
    centrality_dict_downsampled[lineage_id]['centrality_matrix'] = centrality_heatmap
    print(len(centrality_heatmap))
    # GET A NEW PATHWAY ACTIVITY THING FOR TEH DOWNSAMPLED STUFF HEHE
    
    lineage_pathway_activity_median = pd.read_csv("pseudotime_lineages/sliding_window/lin" + lineage_id + '_v2_allcells/pathway_activity_subsampled_' + str(subsample_numbers[lineage_id]) +'_MEDIAN.csv',
           index_col = 0)
    lineage_pathway_activity_mean = pd.read_csv("pseudotime_lineages/sliding_window/lin" + lineage_id + '_v2_allcells/pathway_activity_subsampled_' + str(subsample_numbers[lineage_id]) +'.csv',
           index_col = 0)
    
    pathway_activity_windows_norm = lineage_pathway_activity_median.div(lineage_pathway_activity_median.max(axis=1), axis=0)
    pathway_activity_windows_norm = pathway_activity_windows_norm.dropna(how='all')
    
    centrality_dict_downsampled[lineage_id]['activity_matrix_norm_median'] = pathway_activity_windows_norm
    #centrality_dict_downsampled[lineage_id]['activity_matrix_median'] = lineage_pathway_activity
    
    pathway_activity_windows_norm = lineage_pathway_activity_mean.div(lineage_pathway_activity_mean.max(axis=1), axis=0)
    pathway_activity_windows_norm = pathway_activity_windows_norm.dropna(how='all')
        
    centrality_dict_downsampled[lineage_id]['activity_matrix_norm_mean'] = pathway_activity_windows_norm
    #centrality_dict_downsampled[lineage_id]['activity_matrix_mean'] = lineage_pathway_activity