## Existence Of Node Clusters

Here we demonstrate that in random forest that has been trained on some set of data, the nodes can be reasonably organized into clusters.

First, we must train or load a forest:

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc

import sys
sys.path.append('../')
import tree_reader as tr 
import lumberjack

data_location = "/Users/bbrener1/battle/rusty_forest_4/data/aging_brain/"
!ls {data_location}
# data_location = "../data/aging_brain/"

forest = tr.Forest.load(data_location + 'cv_forest_trimmed_extra')
forest.arguments


In [8]:
# print(len(forest.output_features))
# print(len(forest.split_clusters))
print(len(forest.nodes()))


A Random Forest is a collection of decision trees, and a decision tree is a collection of individual decision points, commonly known as "Nodes"

To understand Random Forests and Decision Trees, it is important to understand how Nodes work. Each individual node is a (very crappy) regressor, eg. each Node makess a prediction based on a rule like "If Gene 1 has expression > 10, Gene 2 will have expression < 5", or "If a house is < 5 miles from a school, it will cost > $100,000". A very important property of each node, however, is that it can also have children, which are other nodes. When a node makes a prediction like "If Gene 1 has expression > 10 then Gene 2 has expression < 5", it can pass all the samples for which Gene 1 is > 10 to one of its children, and all the samples for which Gene 1 < 10 to the other child. After that, each one of its children can make a different prediction, which results in compound rules.

This is how a decision tree is formed. A decision tree with a depth of 2 might contain a rule like "If Gene 1 > 10 AND Gene 3 > 10, THEN Gene 2 and Gene 4 are both < 2, which would represent one of the "Leaf" nodes that it has. Leaf nodes are nodes with no children. 

Individual decision trees, then, are somewhat crappy predictors, but they're better than individual nodes. In order to improve the performance of decision trees, we can construct a Random Forest. To construct a random forest, we can train many decision trees on bootstraps of a dataset

If many decision trees are combined and their predictions averaged together, you have a Random Forest, which is a pretty good kind of regressor. 

A practical demonstration might help:

In [None]:
forest.reset_split_clusters()
forest.interpret_splits(depth=4,mode='sample',metric='cosine',pca=100,relatives=False,k=10,resolution=1)

So now that we know that random forests are collections of ordered nodes, we can examine a more interesting question: do certain nodes occur repeatedly in the forest, despite operating on bootstrapped samples? 

In order to examine this question first we must understand different ways of describing a node. I think generally there are three helpful ways of looking at a node:

* **Node Sample Encoding**: A binary vector the length of the number of samples you are considering. 0 or false means the sample is absent from the node. A 1 or true means the sample is present in the node. 

* **Node Mean Encoding**: A float vector the length of the number of targets you are considering. Each value is the mean of the target values for all samples in this node. This is the node's prediction for samples that occur in it.

* **Node Additive Encoding**: A float vector the length of the number of targets you are considering. Each value is THE DIFFERENCE between the mean value for that target in THIS NODE and the mean value for that target IN THE PARENT of this node. For root nodes, which have no parents, the additive encoding is simply th mean value across the entire dataset. (As if the mean of a hypothetical parent would have been 0). This encoding represents the marginal effect of each node.

We should examine if there are any common patterns that appear if we encode many nodes from a forest using each of these representations:

In [9]:
# Here we plot the sample representations of nodes. 
# This generates a set of figures demonstrating the existence of node clusters

from sklearn.decomposition import PCA

# For ease of processing we have to construct dimensionally reduced representations of the encodings. 

nodes = forest.nodes(root=False,depth=3)

sample_encoding = forest.node_representation(nodes,mode='sample')
reduced_sample = PCA(n_components=100).fit_transform(sample_encoding.T)
reduced_sample_node = PCA(n_components=100).fit_transform(sample_encoding)

print(sample_encoding.shape)
print(reduced_sample.shape)
print(reduced_sample_node.shape)

from scipy.cluster.hierarchy import linkage,dendrogram


# sample_agglomeration = dendrogram(linkage(reduced_sample, metric='cosine', method='average'), no_plot=True)['leaves']
# node_sample_agglomeration = dendrogram(linkage(reduced_node, metric='cosine', method='average'), no_plot=True)['leaves']

# plt.figure()
# plt.title("Cell Presence in Node (Clustered)")
# plt.imshow(sample_encoding[node_sample_agglomeration].T[sample_agglomeration].T,cmap='binary',aspect='auto',interpolation='none')
# plt.xlabel("Cells")
# plt.ylabel("Nodes")
# plt.colorbar()
# plt.tight_layout()
# plt.show()

# # # And here we sort the nodes after they have been clustered (more on the clustering procedure in a bit)

# node_cluster_sort = np.argsort([n.split_cluster for n in nodes])

# plt.figure()
# plt.title("Cell Presence in Node (Clustered)")
# plt.imshow(sample_encoding[node_cluster_sort].T[sample_agglomeration].T,cmap='binary',aspect='auto',interpolation='none')
# plt.xlabel("Cells")
# plt.ylabel("Nodes")
# plt.colorbar()
# plt.tight_layout()
# plt.show()

plt.figure()
plt.suptitle("Cell Presence in Node (Two-Way Agglomerated)")
ax1 = plt.axes([0,.7,.8,.2])
node_sample_agglomeration = dendrogram(linkage(reduced_sample_node, metric='cosine', method='average'),orientation='top', no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax2 = plt.axes([.8,0,.2,.7])
sample_agglomeration = dendrogram(linkage(reduced_sample, metric='cosine', method='average'),orientation='right', no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax3 = plt.axes([0,0,.8,.7])
im = plt.imshow(sample_encoding[node_sample_agglomeration].T[sample_agglomeration[::-1]],cmap='binary',aspect='auto',interpolation='none')
plt.xlabel("Nodes")
plt.ylabel("Cells")
plt.tight_layout()
plt.show()


In [10]:
## from sklearn.decomposition import PCA

sister_encoding = forest.node_representation(nodes,mode='sister')
reduced_sister = PCA(n_components=100).fit_transform(sister_encoding.T)
reduced_sister_node = PCA(n_components=100).fit_transform(sister_encoding)

print(sister_encoding.shape)
print(reduced_sister.shape)
print(reduced_sister_node.shape)

# from scipy.cluster.hierarchy import linkage,dendrogram

# sister_agglomeration = dendrogram(linkage(reduced_sister, metric='cosine', method='average'), no_plot=True)['leaves']
# node_sister_agglomeration = dendrogram(linkage(reduced_sister_node, metric='cosine', method='average'), no_plot=True)['leaves']

# plt.figure()
# plt.title("Sample Presence in Node vs Sister (Two-Way Agglomerated)")
# plt.imshow(sister_encoding[node_sister_agglomeration].T[sister_agglomeration].T,cmap='bwr',aspect='auto',interpolation='none')
# plt.xlabel("Samples")
# plt.ylabel("Nodes")
# plt.colorbar()
# plt.tight_layout()
# plt.show()

# plt.figure()
# plt.title("Sample Presence in Node vs Sister (Clustered By Gain)")
# plt.imshow(sister_encoding[node_cluster_sort].T[sister_agglomeration].T,cmap='bwr',aspect='auto',interpolation='none')
# plt.xlabel("Samples")
# plt.ylabel("Nodes")
# plt.colorbar()
# plt.tight_layout()
# plt.show()


In [14]:
# Here we plot the construct and agglomerate the additive gain representation 


feature_encoding = forest.node_representation(nodes,mode='partial_absolute')
reduced_feature = PCA(n_components=100).fit_transform(feature_encoding.T)
reduced_feature_node = PCA(n_components=100).fit_transform(feature_encoding)

minimax = np.max(np.abs(feature_encoding))

# feature_agglomeration = dendrogram(linkage(reduced_feature, metric='cosine', method='average'), no_plot=True)['leaves']
# node_feature_agglomeration = dendrogram(linkage(reduced_node, metric='cosine', method='average'), no_plot=True)['leaves']

# node_cluster_sort = np.argsort([n.split_cluster for n in nodes])

In [15]:
# Here we plot the additive gain representation 

# print(feature_encoding.shape)

# plt.figure()
# plt.title("Target Gain in Node (Double-Agglomerated)")
# plt.imshow(feature_encoding[node_feature_agglomeration].T[feature_agglomeration].T,cmap='bwr',interpolation='none',aspect='auto',vmin=-2,vmax=2)
# plt.xlabel("Genes")
# plt.ylabel("Nodes")
# plt.colorbar(label="Parent Mean - Node Mean (Log TPM)")
# plt.tight_layout()
# plt.show()

# plt.figure()
# plt.title("Target Gain in Node (Clustered)")
# plt.imshow(feature_encoding[node_cluster_sort].T[feature_agglomeration].T,cmap='bwr',interpolation='none',aspect='auto',vmin=-2,vmax=2)
# plt.xlabel("Genes")
# plt.ylabel("Nodes")
# plt.colorbar(label="Parent Mean - Node Mean (Log TPM)")
# plt.tight_layout()
# plt.show()

plt.figure()
plt.suptitle("Gene Variance Explained by Node (Two-Way Agglomerated)")
ax1 = plt.axes([0,.7,.8,.2])
node_feature_agglomeration = dendrogram(linkage(reduced_feature_node, metric='cosine', method='average'),orientation='top', no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax2 = plt.axes([.8,0,.2,.7])
feature_agglomeration = dendrogram(linkage(reduced_feature, metric='cosine', method='average'),orientation='right', no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax3 = plt.axes([0,0,.8,.7])
im = plt.imshow(feature_encoding[node_sample_agglomeration].T[feature_agglomeration[::-1]],cmap='bwr',vmin=-minimax,vmax=minimax,aspect='auto',interpolation='none')
plt.xlabel("Nodes")
plt.ylabel("Genes")
plt.tight_layout()
plt.show()


In [20]:

plt.figure(figsize=(4,4))
plt.suptitle("Node Encodings (Two-Way Agglomerated)")

# ax1 = plt.axes([0,.7,.8,.2])
# node_sister_agglomeration = dendrogram(linkage(reduced_sister_node, metric='cosine', method='average'), orientation='top', no_plot=False)['leaves']
# plt.xticks([])
# plt.yticks([])
# ax2 = plt.axes([.8,.35,.2,.33])
# sister_agglomeration = dendrogram(linkage(reduced_sister, metric='cosine', method='average'), orientation='right', no_plot=False)['leaves']
# plt.xticks([])
# plt.yticks([])
# ax3 = plt.axes([0,.35,.8,.33])
# im1 = plt.imshow(sister_encoding[node_sister_agglomeration].T[sister_agglomeration[::-1]],cmap='bwr',aspect='auto',interpolation='none')
# plt.xticks([])
# plt.yticks([])

ax1 = plt.axes([0,.7,.8,.2])
node_sample_agglomeration = dendrogram(linkage(reduced_sample_node, metric='cosine', method='average'),orientation='top', color_threshold=0, no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax2 = plt.axes([.8,.35,.2,.33])
sample_agglomeration = dendrogram(linkage(reduced_sample, metric='cosine', method='average'),orientation='right', color_threshold=0, no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax3 = plt.axes([0,.35,.8,.33])
im1 = plt.imshow(sample_encoding[node_sample_agglomeration].T[sample_agglomeration[::-1]],cmap='binary',aspect='auto',interpolation='none')
plt.xticks([])
plt.yticks([])
plt.ylabel("Cells")
ax4 = plt.axes([.8,0,.2,.33])
feature_agglomeration = dendrogram(linkage(reduced_feature, metric='cosine', method='average'),orientation='right', color_threshold=0, no_plot=False)['leaves']
plt.xticks([])
plt.yticks([])
ax5 = plt.axes([0,0,.8,.33])
# im2 = plt.imshow(feature_encoding[node_sister_agglomeration].T[feature_agglomeration[::-1]],cmap='bwr',vmin=-minimax,vmax=minimax,aspect='auto',interpolation='none')
im2 = plt.imshow(feature_encoding[node_sample_agglomeration].T[feature_agglomeration[::-1]],cmap='seismic',vmin=-minimax,vmax=minimax,aspect='auto',interpolation='none')
plt.yticks([])
plt.xlabel("Nodes")
plt.ylabel("Genes")
plt.colorbar(im1,ax=ax2,orientation='vertical',label="Cell in Node",shrink=.7)
plt.colorbar(im2,ax=ax4,orientation='vertical',label="Δ Mean Expression",shrink=.7)
plt.text(-.03,.65,"A.",ha='right',transform=plt.gcf().transFigure)
plt.text(-.03,.3,"B.",ha='right',transform=plt.gcf().transFigure)
plt.tight_layout()
plt.show()

In [19]:
plt.figure()
plt.hist(feature_encoding.flatten())
plt.show()

In [None]:
# Let's try to look at some distance metrics for the sample encodings

# from scipy.spatial.distance import cdist,pdist,squareform

# cosine_sample = 1 - squareform(pdist(sample_encoding,metric='cosine'))

# plt.figure()
# plt.imshow(cosine_sample[node_sample_agglomeration].T[node_sample_agglomeration],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

# plt.figure()
# plt.imshow(cosine_sample[node_feature_agglomeration].T[node_feature_agglomeration],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

plt.figure()
plt.imshow(cosine_sample[node_cluster_sort].T[node_cluster_sort],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
plt.colorbar()
plt.show()


In [None]:
# cosine_feature = 1 - squareform(pdist(feature_encoding,metric='cosine'))

# plt.figure()
# plt.imshow(cosine_feature[node_sample_agglomeration].T[node_sample_agglomeration],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

# plt.figure()
# plt.imshow(cosine_feature[node_feature_agglomeration].T[node_feature_agglomeration],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

plt.figure()
plt.imshow(cosine_feature[node_cluster_sort].T[node_cluster_sort],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
plt.colorbar()
plt.show()

# euclidean_feature = squareform(pdist(feature_encoding,metric='euclidean'))

# plt.figure()
# plt.imshow(euclidean_feature[node_cluster_sort].T[node_cluster_sort],aspect='auto',)
# plt.colorbar()
# plt.show()

# plt.figure()
# plt.imshow(euclidean_feature[node_feature_agglomeration].T[node_feature_agglomeration],aspect='auto',)
# plt.colorbar()
# plt.show()

# correlation_feature = 1 - squareform(pdist(feature_encoding,metric='correlation'))

# plt.figure()
# plt.imshow(correlation_feature[node_feature_agglomeration].T[node_feature_agglomeration],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
# plt.colorbar()
# plt.show()

plt.figure()
plt.imshow(correlation_feature[node_cluster_sort].T[node_cluster_sort],aspect='auto',cmap='bwr',vmin=-1,vmax=1)
plt.colorbar()
plt.show()


Finally we can look at silhouette plots scores for various node encodings in order to get a feel for whether or not we are adequately clustering them and whether or not the clusters meaningfully exist. 

In [None]:
# Silhouette Plots For Node Clusters 

from sklearn.metrics import silhouette_samples, silhouette_score

node_labels = np.array([n.split_cluster for n in forest.nodes(root=False)])

# silhouette_scores = silhouette_samples(reduced_node,node_labels,metric='cosine')
silhouette_scores = silhouette_samples(feature_encoding,node_labels,metric='euclidean')
# silhouette_scores = silhouette_samples(sample_encoding,node_labels,metric='cosine')
# silhouette_scores = silhouette_samples(sister_encoding,node_labels,metric='cosine')

sorted_silhouette = np.zeros(silhouette_scores.shape)
sorted_colors = np.zeros(silhouette_scores.shape)

current_index = 0
next_index = 0
for i in sorted(set(node_labels)):
    mask = node_labels == i
    selected_values = sorted(silhouette_scores[mask])    
    next_index = current_index + np.sum(mask)
    sorted_silhouette[current_index:next_index] = selected_values
    sorted_colors[current_index:next_index] = i
    current_index = next_index

In [None]:
import matplotlib.cm as cm

plt.figure()
plt.title("Silhouette Plots For Nodes Clustered By Gain")
for i,node in enumerate(sorted_silhouette):
    plt.plot([0,node],[i,i],color=cm.nipy_spectral(sorted_colors[i] / len(forest.split_clusters)),linewidth=0.5)
# plt.scatter(sorted_silhouette,np.arange(len(sorted_silhouette)),s=1)
plt.plot([0,0],[0,len(sorted_silhouette)],color='red')
plt.xlabel("Silhouette Score")
plt.ylabel("Nodes")
plt.show()

In [None]:

node_populations = np.array([n.pop() for n in forest.nodes(root=False)])
mask = node_populations > 100


feature_encoding = forest.node_representation(forest.nodes(root=False),mode='partial')[mask]
reduced_feature = PCA(n_components=100).fit_transform(feature_encoding.T)
reduced_node = PCA(n_components=100).fit_transform(feature_encoding)


feature_agglomeration = dendrogram(linkage(reduced_feature, metric='cosine', method='average'), no_plot=True)['leaves']
node_agglomeration = dendrogram(linkage(reduced_node, metric='cosine', method='average'), no_plot=True)['leaves']

node_cluster_sort = np.argsort(np.array([n.split_cluster for n in forest.nodes(root=False)])[mask])

In [None]:

print(feature_encoding.shape)

plt.figure()
plt.title("Target Gain in Node (Double-Agglomerated)")
plt.imshow(feature_encoding[node_agglomeration].T[feature_agglomeration].T,cmap='bwr',interpolation='none',aspect='auto',vmin=-2,vmax=2)
plt.xlabel("Genes")
plt.ylabel("Nodes")
plt.colorbar(label="Parent Mean - Node Mean (Log TPM)")
plt.tight_layout()
plt.show()


plt.figure()
plt.title("Target Gain in Node (Clustered)")
plt.imshow(feature_encoding[node_cluster_sort].T[feature_agglomeration].T,cmap='bwr',interpolation='none',aspect='auto',vmin=-2,vmax=2)
plt.xlabel("Genes")
plt.ylabel("Nodes")
plt.colorbar(label="Parent Mean - Node Mean (Log TPM)")
plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import silhouette_samples, silhouette_score

node_labels = np.array([n.split_cluster for n in forest.nodes(root=False)])[mask]

# silhouette_scores = silhouette_samples(reduced_node,node_labels,metric='cosine')
silhouette_scores = silhouette_samples(feature_encoding,node_labels,metric='euclidean')
# silhouette_scores = silhouette_samples(sample_encoding,node_labels,metric='cosine')
# silhouette_scores = silhouette_samples(sister_encoding,node_labels,metric='cosine')

sorted_silhouette = np.zeros(silhouette_scores.shape)
sorted_colors = np.zeros(silhouette_scores.shape)

current_index = 0
next_index = 0
for i in sorted(set(node_labels)):
    mask = node_labels == i
    selected_values = sorted(silhouette_scores[mask])    
    next_index = current_index + np.sum(mask)
    sorted_silhouette[current_index:next_index] = selected_values
    sorted_colors[current_index:next_index] = i
    current_index = next_index

In [None]:
import matplotlib.cm as cm

plt.figure()
plt.title("Silhouette Plots For Nodes Clustered By Gain")
for i,node in enumerate(sorted_silhouette):
    plt.plot([0,node],[i,i],color=cm.nipy_spectral(sorted_colors[i] / len(forest.split_clusters)),linewidth=0.5)
# plt.scatter(sorted_silhouette,np.arange(len(sorted_silhouette)),s=1)
plt.plot([0,0],[0,len(sorted_silhouette)],color='red')
plt.xlabel("Silhouette Score")
plt.ylabel("Nodes")
plt.show()

In [None]:
from sklearn.manifold import TSNE

In [None]:
reduced_feature.shape

In [None]:
trans_node = TSNE(n_components=2,metric='cosine').fit_transform(reduced_node)

In [None]:
plt.figure()
plt.imshow(reduced_node[node_agglomeration][:,:20],aspect='auto',cmap='bwr',vmin=-20,vmax=20,interpolation='none')
plt.colorbar()
plt.show()

In [None]:
plt.figure()
plt.scatter(*trans_node.T,s=2,c=node_labels,cmap='rainbow')
plt.show()

In [None]:
from sklearn.cluster import KMeans

for i in range(5,50,5):

    node_labels = KMeans(i).fit_predict(reduced_node[:,:20])
        
    silhouette_scores = silhouette_samples(feature_encoding,node_labels,metric='cosine')

    sorted_silhouette = np.zeros(silhouette_scores.shape)
    sorted_colors = np.zeros(silhouette_scores.shape)

    current_index = 0
    next_index = 0
    for i in sorted(set(node_labels)):
        mask = node_labels == i
        selected_values = sorted(silhouette_scores[mask])    
        next_index = current_index + np.sum(mask)
        sorted_silhouette[current_index:next_index] = selected_values
        sorted_colors[current_index:next_index] = i
        current_index = next_index
        
        
    plt.figure()
    plt.title("Silhouette Plots For Nodes Clustered By Gain")
    for i,node in enumerate(sorted_silhouette):
        plt.plot([0,node],[i,i],color=cm.nipy_spectral(sorted_colors[i] / len(forest.split_clusters)),linewidth=0.5)
    # plt.scatter(sorted_silhouette,np.arange(len(sorted_silhouette)),s=1)
    plt.plot([0,0],[0,len(sorted_silhouette)],color='red')
    plt.xlabel("Silhouette Score")
    plt.ylabel("Nodes")
    plt.show()

In [None]:
len(set(clustered))

In [None]:
len(clustered)

In [None]:
optics_sort = np.argsort(clustered)

plt.figure()
plt.imshow(reduced_node[optics_sort][:,:20],aspect='auto',cmap='bwr',vmin=-20,vmax=20,interpolation='none')
plt.colorbar()
plt.show()

## Clustering Explanatory Power

In [None]:
feature_encoding = forest.node_representation(forest.nodes(root=False),mode='additive_mean')
labels = np.array([n.split_cluster for n in forest.nodes(root=False)])

In [None]:
feature_encoding.shape


In [None]:
remaining = 0

for cluster in sorted(list(set(labels))):
    mask = labels == cluster
    means = np.mean(feature_encoding[mask],axis=0)
    residuals = feature_encoding[mask] - means
    mse = np.sum(np.power(residuals,2)) / (np.sum(mask) * feature_encoding.shape[1])
    remaining += (np.sum(mask) / feature_encoding.shape[0]) * mse
    print((cluster,mse))
    
means = np.mean(feature_encoding,axis=0)
residuals = feature_encoding - means
mse = np.sum(np.power(residuals,2)) / (feature_encoding.shape[0] * feature_encoding.shape[1])
print(f"Remaining: {remaining}")
print(f"All:{mse}")


In [None]:
from tree_reader_utils import fast_knn,hacked_louvain

shuffled = feature_encoding.copy()

for f in shuffled.T:
    np.random.shuffle(f)



In [None]:
relabel = hacked_louvain(fast_knn(shuffled,50))
    

In [None]:
remaining_sims = []

for i in range(20):
    print(i)
    
    shuffled = feature_encoding.copy()

    for f in shuffled.T:
        np.random.shuffle(f)

    relabel = hacked_louvain(fast_knn(shuffled,50))

    remaining = 0
    for cluster in sorted(list(set(relabel))):
        mask = relabel == cluster
        means = np.mean(shuffled[mask],axis=0)
        residuals = shuffled[mask] - means
        mse = np.sum(np.power(residuals,2)) / (np.sum(mask) * shuffled.shape[1])
        remaining += (np.sum(mask) / shuffled.shape[0]) * mse
#         print((cluster,mse))
    remaining_sims.append(remaining)
    
means = np.mean(shuffled,axis=0)
residuals = shuffled - means
mse = np.sum(np.power(residuals,2)) / (shuffled.shape[0] * shuffled.shape[1])
print(f"Remaining: {remaining}")
print(f"All:{mse}")


### Now for sample encoding

In [None]:
nodes = forest.nodes(root=False)

sample_encoding = forest.node_representation(nodes,mode='sample')
labels = np.array([n.split_cluster for n in nodes])

sample_encoding.shape

In [None]:
remaining = 0

for cluster in sorted(list(set(labels))):
    mask = labels == cluster
    means = np.mean(sample_encoding[mask],axis=0)
    residuals = sample_encoding[mask] - means
    mse = np.sum(np.power(residuals,2)) / (np.sum(mask) * sample_encoding.shape[1])
    remaining += (np.sum(mask) / sample_encoding.shape[0]) * mse
    print((cluster,mse))
    print( np.sum(mask))
    
means = np.mean(sample_encoding,axis=0)
residuals = sample_encoding - means
mse = np.sum(np.power(residuals,2)) / (sample_encoding.shape[0] * sample_encoding.shape[1])
print(f"Remaining: {remaining}")
print(f"All:{mse}")


In [None]:
from tree_reader_utils import fast_knn,hacked_louvain

shuffled = feature_encoding.copy()

for f in shuffled.T:
    np.random.shuffle(f)

relabel = hacked_louvain(fast_knn(shuffled,50))


In [None]:

remaining = 0
for cluster in sorted(list(set(relabel))):
    mask = relabel == cluster
    means = np.mean(shuffled[mask],axis=0)
    residuals = shuffled[mask] - means
    mse = np.sum(np.power(residuals,2)) / (np.sum(mask) * shuffled.shape[1])
    remaining += (np.sum(mask) / shuffled.shape[0]) * mse
    print((cluster,mse))
remaining_sims.append(remaining)

means = np.mean(shuffled,axis=0)
residuals = shuffled - means
mse = np.sum(np.power(residuals,2)) / (shuffled.shape[0] * shuffled.shape[1])
print(f"Remaining: {remaining}")
print(f"All:{mse}")


In [None]:
np.var(remaining_sims)
# np.mean(remaining_sims)

In [None]:
remaining_sims

In [None]:
tn = forest.trees[0].root.nodes()

In [None]:
len(tn)

In [None]:
additive = forest.node_representation(tn,mode='additive_mean')
sample = forest.node_representation(tn,mode='sample')

pops = np.sum(sample,axis=1)
pops.shape

In [None]:
plt.figure()
plt.imshow(additive,aspect='auto',cmap='bwr')
plt.show()

plt.figure()
plt.imshow(sample,aspect='auto',cmap='binary')
plt.show()

In [None]:
np.dot(additive.T,pops) # equals zero, law of total expectation

In [None]:
explained = np.dot(np.power(additive.T,2),pops) / 16027

In [None]:
total = np.var(forest.output,axis=0)

In [None]:
explained / total

In [None]:
nse = forest.node_representation(forest.nodes(),mode='sample')
nge = forest.node_representation(forest.nodes(),mode='additive_mean')

In [None]:
nse.shape
nge.shape

In [None]:
ncv = np.cov(nse.T,nge.T)

In [None]:
ncv.shape