In [140]:
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score

In [141]:
FEATURES_TO_RANK = ['decay_CaDynamics_E2_axonal', 'decay_CaDynamics_E2_somatic', 'e_pas_axonal', 
                    'e_pas_somatic', 'gCa_LVAstbar_Ca_LVAst_axonal', 'gCa_LVAstbar_Ca_LVAst_somatic',
                    'gCabar_Ca_axonal', 'gCabar_Ca_somatic', 'gIhbar_Ih_dend', 'gImbar_Im_axonal',
                    'gImbar_Im_dend', 'gImbar_Im_somatic', 'gK_Pstbar_K_Pst_axonal', 'gK_Pstbar_K_Pst_somatic',
                    'gK_Tstbar_K_Tst_axonal', 'gK_Tstbar_K_Tst_dend', 'gK_Tstbar_K_Tst_somatic', 'gNaTa_tbar_NaTa_t_axonal',
                    'gNaTs2_tbar_NaTs2_t_dend', 'gNaTs2_tbar_NaTs2_t_somatic', 'gNap_Et2bar_Nap_Et2_axonal', 
                    'gNap_Et2bar_Nap_Et2_dend', 'gNap_Et2bar_Nap_Et2_somatic', 'gSK_E2bar_SK_E2_axonal', 
                    'gSK_E2bar_SK_E2_somatic', 'gSKv3_1bar_SKv3_1_axonal', 'gSKv3_1bar_SKv3_1_dend', 'gSKv3_1bar_SKv3_1_somatic', 
                    'g_pas_axonal', 'g_pas_dend', 'g_pas_somatic', 'gamma_CaDynamics_E2_axonal', 'gamma_CaDynamics_E2_somatic']

In [142]:
DCA_THRESHOLD_PER_CELL_UNNORM_95 = pd.read_csv('data/dca_threshold_per_cell_unnorm_950.0.csv')
DCA_THRESHOLD_PER_CELL_UNNORM_68 = pd.read_csv('data/dca_threshold_per_cell_unnorm_680.0.csv')

DCA_THRESHOLD_PER_CELL_UNNORM = DCA_THRESHOLD_PER_CELL_UNNORM_95

MAPPINGS_WITH_CONDUCTANCE = pd.read_csv('data/mappings_with_conductance.csv')

In [143]:
E_FEATURES = MAPPINGS_WITH_CONDUCTANCE[['bbp_name'] + FEATURES_TO_RANK].rename(columns={'bbp_name': 'cell_name'})

In [144]:
E_FEATURES_DCA = E_FEATURES.merge(DCA_THRESHOLD_PER_CELL_UNNORM[['cell_name', 'pi']], left_on='cell_name', right_on='cell_name', how='inner')
E_FEATURES_DCA = E_FEATURES_DCA.merge(DCA_THRESHOLD_PER_CELL_UNNORM[['cell_name', 'dca_level']], left_on='cell_name', right_on='cell_name', how='inner')

In [145]:
E_FEATURES_DCA = E_FEATURES_DCA[E_FEATURES_DCA['dca_level'] <= 67]
E_FEATURES_DCA

Unnamed: 0,cell_name,decay_CaDynamics_E2_axonal,decay_CaDynamics_E2_somatic,e_pas_axonal,e_pas_somatic,gCa_LVAstbar_Ca_LVAst_axonal,gCa_LVAstbar_Ca_LVAst_somatic,gCabar_Ca_axonal,gCabar_Ca_somatic,gIhbar_Ih_dend,...,gSKv3_1bar_SKv3_1_axonal,gSKv3_1bar_SKv3_1_dend,gSKv3_1bar_SKv3_1_somatic,g_pas_axonal,g_pas_dend,g_pas_somatic,gamma_CaDynamics_E2_axonal,gamma_CaDynamics_E2_somatic,pi,dca_level
0,bbp002,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,33.246692,25.0
1,bbp002,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,33.246692,25.0
2,bbp002,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,33.246692,25.0
3,bbp002,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,33.246692,25.0
4,bbp003,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,37.653326,28.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
655,bbp203,573.007045,967.678789,-64.601696,-69.781406,0.009986,0.009728,0.000501,0.000028,0.000023,...,0.517764,0.000041,0.260872,0.000063,0.000001,0.000020,0.000503,0.000814,66.713427,52.0
656,bbp204,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,62.539454,46.0
657,bbp204,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,62.539454,46.0
658,bbp204,64.277990,731.707637,-60.216510,-62.442793,0.000015,0.001067,0.000003,0.000032,0.000052,...,0.317363,0.004399,0.297559,0.000094,0.000001,0.000091,0.010353,0.000511,62.539454,46.0


In [146]:
X = E_FEATURES_DCA.drop(['pi', 'dca_level'], axis=1)  # 'target_column' is the name of the column you want to predict
y = E_FEATURES_DCA['pi']

In [147]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [148]:
rf_model = RandomForestRegressor(n_estimators=100)

# For regression tasks:
# rf_model = RandomForestRegressor(n_estimators=100, random_state=42)

# Fit the model to the training data
X_train = X_train.drop(['cell_name'], axis=1)
rf_model.fit(X_train, y_train)

In [149]:
X_test_1 = X_test.drop(['cell_name'], axis=1)
y_pred = rf_model.predict(X_test_1)

print(y_test, y_pred)

576    82.589248
607    46.935766
297    22.834912
348    27.850303
192    53.027112
         ...    
308    52.792132
623    35.506194
399    75.196074
596    63.342093
505    74.776191
Name: pi, Length: 130, dtype: float64 [48.07832785 46.65740241 46.65740241 46.65740241 63.31511758 46.65740241
 48.07832785 48.07832785 48.07832785 48.07832785 63.31511758 34.70038053
 63.31511758 46.65740241 46.65740241 48.07832785 48.07832785 63.31511758
 63.31511758 46.65740241 48.07832785 63.31511758 63.31511758 63.31511758
 63.31511758 48.07832785 46.65740241 46.65740241 46.65740241 63.31511758
 63.31511758 48.07832785 46.65740241 46.65740241 67.73343912 63.31511758
 48.07832785 46.65740241 46.65740241 46.65740241 48.07832785 48.07832785
 46.65740241 67.73343912 46.65740241 46.65740241 46.65740241 46.65740241
 46.65740241 63.31511758 46.65740241 48.07832785 48.07832785 46.65740241
 46.65740241 48.07832785 46.65740241 46.65740241 48.07832785 63.31511758
 48.07832785 48.07832785 34.70038053 46.65740

In [150]:
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("Mean Squared Error:", mse)
print("R-squared:", r2)

Mean Squared Error: 230.81411034506834
R-squared: 0.2300132776435283
