## Local Effects

One of the more important aspects of random forest nodes, and by extension node clusters, is that they describe what we would call "Local Effects"

While a conventional linear regression might describe a linear relationship between the behavior of a feature and a target that is true across the entire dataset, a node in a random forest may just as easily be a child of another node, and thus only trained on a small part of the dataset. Therefore a relationship that it describes between a feature and a target may be true across the entire dataset, or it may only be true conditionally on the predictions made by the parents of the node.

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import sys
sys.path.append('/Users/bbrener1/battle/rf_5/rusty_axe')
sys.path.append('/Users/bbrener1/battle/rf_5/')
import tree_reader as tr
import lumberjack

# sys.path.append('../')
# import rusty_axe.tree_reader as tr 
# import rusty_axe.lumberjack

import pickle 

# data_location = "../data/aging_brain/"
data_location = "/Users/bbrener1/battle/rusty_forest_4/data/aging_brain/"

young = pickle.load(open(data_location + "aging_brain_young.pickle",mode='rb'))
old = pickle.load(open(data_location + "aging_brain_old.pickle",mode='rb'))

filtered = pickle.load(open(data_location + "aging_brain_filtered.pickle",mode='rb'))

batch_encoding = np.loadtxt(data_location + 'aging_batch_encoding.tsv')
batch_encoding = batch_encoding.astype(dtype=bool)

young_mask = np.zeros(37069,dtype=bool)
old_mask = np.zeros(37069,dtype=bool)

young_mask[:young.shape[0]] = True
old_mask[young.shape[0]:] = True

forest = tr.Forest.load(data_location + 'cv_forest_trimmed_extra')
forest.arguments

In [None]:
# forest = lumberjack.fit(
#     young.X,
#     header=young.var_names,
#     trees=100,
#     ifs=500,
#     ofs=500,
#     ss=500,
#     depth=8,
#     leaves=10,
#     sfr=1,
#     norm='l1',
#     dispersion_mode='var',
#     standardize='true',
#     reduction=10,
#     reduce_input='true',
#     reduce_output='true'
# )
# forest.set_cache(True)
# forest.trim(.01)
# forest.backup(data_location+"cv_forest_13")

In [None]:
len(forest.nodes())

In [None]:
filtered_feature_mask = np.zeros(2000,dtype=bool)

for feature in forest.output_features:
    f_i = list(young.var_names).index(feature)
    filtered_feature_mask[f_i] = True
    
young_filtered = young[:,filtered_feature_mask]
young_filtered.shape

old_filtered = old[:,filtered_feature_mask]
old_filtered.shape

np.savetxt("filtered_feature_mask.txt",filtered_feature_mask)

In [None]:
forest.reset_split_clusters()

forest.interpret_splits(
    depth=10,
    mode='partial_absolute',
    metric='cosine',
    pca=100,
    relatives=False,
    k=20,
    resolution=1,
)

print(len(forest.split_clusters))

In [None]:
forest.maximum_spanning_tree(mode='samples')

In [None]:

# forest.tsne_coordinates = young.obsm['X_umap']
# forest.html_tree_summary(n=5)

In [None]:
# forest.backup(data_location + "full_clustering")

In [None]:
# We now would like to see if there are any local associations that are dramatically different
# from global ones, to the degree that it is impossible to recapture them using PCA-based analysis. 

# We will need to perform a PCA analysis first. 

# from sklearn.decomposition import PCA

# model = PCA(n_components=40).fit(young_filtered.X)
# transformed = model.transform(young_filtered.X)
# recovered = model.inverse_transform(transformed)

# centered = young_filtered.X - np.mean(young_filtered.X,axis=0)
# null_squared_residual = np.power(centered,2)

# recovered_residual = young_filtered.X - recovered
# recovered_squared_residual = np.power(recovered_residual,2)

# pca_recovered_per_sample = np.sum(recovered_squared_residual,axis=1)
# pca_recovered_fraction_per_sample = np.sum(recovered_squared_residual,axis=1) / np.sum(null_squared_residual,axis=1)
# print(np.sum(null_squared_residual))
# print(np.sum(recovered_squared_residual))

# print(f"Remaining variance:{(np.sum(recovered_squared_residual) / np.sum(null_squared_residual))}")


from sklearn.decomposition import PCA

model = PCA(n_components=25).fit(young.X)
transformed = model.transform(young.X)
recovered = model.inverse_transform(transformed)

centered = young.X - np.mean(young.X,axis=0)
null_squared_residual = np.power(centered,2)

recovered_residual = young.X - recovered
recovered_squared_residual = np.power(recovered_residual,2)

pca_recovered_per_sample = np.sum(recovered_squared_residual,axis=1)
pca_recovered_fraction_per_sample = np.sum(recovered_squared_residual,axis=1) / np.sum(null_squared_residual,axis=1)
print(np.sum(null_squared_residual))
print(np.sum(recovered_squared_residual))

print(f"Remaining variance:{(np.sum(recovered_squared_residual) / np.sum(null_squared_residual))}")

In [None]:
# Here we specify two interesting features and see what the weights for them are in each PC

f1 = "Syt1"
f2 = "Cd74"

f1_index = list(filtered.var_names).index(f1)
f2_index = list(filtered.var_names).index(f2)
# f1_index = forest.truth_dictionary.feature_dictionary[f1]
# f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_weights = model.components_[:,f1_index]
f2_weights = model.components_[:,f2_index]

# plt.figure()
# plt.scatter(forest.output[:,f1_index],forest.output[:,f2_index],s=1)
# plt.show()

plt.figure()
plt.title(f"PC Weights of {f1} vs {f2}")
plt.plot([.2,-.2],np.array([-.2,.2])*.55,color='red',label="Slope of -.55")
plt.plot([.2,-.2],np.array([0,0]),'--',color='lightgray')
plt.plot([0,0],[.2,-.2],'--',color='lightgray')
plt.scatter(f1_weights,f2_weights,s=2)
plt.xlabel(f"{f1} weight")
plt.ylabel(f"{f2} weight")
for i,(x,y) in enumerate(zip(f1_weights,f2_weights)):
    plt.text(x+.005,y-.01,str(i),fontsize=5)
plt.legend()
plt.show()

for i,pc in enumerate(model.components_):
    print(f"PC:{i}, {f1}:{pc[f1_index]}, {f2}:{pc[f2_index]}")

In [None]:

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

for factor in forest.split_clusters:
    print("=====================================")
    print(factor.name())
    print("=====================================")
    local = factor.local_correlations()
    print(local[f1_index,f2_index])
#     print(f_names)
#     print(discrepancy)

In [None]:
# Here we visualize the loadings of each PC to get a sense for where the PC is making meaningful predictions. 
# (It may give us a hint as to whether or not it specifies a particular cell type)

# plt.figure(figsize=(6,5))
# plt.scatter(*forest.tsne_coordinates.T,c=np.ones(len(forest.samples)),cmap='binary',vmin=0,vmax=1,s=2)
# plt.xlabel("UMAP Embedding, (AU)")
# plt.ylabel("UMAP Embedding, (AU)")
# plt.colorbar()
# plt.show()

for i,pc in enumerate(transformed.T):
    plt.figure(figsize=(6,5))
    plt.title(f"PC {i} Loadings")
    ab_max = np.max(np.abs(pc))
    plt.scatter(*forest.tsne_coordinates.T,c=pc,s=3,alpha=.4,cmap='bwr',vmin=-ab_max,vmax=ab_max)
    plt.xlabel("UMAP Embedding, (AU)")
    plt.ylabel("UMAP Embedding, (AU)")
    plt.colorbar()
    plt.show()

## Local Discrepancy Analysis

In [None]:
# Now we will look for features that have an especially large discrepancy in the local 
# correlation compared to the global correlation for each factor. 

global_correlations = forest.global_correlations()

for factor in forest.split_clusters[1:]:
    fi_pairs = factor.most_local_correlations(n=20)
    features = forest.output_features
    f_names = [(features[i],features[j]) for (i,j) in fi_pairs]
    local_correlations = factor.local_correlations()
    print("=====================================")
    print(factor.name())
    print("=====================================")
    print("F1\tF2\tGlobal\tLocal")
    print("-------------------------------------")
    for f1,f2 in fi_pairs:
        print(f"{features[f1]}\t{features[f2]}\t{np.around(global_correlations[f1,f2],3)}\t{np.around(local_correlations[f1,f2],3)}")
    print("=====================================")
    
# Interesting result from 7: Rtn1,Bcan


In [None]:
# We want to find interesting gene pairs and plot their correlations across several factors to find inconsistencies.

# First we find interesting pairs

global_correlations = forest.global_correlations()

interesting_pairs = []

for factor in forest.split_clusters:
    for ip in factor.most_local_correlations(n=3):
        if ip not in interesting_pairs and (ip[1],ip[0]) not in interesting_pairs:
            interesting_pairs.append(ip)
    
# uniques = list(set([y for x in interesting_pairs for y in x]))
    
interesting_pair_names = [f"{forest.output_features[f1]}, {forest.output_features[f2]}" for (f1,f2) in interesting_pairs]
    
factor_correlation_table = np.zeros((len(interesting_pairs),len(forest.split_clusters)))

# Now we plot the local correlations of each interesting pair in each factor

for i,factor in enumerate(forest.split_clusters):
    local_correlations = factor.local_correlations()
    for j,(f1,f2) in enumerate(interesting_pairs):
        factor_correlation_table[j,i] = local_correlations[f1,f2] - global_correlations[f1,f2]
#         factor_correlation_table[j,i] = local_correlations[f1,f2]

#     local_correlations = factor.local_correlations(indices=uniques)
#     for j,(f1,f2) in enumerate(interesting_pairs):
#         f1_u = uniques.index(f1)
#         f2_u = uniques.index(f2)
#         factor_correlation_table[j,i] = local_correlations[f1_u,f2_u]

# Now we have a table of pairs and how they correlate across a lot of factors. We can plot it, but it's best to
# agglomerate first

# plt.figure()
# plt.imshow(factor_correlation_table,interpolation='none',aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

from scipy.cluster.hierarchy import linkage,dendrogram

pair_agglomeration = dendrogram(linkage(factor_correlation_table, metric='cosine', method='average'), no_plot=True)['leaves']
factor_agglomeration = dendrogram(linkage(factor_correlation_table.T, metric='cosine', method='average'), no_plot=True)['leaves']

print(len(pair_agglomeration))
print(len(factor_agglomeration))

agg_indices = np.array(interesting_pairs)[pair_agglomeration]
agg_names = np.array(interesting_pair_names)[pair_agglomeration]

plt.figure(figsize=(7,20))
plt.title("Local Correlations of Selected Feature Pairs")
plt.imshow(factor_correlation_table[pair_agglomeration].T[factor_agglomeration].T,interpolation='none',aspect='auto',cmap='bwr',vmin=-1,vmax=1)
plt.yticks(np.arange(len(agg_names)),labels=agg_names)
plt.colorbar(label="Local Correlation")
plt.xlabel("Factors")
plt.xticks(np.arange(len(factor_agglomeration)),labels=factor_agglomeration,rotation=90)
plt.show()


print([(x,y) for x,y in enumerate(agg_names)])

In [None]:
forest.split_clusters[0].local_correlations()


In [None]:
forest.global_correlations()
# forest.output_features[1639]

# print(forest.split_clusters[23].local_correlations(indices=[717,1639]))
# print(forest.split_clusters[20].local_correlations(indices=[717,1639]))

# cluster 23, Rrares2 (717), Meg3 (1639)

In [None]:
# Here we check the naive linear fit between two features (eg a simple correlation among all cells)

from scipy.stats import linregress

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_values = forest.output[:,f1_index]
f2_values = forest.output[:,f2_index]

slope,intercept,r_fit,_,_ = linregress(f1_values,f2_values)

plt.figure(figsize=(3,2.5))
plt.title(f"Linar Fit, {f1}, {f2}, Naive")
plt.scatter(f1_values,f2_values,s=3)
plt.plot(np.arange(7), intercept + (np.arange(7) * slope),c='red',label=f"Slope:{np.around(slope,3)},R2:{np.around(r_fit,3)}")
plt.legend()
plt.xlabel(f"{f1}")
plt.ylabel(f"{f2}")
plt.show()

In [None]:
# Here we filter only for cells that have a high or low sister score for a particular factor
# and linearly regress two genes to check for a "local" association. 


from scipy.stats import linregress

factor = forest.split_clusters[18]
factor_threshold = .1
factor_mask = np.abs(factor.sister_scores()) > factor_threshold

plt.figure()
plt.title(f"Sister scores, {factor.name()}")
plt.hist(factor.sister_scores(),bins=50)
plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.legend()
plt.show()

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_values = forest.output[:,f1_index][factor_mask]
f2_values = forest.output[:,f2_index][factor_mask]

slope,intercept,r_fit,_,_ = linregress(f1_values,f2_values)

plt.figure(figsize=(3,2.5))
plt.title(f"Linar Fit, {f1}, {f2}, \n Factor {factor.name()}, Filtered")
plt.scatter(f1_values,f2_values,s=3)
plt.plot(np.arange(7), intercept + (np.arange(7) * slope),c='red',label=f"Slope:{np.around(slope,3)},R2:{np.around(r_fit,3)}")
plt.xlabel(f"{f1}")
plt.ylabel(f"{f2}")
plt.legend()
plt.show()

plt.figure()
plt.scatter(*forest.tsne_coordinates.T,c=factor_mask,s=1)
plt.show()

In [None]:
# Here we filter only for cells that have a high or low sister score for a particular factor
# and linearly regress two genes to check for a "local" association. 


from scipy.stats import linregress

factor = forest.split_clusters[18]
factor_threshold = .05
factor_mask = np.abs(factor.sister_scores()) > factor_threshold

plt.figure()
plt.title(f"Sister scores, {factor.name()}")
plt.hist(factor.sister_scores(),bins=50)
plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.legend()
plt.show()

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

f1_values = forest.output[:,f1_index][factor_mask]
f2_values = forest.output[:,f2_index][factor_mask]

slope,intercept,r_fit,_,_ = linregress(f1_values,f2_values)

plt.figure(figsize=(3,2.5))
plt.title(f"Linar Fit, {f1}, {f2}, \n Factor {factor.name()}, Filtered")
plt.scatter(f1_values,f2_values,s=3)
plt.plot(np.arange(7), intercept + (np.arange(7) * slope),c='red',label=f"Slope:{np.around(slope,3)},R2:{np.around(r_fit,3)}")
plt.xlabel(f"{f1}")
plt.ylabel(f"{f2}")
plt.legend()
plt.show()

plt.figure()
plt.scatter(*forest.tsne_coordinates.T,c=factor_mask,s=1)
plt.show()

In [None]:
# Here we find highly weighted genes for a particular PC, as well as the rankings of particular features of interest
# Our objective is to see if the two featurs represent an important part of the variance captured by the PC
pc = 8

f1 = "Tmem119"
f2 = "Cd74"

f1_index = forest.truth_dictionary.feature_dictionary[f1]
f2_index = forest.truth_dictionary.feature_dictionary[f2]

weights = model.components_[pc]

weight_sort = np.argsort(np.abs(weights))

print(list(forest.output_features[weight_sort[:-20:-1]]))
print(list(weights[weight_sort[:-20:-1]]))

print(f"{f1}: {len(weights) - list(weight_sort).index(f1_index)}")
print(f"{f2}: {len(weights) - list(weight_sort).index(f2_index)}")

print(weights[f1_index])
print(weights[f2_index])

In [None]:
for s_c in forest.split_clusters:
    scores = s_c.sister_scores()
    log_scores = s_c.log_sister_scores(prior=10)

    abmax=np.max(np.abs(scores))

    plt.figure()
    plt.title("Regular")
    plt.scatter(*forest.tsne_coordinates.T,c=scores,cmap='bwr',s=1,vmin=-abmax,vmax=abmax)
    plt.colorbar()
    plt.show()

    abmax=np.max(np.abs(log_scores))

    plt.figure()
    plt.title("Log")
    plt.scatter(*forest.tsne_coordinates.T,c=log_scores,cmap='bwr',s=1,vmin=-abmax,vmax=abmax)
    plt.colorbar()
    plt.show()

In [None]:
factor = forest.split_clusters[34]

samples = factor.sample_scores()
sisters = factor.sister_scores()
log_sisters = factor.log_sister_scores()

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of Sample Scores In {factor.name()}")
plt.hist(samples,bins=50)
plt.ylabel("Frequency")
plt.xlabel("Sample Scores")
plt.show()

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of Sister Scores In {factor.name()}")
plt.hist(sisters,bins=50)
plt.ylabel("Frequency")
plt.xlabel("Sister Scores")
plt.show()

plt.figure()
plt.hist(log_sisters,bins=50)
plt.show()




In [None]:
# Here we test whether or not a particular factor over-expresses a gene of interest
# (Used as a statistical test for cell type identity, eg "is factor 34 immune cells?")
from scipy.stats import ttest_ind

factor = forest.split_clusters[3]
factor_threshold = .05
mask = factor.sister_scores() > factor_threshold

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of sister scores in Cluster{factor.name()}")
plt.hist(factor.sister_scores(),bins=50,log=True)
# plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
# plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.xlabel("Sister scores")
plt.ylabel("Frequency")
plt.show()

feature = "Cldn5"

f_index = list(young.var_names).index(feature)

test = ttest_ind(young.X[mask][:,f_index],young.X[~mask][:,f_index],equal_var=False)

print(f"{feature} in {factor.name()} vs all other: {test}")

from scipy.stats import sem

m1 = np.around(np.mean(young.X[mask][:,f_index]),3)
se1 = np.around(sem(young.X[mask][:,f_index]),3)
m2 = np.around(np.mean(young.X[~mask][:,f_index]),3)
se2 = np.around(sem(young.X[~mask][:,f_index]),3)

print(f"Mean expression: {str(m1)} +/- {str(se1)} vs {str(m2)} +/- {str(se2)}")

plt.figure(figsize=(3,2.5))
plt.title(f"{feature} Mean Expression")
plt.bar([0,1],[m1,m2],yerr=[se1,se2],width=.5,tick_label=[f"NC {factor.name()}","Rest",])
plt.ylabel("Mean Expression (Log TPM)")
plt.show()

In [None]:
# Here we test whether or not a particular factor over-expresses a gene of interest
# (Used as a statistical test for cell type identity, eg "is factor 34 immune cells?")
from scipy.stats import ttest_ind

factor = forest.split_clusters[13]
factor_threshold = .05
mask = factor.sister_scores() > factor_threshold

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of sister scores in Cluster{factor.name()}")
plt.hist(factor.sister_scores(),bins=50)
plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.xlabel("Sister scores")
plt.ylabel("Frequency")
plt.show()

feature = "Cldn5"

f_index = list(young.var_names).index(feature)

test = ttest_ind(young.X[mask][:,f_index],young.X[~mask][:,f_index],equal_var=False)

print(f"{feature} in {factor.name()} vs all other: {test}")

from scipy.stats import sem

m1 = np.around(np.mean(young.X[mask][:,f_index]),3)
se1 = np.around(sem(young.X[mask][:,f_index]),3)
m2 = np.around(np.mean(young.X[~mask][:,f_index]),3)
se2 = np.around(sem(young.X[~mask][:,f_index]),3)

print(f"Mean expression: {str(m1)} +/- {str(se1)} vs {str(m2)} +/- {str(se2)}")

plt.figure(figsize=(3,2.5))
plt.title(f"{feature} Mean Expression")
plt.bar([0,1],[m1,m2],yerr=[se1,se2],width=.5,tick_label=[f"NC {factor.name()}","Rest",])
plt.ylabel("Mean Expression (Log TPM)")
plt.show()

In [None]:
# Here we test whether or not a particular factor over-expresses a gene of interest
# (Used as a statistical test for cell type identity, eg "is factor 34 immune cells?")
from scipy.stats import ttest_ind

factor = forest.split_clusters[13]
factor_threshold = .05
mask = factor.sister_scores() > factor_threshold

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of sister scores in Cluster{factor.name()}")
plt.hist(factor.sister_scores(),bins=50)
plt.plot([factor_threshold,factor_threshold],[-100,100],color='red')
plt.plot([-factor_threshold,-factor_threshold],[-100,100],color='red',label="Sister score threshold")
plt.xlabel("Sister scores")
plt.ylabel("Frequency")
plt.show()

factor_2 = forest.split_clusters[21]
factor_2_threshold = .05
mask_2 = factor_2.sister_scores() > factor_2_threshold

plt.figure(figsize=(3,2.5))
plt.title(f"Distribution of sister scores in Cluster{factor_2.name()}")
plt.hist(factor_2.sister_scores(),bins=50)
plt.plot([factor_2_threshold,factor_2_threshold],[-100,100],color='red')
plt.plot([-factor_2_threshold,-factor_2_threshold],[-100,100],color='red',label="Sister score threshold")
plt.xlabel("Sister scores")
plt.ylabel("Frequency")
plt.show()


feature = "Cldn5"

f_index = list(young.var_names).index(feature)

test = ttest_ind(young.X[mask][:,f_index],young.X[mask_2][:,f_index],equal_var=False)

print(f"{feature} in {factor.name()} vs {factor_2.name()}:")

from scipy.stats import sem

m1 = np.around(np.mean(young.X[mask][:,f_index]),3)
se1 = np.around(sem(young.X[mask][:,f_index]),3)
m2 = np.around(np.mean(young.X[mask_2][:,f_index]),3)
se2 = np.around(sem(young.X[mask_2][:,f_index]),3)

print(f"Mean expression: {str(m1)} +/- {str(se1)} vs {str(m2)} +/- {str(se2)}")
print(test)

plt.figure(figsize=(3,2.5))
plt.title(f"{feature} Mean Expression")
plt.bar([0,1],[m1,m2],yerr=[se1,se2],width=.5,tick_label=[f"NC {factor.name()}",f"{factor_2.name()}",])
plt.ylabel("Mean Expression (Log TPM)")
plt.show()

In [None]:
for feature in young.var_names:
    if "Rad" in feature:
        print(feature)

In [None]:
feature = "Cdkn1a"

f_index = list(young.var_names).index(feature)

plt.figure()
plt.title(f"Distribution of {feature} in Cells")
plt.scatter(*filtered.obsm["X_umap"][young_mask].T,c=young.X[:,f_index],s=young.X[:,f_index]+1)
plt.colorbar(label="Log TPM")
plt.show()

plt.figure()
plt.title(f"Distribution of {feature} in Cells")
plt.scatter(*filtered.obsm["X_umap"][old_mask].T,c=old.X[:,f_index],s=old.X[:,f_index]+1)
plt.colorbar(label="Log TPM")
plt.show()

# Working cell cycle markers:
# Ccn*
# Gem (shitty)
# Plk2?(Unconfirmed, Plk1 is conventioanl)
# Top2a (very specific)
# 


In [17]:
feature = "Fth1"

f_index = list(young.var_names).index(feature)

plt.figure(figsize=(10,10))
plt.scatter(*filtered.obsm['X_umap'][young_mask].T,c=filtered.X[young_mask,f_index],s=4)
plt.colorbar()
plt.show()

plt.figure(figsize=(10,10))
plt.scatter(*filtered.obsm['X_umap'][old_mask].T,c=filtered.X[old_mask,f_index],s=4)
plt.colorbar()
plt.show()

# f_index = list(forest.output_features).index(feature)

# plt.figure(figsize=(10,10))
# plt.scatter(*filtered.obsm['X_umap'][young_mask].T,c=ns_residuals[:,f_index],s=4)
# plt.show()

# plt.figure()
# plt.scatter(ns[:,0],forest.output[:,f_index],alpha=.1)
# plt.show()
# plt.figure()
# plt.scatter(ns[:,0],ns_residuals[:,f_index],alpha=.1)
# plt.show()

##  Examining Partial Effects

In [None]:
factor = forest.split_clusters[21]


additive = np.mean(forest.node_representation(factor.nodes,mode='additive_mean'),axis=0)
additive_sort = np.argsort(additive)

print(forest.output_features[additive_sort[-10:]])
print(additive[additive_sort[-10:]])

rtn_1_index = list(forest.output_features).index('Rtn1')
klk_6_index = list(forest.output_features).index('Klk6')
print(additive[klk_6_index])

partials = forest.node_representation(factor.nodes,mode='partial_absolute') 
partial_means = np.mean(partials,axis=0)
partial_sort = np.argsort(np.abs(partial_means))

print(forest.output_features[partial_sort[:10]])
print(partial_means[partial_sort[:10]])

print(forest.output_features[partial_sort[-10:]])
print(partial_means[partial_sort[-10:]])

print(additive[klk_6_index])
print(partial_means[klk_6_index])

plt.figure()
plt.hist(partials[:,klk_6_index],bins=50)
plt.show()

# ratio = np.abs(additive/partial)
# ratio_sort = np.argsort(ratio)
# print(ratio[ratio_sort])
# print(forest.output_features[ratio_sort[:10]])
# print(forest.output_features[ratio_sort[-10:]])

In [None]:
partials.shape

In [None]:
factor = forest.split_clusters[23]

f_index = list(forest.output_features).index('Klk6')

additive = np.mean(forest.node_representation(factor.nodes,mode='additive_mean'),axis=0)
additive_sort = np.argsort(additive)

print(forest.output_features[additive_sort[-10:]])
print(additive[additive_sort[-10:]])

print(additive[f_index])

partials = forest.node_representation(factor.nodes,mode='partial')
partial_means = np.mean(partials,axis=0)
# partial_sort = np.argsort(np.abs(partial_means))
partial_sort = np.argsort(partial_means)

print(forest.output_features[partial_sort[-10:]])
print(partial_means[partial_sort[-10:]])

print(additive[f_index])
print(partial_means[f_index])

plt.figure()
plt.hist(partials[:,klk_6_index],bins=50)
plt.show()


In [None]:
print(np.mean(young_filtered.X,axis=0)[klk_6_index])
print(np.var(young_filtered.X,axis=0)[klk_6_index])

mask_19p = forest.split_clusters[19].sister_scores() > .05
mask_19m = forest.split_clusters[19].sister_scores() < .05

print(np.mean(young_filtered.X[mask_19p],axis=0)[klk_6_index])
print(np.var(young_filtered.X[mask_19p],axis=0)[klk_6_index])
print(np.mean(young_filtered.X[mask_19m],axis=0)[klk_6_index])
print(np.var(young_filtered.X[mask_19m],axis=0)[klk_6_index])

mask_23p = forest.split_clusters[23].sister_scores() > .05
mask_23m = forest.split_clusters[23].sister_scores() < -.05

print(np.mean(young_filtered.X[mask_23p],axis=0)[klk_6_index])
print(np.var(young_filtered.X[mask_23p],axis=0)[klk_6_index])

print(np.mean(young_filtered.X[mask_23m],axis=0)[klk_6_index])
print(np.var(young_filtered.X[mask_23m],axis=0)[klk_6_index])

print(np.sum(mask_19p))
print(np.sum(mask_19m))
print(np.sum(mask_23p))
print(np.sum(mask_23m))


In [None]:
young_filtered.shape

# for feature in young.var_names:
#     if "Ccn" in feature:
#         print(feature)

In [None]:
parent_class = [p.split_cluster for p in forest.split_clusters[27].parents()]

np.unique(parent_class,return_counts=True) 
np.unique(parent_class,return_counts=True)[1] / len(parent_class)


# plt.figure()
# plt.hist(parent_class)
# plt.show()

In [None]:
import matplotlib.patheffects as PathEffects

factor = forest.split_clusters[34]

n=5
m=10

lt,hd = factor.top_local_table(n)


plt.figure()
plt.title(f"Local Correlations in NC{factor.name()}")
plt.imshow(lt,cmap='bwr',vmin=-1,vmax=1)
for i in range(m):
        for j in range(m):
            text = plt.text(j, i, np.around(lt[i, j], 1),
                           ha="center", va="center", c='w', fontsize=7)
            text.set_path_effects(
                [PathEffects.withStroke(linewidth=.8, foreground='black')])
plt.xticks(ticks=np.arange(10),labels=hd,rotation=45,fontsize=10)
plt.yticks(ticks=np.arange(10),labels=hd,rotation=45,fontsize=10)
plt.colorbar(label="Weighted Correlation")
plt.show()

In [None]:
import matplotlib.patheffects as PathEffects

factor = forest.split_clusters[34]

n=5
m=10

gt,hd = factor.top_global_table(n)


plt.figure()
plt.title("Global Correlations")
plt.imshow(gt,cmap='bwr',vmin=-1,vmax=1)
for i in range(m):
        for j in range(m):
            text = plt.text(j, i, np.around(gt[i, j], 1),
                           ha="center", va="center", c='w', fontsize=7)
            text.set_path_effects(
                [PathEffects.withStroke(linewidth=.8, foreground='black')])
plt.xticks(ticks=np.arange(10),labels=hd,rotation=45,fontsize=10)
plt.yticks(ticks=np.arange(10),labels=hd,rotation=45,fontsize=10)
plt.colorbar(label="Weighted Correlation")
plt.show()

In [None]:
from glob import glob

glob("./*.ipynb")


In [None]:
!pwd

In [None]:
from 

from nbformat import read as nb_read
from nbformat import write as nb_write
