In [1]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold, cross_val_score
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score, make_scorer
from sklearn.linear_model import LogisticRegression

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [2]:
feats_dummy = pd.read_pickle(r'/Users/philliprichardson/Metis/Module 4/feats.pkl')
outcome = pd.read_pickle(r'/Users/philliprichardson/Metis/Module 4/outcome.pkl')

In [3]:
kf = StratifiedKFold(n_splits = 5, random_state = 13, shuffle = True)


def quick_test(model, X, y):
    xtrain, xtest, ytrain, ytest = train_test_split(X, y, test_size=0.2, stratify = y, random_state = 13)
    model.fit(xtrain, ytrain)
    y_pred = model.predict(xtest)
    score = np.mean(cross_val_score(model, X, y, cv=kf, scoring = make_scorer(f1_score, average = 'macro')))
    fitting = np.std(cross_val_score(model, X, y, cv=kf, scoring = make_scorer(f1_score, average = 'macro')))
    return score, fitting


In [4]:
randomforest = RandomForestClassifier(n_estimators=300, max_features = 5)
randomforest_bal = RandomForestClassifier(n_estimators=300, max_features = 5, class_weight = 'balanced')


In [17]:
lst = [randomforest, randomforest_bal]

for i in lst:
    print(quick_test(i, feats_dummy, outcome))
    

(0.4996772890484296, 0.007414979835703855)
(0.5014808368218296, 0.0029608429486447487)


In [6]:
results = [['n_estimators', 'num features', 'avg f1 weighted', 'std f1 weighted']]
for i in range(50, 501, 50):
    for j in range(0, len(feats_dummy.columns), 2):
        randomforest_bal = RandomForestClassifier(n_estimators=i, max_features = j+1, class_weight = 'balanced')
        mean, std = quick_test(randomforest_bal, feats_dummy, outcome)
        results.append([i,j,mean, std])

In [7]:
results

[['n_estimators', 'num features', 'avg f1 weighted', 'std f1 weighted'],
 [50, 0, 0.5016203692927844, 0.0007725523643620418],
 [50, 2, 0.5015004594732216, 0.0015989219403514314],
 [50, 4, 0.4995132192942348, 0.0023391290346025852],
 [50, 6, 0.5012741920763322, 0.002138429541780052],
 [50, 8, 0.5016232958600204, 0.0015492013047610032],
 [50, 10, 0.5007130709151696, 0.0014372922158743262],
 [50, 12, 0.5002595968128826, 0.0023012106878560934],
 [50, 14, 0.5015695486817294, 0.001671391340062846],
 [50, 16, 0.4998193899796237, 0.002506695936603263],
 [50, 18, 0.5004953529493326, 0.0008262605083209332],
 [50, 20, 0.4996983385721331, 0.00162681324832265],
 [50, 22, 0.5005786458531465, 0.00195997539818328],
 [100, 0, 0.5015953494251273, 0.0025210027924655294],
 [100, 2, 0.500732489055231, 0.002211460411327778],
 [100, 4, 0.5011649460987397, 0.001545451268842622],
 [100, 6, 0.5011854067009179, 0.0019980580349322243],
 [100, 8, 0.5014213755212276, 0.0026361303440553935],
 [100, 10, 0.50081264705