In [None]:
"""
Cell For Papermill Parameters
"""

test_size = 0.2
input_data = "to_ngboost.csv"
Label_Column = 'label_clas'
wandb_log = 0
estimators = 100

In [None]:
import pandas as pd
import json
import wandb
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import mean_squared_error
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import plot_confusion_matrix

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

import matplotlib.pyplot as plt
import itertools

In [None]:
sns.set(font_scale=1)
sns.set_style("white")
sns.set_palette("rocket")

In [None]:
#SOURCE https://runawayhorse001.github.io/LearningApacheSpark/classification.html

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    fig = plt.figure(figsize=(3, 3),frameon =False, dpi=200)  
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    #plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.1f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
def make_ROC_graph(labels_test,prediction):
    """ Text """

    false_positive_rate, recall, thresholds = roc_curve(labels_test,prediction)
    roc_auc = auc(false_positive_rate, recall)
    fig = plt.figure(figsize=(3, 3),frameon =False, dpi=200)  
    plt.title('Receiver Operating Characteristic')
    plt.plot(false_positive_rate, recall, 'b', label='AUC = %0.2f' %roc_auc)
    plt.legend(loc='lower right')
    plt.plot([0, 1], [0, 1], 'r--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    fn = "roc_graph.png"
    plt.savefig(fn,bbox_inches='tight')

In [None]:
def make_precision_recall_graph(labels_test,prediction):
    """ Text """
    precision, recall, thresholds = precision_recall_curve(labels_test, prediction)
    average_precision = average_precision_score(labels_test, prediction)
    plt.clf()
    fig = plt.figure(figsize=(3, 3),frameon =False, dpi=200)  
    plt.plot(recall, precision, color='navy', label='Precision-Recall curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve | Area ='+str(round(average_precision,2)))
    fn = "precision_recall.png"
    plt.savefig(fn,bbox_inches='tight')

In [None]:
def make_importance_graph(df):
    """ Text """
    plt.clf()
    fig = plt.figure(figsize=(5, 3),frameon =False, dpi=200)  
    fig = sns.barplot(y="Feature_Name", x="Importance", data=importances_df_final)
    plt.xlabel('Relative Importance')
    plt.ylabel('Feature')
    plt.title('Feature Importance')
    fn = "importance.png"
    plt.savefig(fn,bbox_inches='tight')

In [None]:
with open('../config/config.json') as config_f:
  data = json.load(config_f)
aws_access_key_id = data['aws_access_key_id']
aws_secret_access_key = data['aws_secret_access_key']
region_name = data['region_name']
WANDB_API_KEY = data['WANDB_API_KEY']

In [None]:
if wandb_log == 1:
    os.environ["WANDB_API_KEY"] = WANDB_API_KEY
    wandb.login()
    run = wandb.init(project="Final-Model", entity='prostate-cancer', config={"test_size":test_size})
    wandb_run_name = wandb.run.name
    wandb_run_id = wandb.run.id
    config = wandb.config

# RandomForest

In [None]:
df = pd.read_csv(input_data)

#df.drop(["cancer_in_core_max","cancer_in_core_mean","patient","Patient Number","name","ethnic_grp","occupation","patient_US"],axis=1, inplace=True)
df.drop(["name","smoking_status","occupation","Model-MRI-DNN","Model-US-DNN"],axis=1, inplace=True)

age_mean = df["age"].mean()
size_mean = df["size"].mean()
weight_mean = df["weight"].mean()
psa_mean = df["PSA"].mean()

df= df.fillna({"age":age_mean, "size":size_mean, "weight":weight_mean, "PSA":psa_mean})
df.to_csv("final_result.csv", sep=',', encoding='utf-8', index=False)
df.dropna(inplace=True)

df.loc[df.ethnic_grp =="Patient Refused","ethnic_grp"] = "Unknown"
df.loc[df.ethnic_grp =="Unknown [3]","ethnic_grp"] = "Unknown"

Label_Column = 'label_clas'
df[Label_Column] = df[Label_Column].astype(int)

features_df = df.drop(Label_Column,axis=1,inplace=False)
features_one_hot_df = pd.get_dummies(features_df)

labels_df = df[[Label_Column]]

In [None]:
features_one_hot_df.head(5)

In [None]:
features_np = features_one_hot_df.values.tolist()
labels_np = labels_df.values.ravel()
X_train, X_test, Y_train, Y_test = train_test_split(features_np,labels_np, random_state=0,test_size=0.2)

In [None]:
clf = RandomForestClassifier(n_estimators=estimators,n_jobs=-1,verbose=0)
clf = clf.fit(X_train, Y_train)

In [None]:
prediction = clf.predict(X_test)

In [None]:
scores = cross_val_score(clf, X_test, Y_test)
cnf_matrix = confusion_matrix(Y_test,prediction)

# Results

In [None]:
classes = ['neg','pos']

# Plot non-normalized confusion matrix
plt.figure()
class_names = classes
plot_confusion_matrix(cnf_matrix, classes=class_names,title='Confusion matrix')
plt.savefig("confusion-matrix.png")
if wandb_log == 1:
    wandb.log({"Media/Confusion Matrix": wandb.Image("confusion-matrix.png")})
plt.show()

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,title='Normalized confusion matrix')
plt.savefig("confusion-matrix-normalized.png")
if wandb_log == 1:
    wandb.log({"Media/Normalized Confusion Matrix": wandb.Image("confusion-matrix-normalized.png")})
plt.show()

In [None]:
print(classification_report(Y_test, prediction, target_names=classes))
print(accuracy_score(Y_test, prediction))

In [None]:
make_ROC_graph(Y_test, prediction)
plt.show()
plt.close()
if wandb_log == 1:
    wandb.log({"Media/ROC-Graph-RF": wandb.Image("roc_graph.png")})

In [None]:
make_precision_recall_graph(Y_test,prediction)
plt.show()
plt.close()
if wandb_log == 1:
    wandb.log({"Media/Precision_Recall": wandb.Image("precision_recall.png")})

# Importance

In [None]:
importances = clf.feature_importances_
importances_df = pd.DataFrame(importances)
column_names_df  = pd.DataFrame(features_one_hot_df.columns.values.tolist())
importances_df_final = pd.merge(column_names_df,importances_df, left_index=True,right_index=True)
importances_df_final.columns = ["Feature_Name", "Importance"]
importances_df_final = importances_df_final.sort_values(by=['Importance'], ascending=False)
imp_list = [importances_df_final.columns.values.tolist()] + importances_df_final.values.tolist()
importances_df_final.to_csv('imp_features.csv', index=False,)

In [None]:
make_importance_graph(importances_df_final)
plt.show()
plt.close()
if wandb_log == 1:
    wandb.log({"Media/Importance": wandb.Image("importance.png")})

# Prediction Probability

In [None]:
prediction = clf.predict(X_test)
probability = clf.predict_proba(X_test)

probability_df = pd.DataFrame(probability)
probability_df.columns = ['Prob_0','Prob_1']

In [None]:
probability_df

In [None]:
# Mark the run as finished
if wandb_log == 1:
    wandb.finish()