In [14]:
import os
from glob import glob
from pathlib import Path
import sys
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import sklearn as sk
import pickle
import copy
from treeinterpreter import treeinterpreter as ti
from sklearn.model_selection import LeaveOneGroupOut, train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle



In [7]:
module_path = os.path.abspath(os.path.join('../../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [8]:
from global_config import ROOT_DIR

In [9]:
input_path = os.path.join(ROOT_DIR,'files/out/functionals/video_data_functionals_A220.npz')
output_path = os.path.join(ROOT_DIR,'files/out/functionals/supervised_learning/video')

In [10]:
npz_file = np.load(input_path, allow_pickle=True)
_XX = npz_file['x']
yy = npz_file['y']
groups = npz_file['groups']
col_names = npz_file['col_names']

scaler = MinMaxScaler()
XX = scaler.fit_transform(_XX)

In [11]:
XX, yy = shuffle(XX, yy)

In [15]:
from sklearn.ensemble import RandomForestClassifier

# Obtain X, y, and groups from the training dataset
X = pd.DataFrame(data=XX, columns=col_names)
y = pd.Series(yy)


# TODO write code to create groups in src
groups = pd.Series(groups)
n_groups = len(groups.unique())

# Load best parameters and initialize classifier
file_path = os.path.join(output_path,'best_params','best_params_rf_intensity_video.sav')
best_params = pickle.load(open(file_path, 'rb'))
clf = RandomForestClassifier(**best_params)

# Perform LOGO CV (5 GROUPS)
data_tmp = [] # Initialize to store data

X_train, X_val, y_train, y_val = train_test_split(X, y)
    
# Fit models
clf.fit(X_train, y_train)

# Step 0. Get prediction probabilities, bias, contributions, predicted class and store it in a DataFrame
prediction_proba, bias, contributions = ti.predict(clf, X_val)
prediction = clf.predict(X_val)
for i in range(len(X_val)):
    pred_class_tmp = prediction[i]
    bias_tmp = bias[i, pred_class_tmp]
    contributions_tmp = contributions[i, :, pred_class_tmp]
    data_tmp.append([bias_tmp] + contributions_tmp.tolist() + [pred_class_tmp])

interpretation_step_0_df = pd.DataFrame(data=data_tmp, columns=['bias']+X.columns.tolist()+['prediction'])
interpretation_step_0_df

Unnamed: 0,bias,AU01_r_mean,AU01_r_stddevNorm,AU01_r_percentile20.0,AU01_r_percentile50.0,AU01_r_percentile80.0,AU01_r_iqr60_80-20,AU01_r_numPeaks,AU02_r_mean,AU02_r_stddevNorm,...,AU26_r_iqr60_80-20,AU26_r_numPeaks,AU45_r_mean,AU45_r_stddevNorm,AU45_r_percentile20.0,AU45_r_percentile50.0,AU45_r_percentile80.0,AU45_r_iqr60_80-20,AU45_r_numPeaks,prediction
0,0.018371,0.001929,0.000105,0.0,0.002870,0.000144,-0.001039,0.007452,-0.003261,-0.001202,...,0.003084,-0.000072,0.008981,-0.011903,0.0,0.003246,0.002039,0.016343,0.001021,13
1,0.025114,0.000986,0.002730,0.0,-0.000825,0.001775,0.000085,0.000417,0.004464,0.000829,...,-0.002581,0.002049,0.001679,-0.000245,0.0,0.000097,-0.000072,-0.000006,0.000198,26
2,0.022311,0.001667,0.001133,0.0,0.000050,-0.001093,0.001519,-0.001401,0.001914,-0.001250,...,0.006243,0.002404,0.005116,-0.004442,0.0,0.000234,-0.000457,0.000818,-0.001273,30
3,0.022311,0.003443,0.001138,0.0,0.000531,-0.001161,0.000810,0.002539,-0.000406,0.000363,...,0.006608,0.004248,0.003552,0.000951,0.0,0.000234,0.004343,0.003907,0.000630,30
4,0.023447,0.003467,0.000078,0.0,0.000399,-0.000632,0.003370,0.002157,0.004255,0.000023,...,-0.003232,0.000744,0.010032,0.002167,0.0,0.000972,0.000581,0.003352,-0.003487,18
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
83,0.021288,0.002319,0.002014,0.0,0.000732,0.000754,0.002736,0.001332,0.003903,0.003669,...,0.000280,-0.004064,-0.002115,0.002385,0.0,0.000032,-0.003618,-0.000061,-0.002251,7
84,0.028182,-0.010072,0.005046,0.0,0.001267,-0.005180,0.000109,0.000042,0.000824,0.000105,...,0.000831,0.001103,-0.017170,-0.000371,0.0,-0.005118,-0.002283,-0.003824,-0.000154,15
85,0.022121,0.000489,0.000028,0.0,0.000017,0.004127,0.000823,0.000129,-0.002632,0.004167,...,-0.005426,0.002879,0.003376,0.000097,0.0,0.000000,0.000115,0.000245,0.000391,4
86,0.020114,0.005497,0.005269,0.0,0.000020,0.002644,0.001040,0.000003,-0.004812,-0.000943,...,0.000495,0.006692,0.007351,0.003875,0.0,0.004487,0.013702,0.002604,0.006453,17
