In [94]:
import numpy as np
import pandas as pd
import os
import yaml
import pickle
from scipy import signal
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score

In [95]:
FEATURES_TO_RANK = ['decay_CaDynamics_E2_axonal', 'decay_CaDynamics_E2_somatic', 'e_pas_axonal', 
                    'e_pas_somatic', 'gCa_LVAstbar_Ca_LVAst_axonal', 'gCa_LVAstbar_Ca_LVAst_somatic',
                    'gCabar_Ca_axonal', 'gCabar_Ca_somatic', 'gIhbar_Ih_dend', 'gImbar_Im_axonal',
                    'gImbar_Im_dend', 'gImbar_Im_somatic', 'gK_Pstbar_K_Pst_axonal', 'gK_Pstbar_K_Pst_somatic',
                    'gK_Tstbar_K_Tst_axonal', 'gK_Tstbar_K_Tst_dend', 'gK_Tstbar_K_Tst_somatic', 'gNaTa_tbar_NaTa_t_axonal',
                    'gNaTs2_tbar_NaTs2_t_dend', 'gNaTs2_tbar_NaTs2_t_somatic', 'gNap_Et2bar_Nap_Et2_axonal', 
                    'gNap_Et2bar_Nap_Et2_dend', 'gNap_Et2bar_Nap_Et2_somatic', 'gSK_E2bar_SK_E2_axonal', 
                    'gSK_E2bar_SK_E2_somatic', 'gSKv3_1bar_SKv3_1_axonal', 'gSKv3_1bar_SKv3_1_dend', 'gSKv3_1bar_SKv3_1_somatic', 
                    'g_pas_axonal', 'g_pas_dend', 'g_pas_somatic', 'gamma_CaDynamics_E2_axonal', 'gamma_CaDynamics_E2_somatic']

In [96]:
DCA_THRESHOLD_PER_CELL_UNNORM_95 = pd.read_csv('data/dca_threshold_per_cell_unnorm_950.0.csv')
DCA_THRESHOLD_PER_CELL_UNNORM_68 = pd.read_csv('data/dca_threshold_per_cell_unnorm_680.0.csv')

DCA_THRESHOLD_PER_CELL_UNNORM = DCA_THRESHOLD_PER_CELL_UNNORM_95

MAPPINGS_WITH_CONDUCTANCE = pd.read_csv('data/mappings_with_conductance.csv')

In [97]:
E_FEATURES = MAPPINGS_WITH_CONDUCTANCE[['bbp_name'] + FEATURES_TO_RANK].rename(columns={'bbp_name': 'cell_name'})

In [98]:
E_FEATURES_DCA = E_FEATURES.merge(DCA_THRESHOLD_PER_CELL_UNNORM[['cell_name', 'pi']], left_on='cell_name', right_on='cell_name', how='inner')

In [99]:
E_FEATURES_DCA

Unnamed: 0,cell_name,decay_CaDynamics_E2_axonal,decay_CaDynamics_E2_somatic,e_pas_axonal,e_pas_somatic,gCa_LVAstbar_Ca_LVAst_axonal,gCa_LVAstbar_Ca_LVAst_somatic,gCabar_Ca_axonal,gCabar_Ca_somatic,gIhbar_Ih_dend,...,gSK_E2bar_SK_E2_somatic,gSKv3_1bar_SKv3_1_axonal,gSKv3_1bar_SKv3_1_dend,gSKv3_1bar_SKv3_1_somatic,g_pas_axonal,g_pas_dend,g_pas_somatic,gamma_CaDynamics_E2_axonal,gamma_CaDynamics_E2_somatic,pi
0,bbp002,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.019726,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,33.246692
1,bbp002,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.019726,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,33.246692
2,bbp003,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.019726,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,37.653326
3,bbp003,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.019726,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,37.653326
4,bbp004,468.069681,645.079741,-63.854018,-67.128897,0.009017,0.003242,0.000400,0.000174,0.000049,...,0.000523,0.386953,0.000083,0.503893,0.000008,0.000001,0.000100,0.001739,0.000500,52.524267
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
331,bbp205,103.091390,873.498863,-60.216510,-62.442793,0.000007,0.005592,0.000003,0.000032,0.000080,...,0.003869,1.936176,0.004399,0.072929,0.000094,0.000001,0.000091,0.001734,0.000996,131.212123
332,bbp206,103.091390,873.498863,-60.216510,-62.442793,0.000007,0.005592,0.000003,0.000032,0.000080,...,0.003869,1.936176,0.004399,0.072929,0.000094,0.000001,0.000091,0.001734,0.000996,133.391516
333,bbp206,103.091390,873.498863,-60.216510,-62.442793,0.000007,0.005592,0.000003,0.000032,0.000080,...,0.003869,1.936176,0.004399,0.072929,0.000094,0.000001,0.000091,0.001734,0.000996,133.391516
334,bbp207,103.091390,873.498863,-60.216510,-62.442793,0.000007,0.005592,0.000003,0.000032,0.000080,...,0.003869,1.936176,0.004399,0.072929,0.000094,0.000001,0.000091,0.001734,0.000996,139.814076


In [100]:
X = E_FEATURES_DCA.drop(['pi'], axis=1)  # 'target_column' is the name of the column you want to predict
y = E_FEATURES_DCA['pi']

In [101]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [102]:
rf_model = RandomForestRegressor(n_estimators=100)

# For regression tasks:
# rf_model = RandomForestRegressor(n_estimators=100, random_state=42)

# Fit the model to the training data
X_train = X_train.drop(['cell_name'], axis=1)
rf_model.fit(X_train, y_train)

In [103]:
X_test_1 = X_test.drop(['cell_name'], axis=1)
y_pred = rf_model.predict(X_test_1)

print(y_test, y_pred)

263    53.646514
62     55.416542
79     27.768780
271    34.469678
31     32.210337
         ...    
204    51.114298
143    54.104920
297    85.438884
321    29.945794
329    62.539454
Name: pi, Length: 68, dtype: float64 [48.36396958 63.84954829 46.85529168 48.36396958 48.36396958 46.85529168
 63.84954829 63.84954829 48.36396958 48.36396958 46.85529168 63.84954829
 48.36396958 48.36396958 48.36396958 48.36396958 46.85529168 46.85529168
 63.84954829 63.84954829 48.36396958 48.36396958 46.85529168 48.36396958
 42.62509778 48.36396958 48.36396958 46.85529168 63.84954829 48.59559997
 46.85529168 48.36396958 48.36396958 63.84954829 63.84954829 48.36396958
 63.84954829 48.36396958 63.84954829 48.36396958 42.62509778 48.36396958
 46.85529168 48.36396958 48.36396958 63.84954829 48.36396958 48.36396958
 46.85529168 48.36396958 48.36396958 48.36396958 65.69471024 48.36396958
 63.84954829 63.84954829 48.36396958 46.85529168 46.85529168 46.85529168
 63.84954829 63.84954829 46.85529168 48.363969

In [104]:
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("Mean Squared Error:", mse)
print("R-squared:", r2)

Mean Squared Error: 302.6921506526314
R-squared: 0.13307972216579655
