# Classification Model

5-fold classification code for BraTS2019 dataset (RGBA images - 4 channels).

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import fastai
from fastai.vision import *
from fastai.metrics import error_rate
import os
import pandas as pd
from collections import Counter

import warnings
warnings.filterwarnings('ignore')
import numpy as np
from pathlib import Path

from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import f1_score,precision_score,recall_score
from sklearn.metrics import roc_curve,auc

In [None]:
import torch
torch.cuda.set_device(0)
torch.cuda.current_device()

In [None]:
!pip install fastai==1.0.61 --no-deps
# !pip install torch==1.4 torchvision==0.5.0

Set batch size according to useable memory and imsize according to model needs.

In [None]:
import fastai; fastai.__version__

In [None]:
bs = 32
imsize = 96

np.random.seed(2)

# Training

In [None]:
# import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets, **kwargs):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * ((1-pt)**self.gamma) * CE_loss
        return F_loss.mean()

In [None]:
def view_result(learn):
    interp = ClassificationInterpretation.from_learner(learn)
    interp.plot_confusion_matrix()
#     plt.savefig('confusion_matrix.png')
    
    preds,y, loss = learn.get_preds(with_loss=True)

    # get accuracy
    acc = accuracy(preds, y)
    print('Validation set: the accuracy is {0}.'.format(acc))

    # F1 score
    pred_valid = np.argmax(preds, axis=1)
    
    F1_validation=f1_score(pred_valid, y,average='weighted')
    print('Validation set: F1_score is {0}.'.format(F1_validation)) 
    
    pred_id, lables_pred = torch.max(preds,dim=1)
    lables_true=y
    
    lables_true_array = y.numpy() # fron tensot to array
    lables_pred_array=lables_pred.numpy()
    
    lables_true_array = y.numpy() # fron tensot to array
    lables_pred_array=lables_pred.numpy()
    
    # # Majority of Votes:

    MV_labels = [];                                          # majority of votes decision list
    MV_predicted_labels = []
    for i in range(0, len(lables_pred_array),5):               
        curr_patient = lables_pred_array[i:i+5]              # running on each patient
        occurances = Counter(curr_patient)                   # finding how many times each prediction was appeared
        max_key = max(occurances, key = occurances.get)      # extracting the most common appearance
        MV_labels.append(max_key) 
    #     print("for patient "+str(i)+' the *predicted* diagnosos is '+str(max_key))
        MV_predicted_labels.append(max_key)

    MV_True_labels = [];                                     # majority of votes decision list
    for i in range(0, len(lables_true_array),5):               
        curr_patient = lables_true_array[i:i+5]              # running on each patient
        occurances = Counter(curr_patient)                   # finding how many times each prediction was appeared
        max_key = max(occurances, key = occurances.get)      # extracting the most common appearance
        MV_True_labels.append(max_key) 
    #     print("for patient "+str(i)+' the *True* diagnosos is '+str(max_key))

    lst=[]
    for i in range(len(MV_True_labels)):    
        if MV_True_labels[i] == MV_labels[i]:  
    #         print("patient" +str(i+1)+" TRUE prediction")
            lst.append("TRUE")
        else:  
    #         print("patient" +str(i+1)+" FALSE prediction")
            lst.append("FALSE")
        
    # Accuracy after majority of votes:
    acc_MV = lst.count("TRUE")/len(MV_True_labels)
    print('Validation set: Accuracy after majority of votes is {0}.' .format(acc_MV))
    # F1 score
    F1_validation=f1_score(MV_predicted_labels, MV_True_labels, average='weighted')
    print('Validation set: F1_score after majority of votes is {0}.'.format(F1_validation))
    
    
    # Confusion Matrix after majority of votes:
    cf = confusion_matrix(MV_predicted_labels, MV_True_labels)
    cf
    # disp = ConfusionMatrixDisplay(cf,display_labels=classes)
    # disp.plot(xticks_rotation='vertical',cmap=plt.cm.Blues)
    # plt.title('Confusion Matrix - validation set after majority of votes')
    # plt.show()
    


In [None]:
def new_model(*args, **kwargs):
   
    model = models.resnet152(*args, **kwargs)
    model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)

    return model


In [None]:
path = #BLINDED
path_images = Path(path + '/TRAIN_images')

epochs = 2
lr=1e-02 # 0.01
loss_func = FocalLoss(alpha=0.25, gamma=2) # best - .81 - (alpha=0.2, gamma=3)

for i_fold in range(1,6):
    
    df=pd.read_csv(path+f'/TRAIN_FOLD{i_fold}.csv') 

    IS_VALID="Val"

    # Creating augmentation
    tfms = get_transforms(do_flip=True,flip_vert=True, max_rotate=2.0, max_lighting=0.5, max_zoom=0)

    data = (ImageList.from_df(df, path_images, suffix='.TIFF',convert_mode='RGBA')
           .split_from_df(IS_VALID)
           .label_from_df()
           .transform(tfms, size=imsize)
           .databunch(bs=bs))
    
    
    print(f'Training fold {i_fold}:')
    learn = cnn_learner(data, new_model, metrics=[accuracy],callback_fns=[callbacks.OverSamplingCallback],loss_func=loss_func, wd=0.001,ps=0.5)
    learn.fit_one_cycle(epochs, max_lr=slice(lr/3,lr), callbacks=[callbacks.SaveModelCallback(learn, every='improvement', monitor='accuracy',name='best')]) # save the best mode
    view_result(learn)
          
    learn.save(f"Classification_model_fold{i_fold}")
    learn.export(f'models/Classification_model{i_fold}.pkl')
    print('//////////////////////////////')