<img src='http://www-scf.usc.edu/~ghasemig/images/sharif.png' alt="SUT logo" width=200 height=200 align=left class="saturate" >

<br>
<font face="Times New Roman">
<div dir=ltr align=center>
<font color=0F5298 size=7>
    Introduction to Machine Learning <br>
<font color=2565AE size=5>
    Computer Engineering Department <br>
    Fall 2022<br>
<font color=3C99D size=5>
    Project <br>
<font color=696880 size=4>
    Project Team 
    
    
____


### Full Name : Mohammad Bagher Soltani, Seyed Mohammad Yousef Najafi
### Student Number : 98105813, 99102361
___

# Introduction

In this project, we are going to have a brief and elementary hands-on real-world project, predicting breast cancer survival using machine learning models with clinical data and gene expression profiles.

In [1]:
# imports
import numpy as np
import pandas as pd
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.impute import SimpleImputer
import umap
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.std import tqdm
from torch.optim import Adam
from copy import deepcopy

In [2]:
SEED = 111

In [3]:
np.random.seed(SEED)

# Data Documentation

For this purpose, we will use "Breast Cancer Gene Expression Profiles (METABRIC)" data. 
The first 31 columns of data contain clinical information including death status.
The next columns of the data contain gene's related information which includes both gene expressions and mutation information. (gene's mutation info columns have been marked with "_mut" at the end of the names of the columns) 
For more information please read the [data documentation](https://www.kaggle.com/datasets/raghadalharbi/breast-cancer-gene-expression-profiles-metabric).

# Data Preparation (15 Points)

In this section you must first split data into three datasets:
<br>
1- clinical dataset
<br>
2- gene expressions dataset
<br>
3- gene mutation dataset. (We will not use this dataset in further steps of the project)

## Data Loading & Splitting

In [4]:
# TODO
df = pd.read_csv('METABRIC_RNA_Mutation.csv', low_memory=False)
df.head(5)

Unnamed: 0,patient_id,age_at_diagnosis,type_of_breast_surgery,cancer_type,cancer_type_detailed,cellularity,chemotherapy,pam50_+_claudin-low_subtype,cohort,er_status_measured_by_ihc,...,mtap_mut,ppp2cb_mut,smarcd1_mut,nras_mut,ndfip1_mut,hras_mut,prps2_mut,smarcb1_mut,stmn2_mut,siah1_mut
0,0,75.65,MASTECTOMY,Breast Cancer,Breast Invasive Ductal Carcinoma,,0,claudin-low,1.0,Positve,...,0,0,0,0,0,0,0,0,0,0
1,2,43.19,BREAST CONSERVING,Breast Cancer,Breast Invasive Ductal Carcinoma,High,0,LumA,1.0,Positve,...,0,0,0,0,0,0,0,0,0,0
2,5,48.87,MASTECTOMY,Breast Cancer,Breast Invasive Ductal Carcinoma,High,1,LumB,1.0,Positve,...,0,0,0,0,0,0,0,0,0,0
3,6,47.68,MASTECTOMY,Breast Cancer,Breast Mixed Ductal and Lobular Carcinoma,Moderate,1,LumB,1.0,Positve,...,0,0,0,0,0,0,0,0,0,0
4,8,76.97,MASTECTOMY,Breast Cancer,Breast Mixed Ductal and Lobular Carcinoma,High,1,LumB,1.0,Positve,...,0,0,0,0,0,0,0,0,0,0


In [5]:
# Get column names for clinical, gene expression and gene mutation datasets

columns = df.columns
clinical_columns = columns[:31]
clinical_data_columns = df.columns[:24].append(df.columns[25:30])
label_column = columns[24]
gene_columns = columns[31:]
gene_mut_columns = pd.Index(filter(lambda s: s.endswith('_mut'),columns))
gene_expr_columns = pd.Index(set(gene_columns) - set(gene_mut_columns))

print(f'Number of clinical columns {len(clinical_columns)}')
print(f'Number of gene expression columns {len(gene_expr_columns)}')
print(f'Number of gene mutation columns {len(gene_mut_columns)}')

Number of clinical columns 31
Number of gene expression columns 489
Number of gene mutation columns 173


In [6]:
clinical_dataset = df[clinical_columns]
gene_expr_dataset = df[gene_expr_columns]
gene_mut_dataset = df[gene_mut_columns]

## EDA

For each dataset, you must perform a sufficient EDA.

In [7]:
clinical_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1904 entries, 0 to 1903
Data columns (total 31 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   patient_id                      1904 non-null   int64  
 1   age_at_diagnosis                1904 non-null   float64
 2   type_of_breast_surgery          1882 non-null   object 
 3   cancer_type                     1904 non-null   object 
 4   cancer_type_detailed            1889 non-null   object 
 5   cellularity                     1850 non-null   object 
 6   chemotherapy                    1904 non-null   int64  
 7   pam50_+_claudin-low_subtype     1904 non-null   object 
 8   cohort                          1904 non-null   float64
 9   er_status_measured_by_ihc       1874 non-null   object 
 10  er_status                       1904 non-null   object 
 11  neoplasm_histologic_grade       1832 non-null   float64
 12  her2_status_measured_by_snp6    19

In [8]:
clinical_dataset.describe()

Unnamed: 0,patient_id,age_at_diagnosis,chemotherapy,cohort,neoplasm_histologic_grade,hormone_therapy,lymph_nodes_examined_positive,mutation_count,nottingham_prognostic_index,overall_survival_months,overall_survival,radio_therapy,tumor_size,tumor_stage
count,1904.0,1904.0,1904.0,1904.0,1832.0,1904.0,1904.0,1859.0,1904.0,1904.0,1904.0,1904.0,1884.0,1403.0
mean,3921.982143,61.087054,0.207983,2.643908,2.415939,0.616597,2.002101,5.697687,4.033019,125.121324,0.420693,0.597164,26.238726,1.750535
std,2358.478332,12.978711,0.405971,1.228615,0.650612,0.486343,4.079993,4.058778,1.144492,76.334148,0.4938,0.490597,15.160976,0.628999
min,0.0,21.93,0.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0
25%,896.5,51.375,0.0,1.0,2.0,0.0,0.0,3.0,3.046,60.825,0.0,0.0,17.0,1.0
50%,4730.5,61.77,0.0,3.0,3.0,1.0,0.0,5.0,4.042,115.616667,0.0,1.0,23.0,2.0
75%,5536.25,70.5925,0.0,3.0,3.0,1.0,2.0,7.0,5.04025,184.716667,1.0,1.0,30.0,2.0
max,7299.0,96.29,1.0,5.0,3.0,1.0,45.0,80.0,6.36,355.2,1.0,1.0,182.0,4.0


In [9]:
gene_expr_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1904 entries, 0 to 1903
Columns: 489 entries, prkg1 to casp9
dtypes: float64(489)
memory usage: 7.1 MB


In [10]:
gene_expr_dataset.describe()

Unnamed: 0,prkg1,gdf11,mapk7,numbl,mmp1,col12a1,acvr1c,mapk3,magea8,nek1,...,prkd1,igf1,ctcf,bmpr1a,fbxw7,casp6,pdgfb,mapk14,tgfbr3,casp9
count,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,...,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0,1904.0
mean,1.05042e-07,-5.252101e-08,-2.10084e-07,-3.151261e-07,-9.453782e-07,2e-06,-6.827731e-07,-9.453782e-07,5.777311e-07,3.676471e-07,...,-3.151261e-07,6.302521e-07,-2.62605e-07,-2.10084e-07,1e-06,9.978992e-07,-2.10084e-07,-2e-06,-3.676471e-07,-1.57563e-07
std,1.000263,1.000263,1.000262,1.000264,1.000264,1.000263,1.000262,1.000263,1.000264,1.000261,...,1.000262,1.000263,1.000264,1.000263,1.000264,1.000262,1.000262,1.000264,1.000263,1.000264
min,-4.0392,-2.8462,-2.9241,-2.7227,-1.377,-2.843,-3.2574,-3.1421,-1.2158,-3.6108,...,-2.4195,-1.6046,-5.4334,-4.1234,-2.0713,-3.3717,-2.6355,-2.9367,-2.1243,-3.5596
25%,-0.68615,-0.65185,-0.646625,-0.65195,-0.591675,-0.7249,-0.660625,-0.658475,-0.31475,-0.68155,...,-0.725875,-0.664975,-0.6245,-0.57465,-0.70195,-0.6667,-0.670375,-0.731475,-0.7889,-0.66415
50%,-0.05055,-0.1695,-0.06985,-0.0812,-0.36035,0.0012,-0.0452,-0.0096,-0.12835,-0.01835,...,-0.0731,-0.28525,-0.00615,0.01685,-0.14125,-0.0007,-0.16895,-0.0451,-0.08895,0.0055
75%,0.60645,0.4389,0.54825,0.52865,0.1817,0.7476,0.5924,0.6262,0.081675,0.6679,...,0.65035,0.33835,0.5982,0.592875,0.562725,0.67205,0.551525,0.6718,0.70525,0.632575
max,5.1692,5.457,6.6266,5.5721,6.2827,2.6088,8.7014,4.7229,12.6817,3.8106,...,3.7757,7.4657,8.0724,7.2683,4.7139,3.5192,4.4059,3.9901,3.3077,4.1732


In [11]:
gene_mut_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1904 entries, 0 to 1903
Columns: 173 entries, pik3ca_mut to siah1_mut
dtypes: object(173)
memory usage: 2.5+ MB


In [12]:
gene_mut_dataset.describe()

Unnamed: 0,pik3ca_mut,tp53_mut,muc16_mut,ahnak2_mut,kmt2c_mut,syne1_mut,gata3_mut,map3k1_mut,ahnak_mut,dnah11_mut,...,mtap_mut,ppp2cb_mut,smarcd1_mut,nras_mut,ndfip1_mut,hras_mut,prps2_mut,smarcb1_mut,stmn2_mut,siah1_mut
count,1904,1904,1904,1904,1904,1904,1904,1904,1904,1904,...,1904,1904,1904,1904,1904,1904,1904,1904,1904,1904
unique,160,343,298,248,222,200,128,194,153,154,...,5,5,5,4,4,3,3,3,3,2
top,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
freq,1109,1245,1578,1593,1670,1672,1674,1706,1728,1729,...,1900,1900,1900,1901,1901,1902,1902,1902,1902,1903


In [13]:
# clean data means data with no NaN value in any column
def clean_stats(ds):
    return '''clean data: {0}'''.format(ds.shape[0] - ds.isnull().any(axis=1).sum())

print(f'Clinical dataset {clean_stats(clinical_dataset)}')
print(f'Gene expression dataset {clean_stats(gene_expr_dataset)}')
print(f'Gene mutation dataset {clean_stats(gene_mut_dataset)}')

Clinical dataset clean data: 1092
Gene expression dataset clean data: 1904
Gene mutation dataset clean data: 1904


In [14]:
clinical_dataset.dtypes

patient_id                          int64
age_at_diagnosis                  float64
type_of_breast_surgery             object
cancer_type                        object
cancer_type_detailed               object
cellularity                        object
chemotherapy                        int64
pam50_+_claudin-low_subtype        object
cohort                            float64
er_status_measured_by_ihc          object
er_status                          object
neoplasm_histologic_grade         float64
her2_status_measured_by_snp6       object
her2_status                        object
tumor_other_histologic_subtype     object
hormone_therapy                     int64
inferred_menopausal_state          object
integrative_cluster                object
primary_tumor_laterality           object
lymph_nodes_examined_positive     float64
mutation_count                    float64
nottingham_prognostic_index       float64
oncotree_code                      object
overall_survival_months           

In [15]:
def dtype_stats(ds):
    return '''
    columns: {0}, object columns: {1}, int columns: {2}, float columns: {3}
    '''.format(len(ds.columns),
               (ds.dtypes == 'object').sum(),
               (ds.dtypes == 'int64').sum(),
               (ds.dtypes == 'float64').sum())

print(f'Clinical dataset: {dtype_stats(clinical_dataset)}')
print(f'Gene expression dataset: {dtype_stats(gene_expr_dataset)}')
print(f'Gene mutation dataset: {dtype_stats(gene_mut_dataset)}')

Clinical dataset: 
    columns: 31, object columns: 17, int columns: 5, float columns: 9
    
Gene expression dataset: 
    columns: 489, object columns: 0, int columns: 0, float columns: 489
    
Gene mutation dataset: 
    columns: 173, object columns: 173, int columns: 0, float columns: 0
    


In [16]:
# check if int data needs scaling
clinical_dataset[clinical_columns[clinical_dataset.dtypes == 'int64']].head(5)

Unnamed: 0,patient_id,chemotherapy,hormone_therapy,overall_survival,radio_therapy
0,0,0,1,1,1
1,2,0,1,1,1
2,5,1,1,0,0
3,6,1,1,1,1
4,8,1,1,0,1


In [17]:
# define data and labels for each dataset

labels = clinical_dataset[label_column].to_numpy()
clinical_data = clinical_dataset[clinical_data_columns].to_numpy()
gene_expr_data = gene_expr_dataset.to_numpy()

In [18]:
# Convert categorical data to numerical data for clinical dataset
ordinal_encoder = OrdinalEncoder()
clinical_data = ordinal_encoder.fit_transform(clinical_data)

In [19]:
# Perform data imputation for clinical dataset
imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
clinical_data = imputer.fit_transform(clinical_data)

In [20]:
scaler = StandardScaler()
clinical_data = scaler.fit_transform(clinical_data)

In [21]:
_clinical_train_X, _clinical_test_X, _clinical_train_y, _clinical_test_y = train_test_split(clinical_data, labels, test_size=0.10, random_state=SEED)
_clinical_train_X, _clinical_val_X, _clinical_train_y, _clinical_val_y = train_test_split(_clinical_train_X, _clinical_train_y, test_size=0.10, random_state=SEED)

_gene_expr_train_X, _gene_expr_test_X, _gene_expr_train_y, _gene_expr_test_y = train_test_split(gene_expr_data, labels, test_size=0.10, random_state=SEED)
_gene_expr_train_X, _gene_expr_val_X, _gene_expr_train_y, _gene_expr_val_y = train_test_split(_gene_expr_train_X, _gene_expr_train_y, test_size=0.10, random_state=SEED)

dataset = {
    'clinical':{
        'X_train': _clinical_train_X,
        'X_val': _clinical_val_X,
        'X_test': _clinical_test_X,
        'y_train': _clinical_train_y,
        'y_val': _clinical_val_y,
        'y_test': _clinical_test_y
    },
    'gene_expr':{
        'X_train': _gene_expr_train_X,
        'X_val': _gene_expr_val_X,
        'X_test': _gene_expr_test_X,
        'y_train': _gene_expr_train_y,
        'y_val': _gene_expr_val_y,
        'y_test': _gene_expr_test_y
    },
    'gene_expr_reduced':{
    }
}

## Dimension Reduction (20 + Up to 10 Points Optional)

For each dataset, investigate whether it is needed to use a dimensionality reduction approach or not. If yes, please reduce the dataset's dimension. You can use UMAP for this purpose but any other approach is acceptable. Finding the most important features contains extra points.

<span style="color:orange">
    we check if dimensionality reduction is needed by using a simple linear regression model as a baseline .
</span>



In [22]:
# predict for the clinical dataset using linear regression
_clf = LinearRegression()
_clf.fit(dataset['clinical']['X_train'], dataset['clinical']['y_train'])
_clinical_baseline_pred = np.round(_clf.predict(dataset['clinical']['X_test']))
_clinical_baseline_accuracy = accuracy_score(dataset['clinical']['y_test'], _clinical_baseline_pred)

# predict for the gene expression dataset using linear regression
_clf = LinearRegression()
_clf.fit(dataset['gene_expr']['X_train'], dataset['gene_expr']['y_train'])
_gene_expr_baseline_pred = np.round(_clf.predict(dataset['gene_expr']['X_test']))
_gene_expr_baseline_accuracy = accuracy_score(dataset['gene_expr']['y_test'], _gene_expr_baseline_pred)

print(f'Accuracy of simple linear regression model on clinical data: {_clinical_baseline_accuracy:.3f}')
print(f'Accuracy of simple linear regression model on gene expression data: {_gene_expr_baseline_accuracy:.3f}')

Accuracy of simple linear regression model on clinical data: 0.733
Accuracy of simple linear regression model on gene expression data: 0.550


<span style="color:orange">
    As we can see, the results are much better for the clinical dataset which has few dimensions, but not so much for the gene expession dataset.
    Therefore, we will only reduce the dimensions for gene expression dataset.
</span>



In [23]:
# reduce the dimensions for clinical data and predict using baseline model
CLINICAL_REDUCED_DIMENSIONS = 10
GENE_EXPR_REDUCED_DIMENSIONS = 20


_reducer = umap.UMAP(n_components=CLINICAL_REDUCED_DIMENSIONS, random_state=SEED)
_reducer.fit(clinical_data)
_reduced_X_train = _reducer.transform(dataset['clinical']['X_train'])
_reduced_X_test = _reducer.transform(dataset['clinical']['X_test'])

_clf = LinearRegression()
_clf.fit(_reduced_X_train, dataset['clinical']['y_train'])
_clinical_reduced_baseline_pred = np.round(_clf.predict(_reduced_X_test))
_clinical_reduced_baseline_accuracy = accuracy_score(dataset['clinical']['y_test'], _clinical_reduced_baseline_pred)

# reduce the dimensions for gene expression data and predict using baseline model
_reducer = umap.UMAP(n_components=GENE_EXPR_REDUCED_DIMENSIONS, random_state=SEED)
_reducer.fit(gene_expr_data)
_reduced_X_train = _reducer.transform(dataset['gene_expr']['X_train'])
_reduced_X_val = _reducer.transform(dataset['gene_expr']['X_val'])
_reduced_X_test = _reducer.transform(dataset['gene_expr']['X_test'])

_clf = LinearRegression()
_clf.fit(_reduced_X_train, dataset['gene_expr']['y_train'])
_gene_expr_reduced_baseline_pred = np.round(_clf.predict(_reduced_X_test))
_gene_expr_reduced_baseline_accuracy = accuracy_score(dataset['gene_expr']['y_test'], _gene_expr_reduced_baseline_pred)

print(f'Accuracy of simple linear regression model on reduced clinical data: {_clinical_reduced_baseline_accuracy:.3f}')
print(f'Accuracy of simple linear regression model on reduced gene expression data: {_gene_expr_reduced_baseline_accuracy:.3f}')

Accuracy of simple linear regression model on reduced clinical data: 0.618
Accuracy of simple linear regression model on reduced gene expression data: 0.613


<span style="color:orange">
    As we can see, applying dimension reduction on the clinical dataset leads to worse results, while on gene expression dataset improves the predictions.
    Therefore, we choose to reduce the dimensions of only the gene expression dataset. 
</span>



In [24]:
dataset['gene_expr_reduced'] = {
    'X_train': _reduced_X_train,
    'X_val': _reduced_X_val,
    'X_test': _reduced_X_test,
    'y_train': _gene_expr_train_y,
    'y_val': _gene_expr_val_y,
    'y_test': _gene_expr_test_y
}

# Classic Model (25 Points)

In this section, you must implement a classic classification model for clinical, gene expressions, and reduced gene expressions datasets. Using Random Forest is suggested. (minimum acceptable accuracy = 60%)

In [25]:
random_forst_models = {
    'clinical': None,
    'gene_expr': None,
    'gene_expr_reduced': None
}

for ds_name in random_forst_models:
    clf = RandomForestClassifier(random_state=SEED)
    ds = dataset[ds_name]
    clf.fit(ds['X_train'], ds['y_train'])
    y_pred = clf.predict(ds['X_test'])
    acc = accuracy_score(ds['y_test'], y_pred)
    random_forst_models[ds_name] = {
        'model': clf,
        'accuracy': acc
    }

    print(f'random forest on {ds_name} dataset had accuracy of {acc:.4f}')

svm_models = random_forst_models.copy()

for ds_name in random_forst_models:
    clf = SVC(random_state=SEED)
    ds = dataset[ds_name]
    clf.fit(ds['X_train'], ds['y_train'])
    y_pred = clf.predict(ds['X_test'])
    acc = accuracy_score(ds['y_test'], y_pred)
    random_forst_models[ds_name] = {
        'model': clf,
        'accuracy': acc
    }

    print(f'svm on {ds_name} dataset had accuracy of {acc:.4f}')



random forest on clinical dataset had accuracy of 0.7435
random forest on gene_expr dataset had accuracy of 0.6230
random forest on gene_expr_reduced dataset had accuracy of 0.6178
svm on clinical dataset had accuracy of 0.7016
svm on gene_expr dataset had accuracy of 0.6178
svm on gene_expr_reduced dataset had accuracy of 0.5654


# Neural Network (30 Points)

In this section, you must implement a neural network model for clinical, gene expressions and reduced gene expressions datasets. Using the MPL models is suggested. (minimum acceptable accuracy = 60%)

In [26]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [27]:
class CancerDataset(Dataset):

    def __init__(self, X, y) -> None:
        super().__init__()
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx].astype(np.float32), self.y[idx].astype(np.float32)


In [28]:
dataloaders = {}
batch_size = 64

for ds_name, ds_split in dataset.items():
    dataloaders[ds_name] = {}
    X_train = ds_split['X_train']
    X_val = ds_split['X_val']
    X_test = ds_split['X_test']
    y_train = ds_split['y_train']
    y_val = ds_split['y_val']
    y_test = ds_split['y_test']
    
    train_ds = CancerDataset(X_train, y_train)
    val_ds = CancerDataset(X_val, y_val)
    test_ds = CancerDataset(X_test, y_test)

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=True)
    test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

    dataloaders[ds_name]['train'] = train_dl
    dataloaders[ds_name]['val'] = val_dl
    dataloaders[ds_name]['test'] = test_dl

In [29]:
def update_confusion_matrix(confusion, out, labels):
    out = out.astype(int)
    labels = labels.astype(int)
    for i in range(confusion.shape[0]):
        for j in range(confusion.shape[1]):
            confusion[i, j] += ((out == i) & (labels == j)).sum()
    
    return confusion

In [30]:
def evaluate(model, dataloader):
    model.eval()
    total, correct = 0, 0
    confusion = np.zeros((2, 2))
    with torch.no_grad():
        for _, (data, labels) in enumerate(dataloader):
            data, labels = data.to(device), labels.to(device)

            pred = model(data).squeeze()
            out = torch.round(pred)
            update_confusion_matrix(confusion, out.cpu().numpy(), labels.cpu().numpy())

            correct = correct + (labels == out).sum().detach().cpu().numpy()
            total = total + len(data)
    
    return correct / total, confusion

In [31]:
def train(model, criterion, optimizer, train_dataloader, val_dataloader, num_epochs):
    best_model = model
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    best_val_loss = np.inf

    for epoch in range(num_epochs):
        # train 
        model.train()
        total, correct = 0, 0
        train_loss = 0.0
        with tqdm(enumerate(train_dataloader), total=len(train_dataloader)) as pbar:
            for i, (data, labels) in pbar:
                data, labels = data.to(device), labels.to(device)
                optimizer.zero_grad()

                pred = model(data).squeeze()
                out = pred.detach().round()

                total = total + len(data)
                correct = correct + (labels == out).sum().detach().cpu().numpy()
                
                loss = criterion(pred, labels)
                loss.backward()
                optimizer.step()

                train_loss = train_loss + loss.detach().cpu().numpy()
                
                pbar.set_description('Epoch {0}: train loss={1}, train accuracy={2}'.format(epoch, train_loss / total,
                                                                                            correct / total))
        
        train_losses.append(train_loss)
        train_accs.append((correct / total))
        
        # validation
        model.eval()
        total, correct = 0, 0
        val_loss = 0.0
        with torch.no_grad():
            with tqdm(enumerate(val_dataloader), total=len(val_dataloader)) as pbar:
                for i, (data, labels) in pbar:
                    data, labels = data.to(device), labels.to(device)

                    pred = model(data).squeeze()
                    out = torch.round(pred)

                    correct = correct + (labels == out).sum().detach().cpu().numpy()
                    total = total + len(data)

                    val_loss += criterion(pred, labels).detach().cpu().numpy()
                    
                    pbar.set_description('Epoch {0}: val loss={1}, val accuracy={2}'.format(epoch, val_loss / total,
                                                                                                correct / total))
        
        val_loss = val_loss / total
        if val_loss < best_val_loss:
            print('New model saved, val loss {0} -> {1}'.format(best_val_loss, val_loss))
            best_val_loss = val_loss
            best_model = deepcopy(model)

    return train_losses, val_losses, train_accs, val_accs, best_model

In [32]:
mlp_models = random_forst_models.copy()
lr = 1e-4
num_epochs = 100
torch.manual_seed(SEED)

for ds_name in mlp_models:
    net = nn.Sequential(
        nn.Linear(dataset[ds_name]['X_train'].shape[1], 64),
        nn.Tanh(),
        nn.BatchNorm1d(64),
        nn.Linear(64, 32),
        nn.Tanh(),
        nn.BatchNorm1d(32),
        nn.Linear(32, 32),
        nn.Tanh(),
        nn.BatchNorm1d(32),
        nn.Linear(32, 16),
        nn.Tanh(),
        nn.BatchNorm1d(16),
        nn.Linear(16, 1),
        nn.Sigmoid()
    )
    net = net.to(device)
    optimizer = Adam(net.parameters(), lr=lr)
    criterion = nn.MSELoss()
    train_dl = dataloaders[ds_name]['train']
    val_dl = dataloaders[ds_name]['val']
    test_dl = dataloaders[ds_name]['test']

    return_vals = train(net, criterion, optimizer, train_dl, val_dl, num_epochs)
    train_losses, val_losses, train_accs, val_accs, best_model = return_vals
    mlp_models[ds_name] = {
        'model': best_model,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    }

Epoch 0: train loss=0.004186857025163812, train accuracy=0.5139519792342635: 100%|█████| 25/25 [00:02<00:00, 10.14it/s]
Epoch 0: val loss=0.004254257523043211, val accuracy=0.5813953488372093: 100%|██████████| 3/3 [00:00<00:00, 228.85it/s]


New model saved, val loss inf -> 0.004254257523043211


Epoch 1: train loss=0.0037064810490778405, train accuracy=0.6417910447761194: 100%|███| 25/25 [00:00<00:00, 107.66it/s]
Epoch 1: val loss=0.0037059891362522922, val accuracy=0.6918604651162791: 100%|█████████| 3/3 [00:00<00:00, 330.92it/s]


New model saved, val loss 0.004254257523043211 -> 0.0037059891362522922


Epoch 2: train loss=0.0034773016830762435, train accuracy=0.6904607397793641: 100%|███| 25/25 [00:00<00:00, 107.28it/s]
Epoch 2: val loss=0.003442274796408276, val accuracy=0.7441860465116279: 100%|██████████| 3/3 [00:00<00:00, 259.99it/s]


New model saved, val loss 0.0037059891362522922 -> 0.003442274796408276


Epoch 3: train loss=0.0033907042858586692, train accuracy=0.7047371836469825: 100%|███| 25/25 [00:00<00:00, 106.08it/s]
Epoch 3: val loss=0.0033052011804525242, val accuracy=0.7732558139534884: 100%|█████████| 3/3 [00:00<00:00, 259.28it/s]


New model saved, val loss 0.003442274796408276 -> 0.0033052011804525242


Epoch 4: train loss=0.00323977057884298, train accuracy=0.7177157689811811: 100%|█████| 25/25 [00:00<00:00, 105.04it/s]
Epoch 4: val loss=0.0031748180126034936, val accuracy=0.7732558139534884: 100%|█████████| 3/3 [00:00<00:00, 259.43it/s]


New model saved, val loss 0.0033052011804525242 -> 0.0031748180126034936


Epoch 5: train loss=0.003200355160553528, train accuracy=0.7287475665152499: 100%|████| 25/25 [00:00<00:00, 108.93it/s]
Epoch 5: val loss=0.0030229012806748234, val accuracy=0.75: 100%|███████████████████████| 3/3 [00:00<00:00, 259.37it/s]


New model saved, val loss 0.0031748180126034936 -> 0.0030229012806748234


Epoch 6: train loss=0.0031071244451923853, train accuracy=0.7391304347826086: 100%|███| 25/25 [00:00<00:00, 111.03it/s]
Epoch 6: val loss=0.002880395411751991, val accuracy=0.7616279069767442: 100%|██████████| 3/3 [00:00<00:00, 284.29it/s]


New model saved, val loss 0.0030229012806748234 -> 0.002880395411751991


Epoch 7: train loss=0.0030752084980066683, train accuracy=0.7404282933160286: 100%|███| 25/25 [00:00<00:00, 108.28it/s]
Epoch 7: val loss=0.0028898170174554336, val accuracy=0.7732558139534884: 100%|█████████| 3/3 [00:00<00:00, 262.44it/s]
Epoch 8: train loss=0.0029681421173307354, train accuracy=0.7488643737832577: 100%|███| 25/25 [00:00<00:00, 110.97it/s]
Epoch 8: val loss=0.002813892568959746, val accuracy=0.7616279069767442: 100%|██████████| 3/3 [00:00<00:00, 269.46it/s]


New model saved, val loss 0.002880395411751991 -> 0.002813892568959746


Epoch 9: train loss=0.0029251424524398845, train accuracy=0.7527579493835171: 100%|███| 25/25 [00:00<00:00, 109.94it/s]
Epoch 9: val loss=0.0027268126953479857, val accuracy=0.7848837209302325: 100%|█████████| 3/3 [00:00<00:00, 259.86it/s]


New model saved, val loss 0.002813892568959746 -> 0.0027268126953479857


Epoch 10: train loss=0.002883716068779935, train accuracy=0.7475665152498377: 100%|███| 25/25 [00:00<00:00, 112.17it/s]
Epoch 10: val loss=0.002699995889913204, val accuracy=0.7848837209302325: 100%|█████████| 3/3 [00:00<00:00, 260.75it/s]


New model saved, val loss 0.0027268126953479857 -> 0.002699995889913204


Epoch 11: train loss=0.002848796106483316, train accuracy=0.7560025957170668: 100%|███| 25/25 [00:00<00:00, 108.31it/s]
Epoch 11: val loss=0.0027269656055195386, val accuracy=0.7906976744186046: 100%|████████| 3/3 [00:00<00:00, 283.62it/s]
Epoch 12: train loss=0.0028345436486309802, train accuracy=0.7560025957170668: 100%|██| 25/25 [00:00<00:00, 106.87it/s]
Epoch 12: val loss=0.0027188251531401345, val accuracy=0.7965116279069767: 100%|████████| 3/3 [00:00<00:00, 284.66it/s]
Epoch 13: train loss=0.002853294023440149, train accuracy=0.7592472420506164: 100%|███| 25/25 [00:00<00:00, 107.30it/s]
Epoch 13: val loss=0.0026331493674322617, val accuracy=0.813953488372093: 100%|█████████| 3/3 [00:00<00:00, 284.80it/s]


New model saved, val loss 0.002699995889913204 -> 0.0026331493674322617


Epoch 14: train loss=0.002748334890329397, train accuracy=0.7560025957170668: 100%|███| 25/25 [00:00<00:00, 107.18it/s]
Epoch 14: val loss=0.0027068249534728914, val accuracy=0.8023255813953488: 100%|████████| 3/3 [00:00<00:00, 284.46it/s]
Epoch 15: train loss=0.0027736915755318327, train accuracy=0.7611940298507462: 100%|██| 25/25 [00:00<00:00, 102.38it/s]
Epoch 15: val loss=0.0025859022902887923, val accuracy=0.8023255813953488: 100%|████████| 3/3 [00:00<00:00, 296.50it/s]


New model saved, val loss 0.0026331493674322617 -> 0.0025859022902887923


Epoch 16: train loss=0.002799772630257702, train accuracy=0.7566515249837767: 100%|███| 25/25 [00:00<00:00, 109.51it/s]
Epoch 16: val loss=0.0025719814868860468, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 260.29it/s]


New model saved, val loss 0.0025859022902887923 -> 0.0025719814868860468


Epoch 17: train loss=0.002752687170434354, train accuracy=0.7598961713173265: 100%|███| 25/25 [00:00<00:00, 100.82it/s]
Epoch 17: val loss=0.002607990489449612, val accuracy=0.7965116279069767: 100%|█████████| 3/3 [00:00<00:00, 246.81it/s]
Epoch 18: train loss=0.002775334087693018, train accuracy=0.7637897469175859: 100%|███| 25/25 [00:00<00:00, 103.20it/s]
Epoch 18: val loss=0.002621636314447536, val accuracy=0.7848837209302325: 100%|█████████| 3/3 [00:00<00:00, 284.35it/s]
Epoch 19: train loss=0.0027289823432544236, train accuracy=0.7624918883841661: 100%|██| 25/25 [00:00<00:00, 102.63it/s]
Epoch 19: val loss=0.002621595509523569, val accuracy=0.7965116279069767: 100%|█████████| 3/3 [00:00<00:00, 265.71it/s]
Epoch 20: train loss=0.0028638462212703043, train accuracy=0.7696301103179753: 100%|██| 25/25 [00:00<00:00, 100.11it/s]
Epoch 20: val loss=0.002595703962237336, val accuracy=0.7906976744186046: 100%|█████████| 3/3 [00:00<00:00, 247.50it/s]
Epoch 21: train loss=0.00270771585274330

New model saved, val loss 0.0025719814868860468 -> 0.002551177646531615


Epoch 23: train loss=0.002714113870818146, train accuracy=0.7573004542504866: 100%|███| 25/25 [00:00<00:00, 101.82it/s]
Epoch 23: val loss=0.0025676287017589393, val accuracy=0.8023255813953488: 100%|████████| 3/3 [00:00<00:00, 259.81it/s]
Epoch 24: train loss=0.002744743918382371, train accuracy=0.7573004542504866: 100%|████| 25/25 [00:00<00:00, 92.11it/s]
Epoch 24: val loss=0.002569200600995574, val accuracy=0.813953488372093: 100%|██████████| 3/3 [00:00<00:00, 238.77it/s]
Epoch 25: train loss=0.0027235983922294328, train accuracy=0.763140817650876: 100%|████| 25/25 [00:00<00:00, 99.07it/s]
Epoch 25: val loss=0.0025356146831845127, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 228.93it/s]


New model saved, val loss 0.002551177646531615 -> 0.0025356146831845127


Epoch 26: train loss=0.002688529703579964, train accuracy=0.7637897469175859: 100%|████| 25/25 [00:00<00:00, 95.65it/s]
Epoch 26: val loss=0.0025716767061588377, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 234.18it/s]
Epoch 27: train loss=0.0027279196371125525, train accuracy=0.7657365347177157: 100%|███| 25/25 [00:00<00:00, 95.59it/s]
Epoch 27: val loss=0.0024787207935438598, val accuracy=0.7906976744186046: 100%|████████| 3/3 [00:00<00:00, 259.75it/s]


New model saved, val loss 0.0025356146831845127 -> 0.0024787207935438598


Epoch 28: train loss=0.002778125570278366, train accuracy=0.7585983127839065: 100%|███| 25/25 [00:00<00:00, 101.16it/s]
Epoch 28: val loss=0.002591439544461494, val accuracy=0.813953488372093: 100%|██████████| 3/3 [00:00<00:00, 221.35it/s]
Epoch 29: train loss=0.0027430106744914763, train accuracy=0.7689811810512654: 100%|███| 25/25 [00:00<00:00, 99.21it/s]
Epoch 29: val loss=0.002566968892202821, val accuracy=0.8023255813953488: 100%|█████████| 3/3 [00:00<00:00, 248.73it/s]
Epoch 30: train loss=0.002651317529791288, train accuracy=0.7696301103179753: 100%|███| 25/25 [00:00<00:00, 100.85it/s]
Epoch 30: val loss=0.002582995004432146, val accuracy=0.813953488372093: 100%|██████████| 3/3 [00:00<00:00, 259.99it/s]
Epoch 31: train loss=0.002709914823192656, train accuracy=0.7709279688513953: 100%|███| 25/25 [00:00<00:00, 100.57it/s]
Epoch 31: val loss=0.0026330609134463377, val accuracy=0.7965116279069767: 100%|████████| 3/3 [00:00<00:00, 229.29it/s]
Epoch 32: train loss=0.00278263784688761

New model saved, val loss 0.0024787207935438598 -> 0.002470434145178906


Epoch 33: train loss=0.0027059535656092926, train accuracy=0.7767683322517845: 100%|██| 25/25 [00:00<00:00, 102.38it/s]
Epoch 33: val loss=0.002509467130483583, val accuracy=0.813953488372093: 100%|██████████| 3/3 [00:00<00:00, 238.67it/s]
Epoch 34: train loss=0.002706418450895182, train accuracy=0.7761194029850746: 100%|███| 25/25 [00:00<00:00, 102.97it/s]
Epoch 34: val loss=0.002498256605724956, val accuracy=0.8081395348837209: 100%|█████████| 3/3 [00:00<00:00, 247.09it/s]
Epoch 35: train loss=0.0027804335918927797, train accuracy=0.7702790395846852: 100%|███| 25/25 [00:00<00:00, 99.56it/s]
Epoch 35: val loss=0.00249504410596781, val accuracy=0.813953488372093: 100%|███████████| 3/3 [00:00<00:00, 258.15it/s]
Epoch 36: train loss=0.0026349666915800106, train accuracy=0.773523685918235: 100%|████| 25/25 [00:00<00:00, 99.27it/s]
Epoch 36: val loss=0.002519461218007775, val accuracy=0.8081395348837209: 100%|█████████| 3/3 [00:00<00:00, 238.49it/s]
Epoch 37: train loss=0.00259207374743251

New model saved, val loss 0.002470434145178906 -> 0.002441390395857567


Epoch 44: train loss=0.0025537763658782867, train accuracy=0.7845554834523037: 100%|██| 25/25 [00:00<00:00, 101.59it/s]
Epoch 44: val loss=0.0024957123190857645, val accuracy=0.8023255813953488: 100%|████████| 3/3 [00:00<00:00, 221.95it/s]
Epoch 45: train loss=0.0025639731840681984, train accuracy=0.7839065541855937: 100%|███| 25/25 [00:00<00:00, 97.07it/s]
Epoch 45: val loss=0.002545339077018028, val accuracy=0.7848837209302325: 100%|█████████| 3/3 [00:00<00:00, 242.61it/s]
Epoch 46: train loss=0.0026636457410359987, train accuracy=0.773523685918235: 100%|████| 25/25 [00:00<00:00, 97.83it/s]
Epoch 46: val loss=0.0025616156067266023, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 260.32it/s]
Epoch 47: train loss=0.0025691587108052603, train accuracy=0.781310837118754: 100%|███| 25/25 [00:00<00:00, 101.74it/s]
Epoch 47: val loss=0.0025486660211585287, val accuracy=0.7906976744186046: 100%|████████| 3/3 [00:00<00:00, 238.99it/s]
Epoch 48: train loss=0.00254125515026832

New model saved, val loss 0.002441390395857567 -> 0.002437549621559853


Epoch 53: train loss=0.002527007464462097, train accuracy=0.7878001297858533: 100%|████| 25/25 [00:00<00:00, 99.06it/s]
Epoch 53: val loss=0.0025431053323108094, val accuracy=0.7906976744186046: 100%|████████| 3/3 [00:00<00:00, 259.62it/s]
Epoch 54: train loss=0.0024662448176697117, train accuracy=0.790395846852693: 100%|████| 25/25 [00:00<00:00, 98.35it/s]
Epoch 54: val loss=0.0024659236849740493, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 221.33it/s]
Epoch 55: train loss=0.0024856995745265275, train accuracy=0.7878001297858533: 100%|███| 25/25 [00:00<00:00, 95.69it/s]
Epoch 55: val loss=0.0025571531854396644, val accuracy=0.7906976744186046: 100%|████████| 3/3 [00:00<00:00, 199.45it/s]
Epoch 56: train loss=0.0024873079571300014, train accuracy=0.7929915639195327: 100%|██| 25/25 [00:00<00:00, 100.69it/s]
Epoch 56: val loss=0.0025105114246523658, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 212.96it/s]
Epoch 57: train loss=0.00253855349045152

New model saved, val loss 0.002437549621559853 -> 0.00243528819707937


Epoch 58: train loss=0.00248988357110893, train accuracy=0.7936404931862427: 100%|████| 25/25 [00:00<00:00, 100.98it/s]
Epoch 58: val loss=0.002453657950079718, val accuracy=0.8255813953488372: 100%|█████████| 3/3 [00:00<00:00, 260.88it/s]
Epoch 59: train loss=0.0024649978970104343, train accuracy=0.7975340687865022: 100%|███| 25/25 [00:00<00:00, 98.13it/s]
Epoch 59: val loss=0.002485594733856445, val accuracy=0.8081395348837209: 100%|█████████| 3/3 [00:00<00:00, 222.35it/s]
Epoch 60: train loss=0.00253715509721631, train accuracy=0.7890979883192732: 100%|█████| 25/25 [00:00<00:00, 97.12it/s]
Epoch 60: val loss=0.0025055673579837, val accuracy=0.7965116279069767: 100%|███████████| 3/3 [00:00<00:00, 284.05it/s]
Epoch 61: train loss=0.002429900300495041, train accuracy=0.7923426346528228: 100%|████| 25/25 [00:00<00:00, 97.07it/s]
Epoch 61: val loss=0.0025215205238309016, val accuracy=0.8023255813953488: 100%|████████| 3/3 [00:00<00:00, 248.98it/s]
Epoch 62: train loss=0.00251751523138562

New model saved, val loss 0.00243528819707937 -> 0.002358339770242225


Epoch 63: train loss=0.0024349105862887317, train accuracy=0.7878001297858533: 100%|███| 25/25 [00:00<00:00, 99.51it/s]
Epoch 63: val loss=0.0024348319356524667, val accuracy=0.8081395348837209: 100%|████████| 3/3 [00:00<00:00, 247.70it/s]
Epoch 64: train loss=0.002391303604798685, train accuracy=0.8007787151200519: 100%|███| 25/25 [00:00<00:00, 101.43it/s]
Epoch 64: val loss=0.002481071928212809, val accuracy=0.8081395348837209: 100%|█████████| 3/3 [00:00<00:00, 262.54it/s]
Epoch 65: train loss=0.0024942989805006193, train accuracy=0.7942894224529526: 100%|███| 25/25 [00:00<00:00, 98.28it/s]
Epoch 65: val loss=0.002474072800819264, val accuracy=0.7965116279069767: 100%|█████████| 3/3 [00:00<00:00, 221.85it/s]
Epoch 66: train loss=0.0024788776952906293, train accuracy=0.8066190785204412: 100%|███| 25/25 [00:00<00:00, 99.78it/s]
Epoch 66: val loss=0.0025176136472890545, val accuracy=0.813953488372093: 100%|█████████| 3/3 [00:00<00:00, 239.72it/s]
Epoch 67: train loss=0.00235116831344261

Epoch 97: train loss=0.0022651973790274588, train accuracy=0.8215444516547696: 100%|██| 25/25 [00:00<00:00, 102.99it/s]
Epoch 97: val loss=0.002356385119086088, val accuracy=0.7848837209302325: 100%|█████████| 3/3 [00:00<00:00, 228.11it/s]


New model saved, val loss 0.002358339770242225 -> 0.002356385119086088


Epoch 98: train loss=0.0021023074918022256, train accuracy=0.827384815055159: 100%|████| 25/25 [00:00<00:00, 99.07it/s]
Epoch 98: val loss=0.002469970952979354, val accuracy=0.8081395348837209: 100%|█████████| 3/3 [00:00<00:00, 198.66it/s]
Epoch 99: train loss=0.002157867887939129, train accuracy=0.827384815055159: 100%|█████| 25/25 [00:00<00:00, 99.70it/s]
Epoch 99: val loss=0.002433848675600318, val accuracy=0.7848837209302325: 100%|█████████| 3/3 [00:00<00:00, 239.72it/s]
Epoch 0: train loss=0.004034069288790342, train accuracy=0.5476963011031798: 100%|█████| 25/25 [00:00<00:00, 99.04it/s]
Epoch 0: val loss=0.004259580529706423, val accuracy=0.5581395348837209: 100%|██████████| 3/3 [00:00<00:00, 247.37it/s]


New model saved, val loss inf -> 0.004259580529706423


Epoch 1: train loss=0.0038334910046505044, train accuracy=0.6048020765736535: 100%|███| 25/25 [00:00<00:00, 100.24it/s]
Epoch 1: val loss=0.004185498887023261, val accuracy=0.5872093023255814: 100%|██████████| 3/3 [00:00<00:00, 239.09it/s]


New model saved, val loss 0.004259580529706423 -> 0.004185498887023261


Epoch 2: train loss=0.0036430633799121566, train accuracy=0.6171317326411421: 100%|████| 25/25 [00:00<00:00, 96.75it/s]
Epoch 2: val loss=0.004182288379863251, val accuracy=0.5872093023255814: 100%|██████████| 3/3 [00:00<00:00, 229.64it/s]


New model saved, val loss 0.004185498887023261 -> 0.004182288379863251


Epoch 3: train loss=0.0035741528487685126, train accuracy=0.645684620376379: 100%|█████| 25/25 [00:00<00:00, 95.72it/s]
Epoch 3: val loss=0.004101254998944526, val accuracy=0.6046511627906976: 100%|██████████| 3/3 [00:00<00:00, 221.41it/s]


New model saved, val loss 0.004182288379863251 -> 0.004101254998944526


Epoch 4: train loss=0.0035001646188232206, train accuracy=0.6541207008436081: 100%|███| 25/25 [00:00<00:00, 100.04it/s]
Epoch 4: val loss=0.004130225282075794, val accuracy=0.6162790697674418: 100%|██████████| 3/3 [00:00<00:00, 206.37it/s]
Epoch 5: train loss=0.0034288239413310306, train accuracy=0.6658014276443868: 100%|████| 25/25 [00:00<00:00, 94.91it/s]
Epoch 5: val loss=0.004041883747938068, val accuracy=0.6104651162790697: 100%|██████████| 3/3 [00:00<00:00, 230.65it/s]


New model saved, val loss 0.004101254998944526 -> 0.004041883747938068


Epoch 6: train loss=0.0033424875416746395, train accuracy=0.6852693056456847: 100%|████| 25/25 [00:00<00:00, 94.14it/s]
Epoch 6: val loss=0.004010418300018754, val accuracy=0.5930232558139535: 100%|██████████| 3/3 [00:00<00:00, 221.68it/s]


New model saved, val loss 0.004041883747938068 -> 0.004010418300018754


Epoch 7: train loss=0.003178748718567439, train accuracy=0.7034393251135627: 100%|█████| 25/25 [00:00<00:00, 97.24it/s]
Epoch 7: val loss=0.004062511546667232, val accuracy=0.6046511627906976: 100%|██████████| 3/3 [00:00<00:00, 206.58it/s]
Epoch 8: train loss=0.0031457001284141963, train accuracy=0.7170668397144712: 100%|████| 25/25 [00:00<00:00, 92.66it/s]
Epoch 8: val loss=0.004066143444804258, val accuracy=0.6046511627906976: 100%|██████████| 3/3 [00:00<00:00, 215.46it/s]
Epoch 9: train loss=0.003071419851181184, train accuracy=0.7157689811810513: 100%|█████| 25/25 [00:00<00:00, 97.11it/s]
Epoch 9: val loss=0.003959581429182097, val accuracy=0.5813953488372093: 100%|██████████| 3/3 [00:00<00:00, 199.13it/s]


New model saved, val loss 0.004010418300018754 -> 0.003959581429182097


Epoch 10: train loss=0.0029488064440560758, train accuracy=0.7410772225827384: 100%|███| 25/25 [00:00<00:00, 97.50it/s]
Epoch 10: val loss=0.00400231986544853, val accuracy=0.5930232558139535: 100%|██████████| 3/3 [00:00<00:00, 206.11it/s]
Epoch 11: train loss=0.0028982112192627984, train accuracy=0.7475665152498377: 100%|███| 25/25 [00:00<00:00, 96.10it/s]
Epoch 11: val loss=0.003921451870092126, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 213.43it/s]


New model saved, val loss 0.003959581429182097 -> 0.003921451870092126


Epoch 12: train loss=0.0028020766920891773, train accuracy=0.7689811810512654: 100%|███| 25/25 [00:00<00:00, 93.91it/s]
Epoch 12: val loss=0.004053608795931173, val accuracy=0.5755813953488372: 100%|█████████| 3/3 [00:00<00:00, 214.83it/s]
Epoch 13: train loss=0.0026703186086599585, train accuracy=0.7715768981181051: 100%|███| 25/25 [00:00<00:00, 98.16it/s]
Epoch 13: val loss=0.004066407593876817, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 239.33it/s]
Epoch 14: train loss=0.0026383090046301514, train accuracy=0.7793640493186242: 100%|███| 25/25 [00:00<00:00, 97.46it/s]
Epoch 14: val loss=0.004054821854413942, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 238.68it/s]
Epoch 15: train loss=0.0025841803480324073, train accuracy=0.7839065541855937: 100%|███| 25/25 [00:00<00:00, 96.13it/s]
Epoch 15: val loss=0.003971133121224337, val accuracy=0.5988372093023255: 100%|█████████| 3/3 [00:00<00:00, 239.04it/s]
Epoch 16: train loss=0.00256364424649497

Epoch 46: train loss=0.0007273377554514128, train accuracy=0.9753406878650227: 100%|███| 25/25 [00:00<00:00, 93.93it/s]
Epoch 46: val loss=0.0043620068666546845, val accuracy=0.6337209302325582: 100%|████████| 3/3 [00:00<00:00, 259.17it/s]
Epoch 47: train loss=0.0007506796001111587, train accuracy=0.9779364049318624: 100%|███| 25/25 [00:00<00:00, 93.18it/s]
Epoch 47: val loss=0.004500604386246482, val accuracy=0.6569767441860465: 100%|█████████| 3/3 [00:00<00:00, 229.44it/s]
Epoch 48: train loss=0.0007542728214953179, train accuracy=0.9811810512654121: 100%|███| 25/25 [00:00<00:00, 98.82it/s]
Epoch 48: val loss=0.0044765278350475226, val accuracy=0.6104651162790697: 100%|████████| 3/3 [00:00<00:00, 221.80it/s]
Epoch 49: train loss=0.0006912186388881232, train accuracy=0.9798831927319922: 100%|██| 25/25 [00:00<00:00, 103.08it/s]
Epoch 49: val loss=0.0045532015000664915, val accuracy=0.627906976744186: 100%|█████████| 3/3 [00:00<00:00, 222.58it/s]
Epoch 50: train loss=0.00077547866181457

Epoch 80: train loss=0.0004506467803490773, train accuracy=0.9935107073329007: 100%|██| 25/25 [00:00<00:00, 101.12it/s]
Epoch 80: val loss=0.004507343349761741, val accuracy=0.6395348837209303: 100%|█████████| 3/3 [00:00<00:00, 212.48it/s]
Epoch 81: train loss=0.00025686196496953136, train accuracy=0.9954574951330305: 100%|█| 25/25 [00:00<00:00, 101.00it/s]
Epoch 81: val loss=0.0045166178498157235, val accuracy=0.6569767441860465: 100%|████████| 3/3 [00:00<00:00, 221.42it/s]
Epoch 82: train loss=0.00021564082673458912, train accuracy=0.9987021414665801: 100%|██| 25/25 [00:00<00:00, 98.79it/s]
Epoch 82: val loss=0.004567274693832841, val accuracy=0.6511627906976745: 100%|█████████| 3/3 [00:00<00:00, 239.42it/s]
Epoch 83: train loss=0.00032871180636437195, train accuracy=0.9967553536664504: 100%|██| 25/25 [00:00<00:00, 98.23it/s]
Epoch 83: val loss=0.004804154814675797, val accuracy=0.627906976744186: 100%|██████████| 3/3 [00:00<00:00, 248.67it/s]
Epoch 84: train loss=0.00020278694100876

New model saved, val loss inf -> 0.004363667653050534


Epoch 1: train loss=0.003816731008208471, train accuracy=0.6048020765736535: 100%|█████| 25/25 [00:00<00:00, 93.25it/s]
Epoch 1: val loss=0.004434009017639382, val accuracy=0.4883720930232558: 100%|██████████| 3/3 [00:00<00:00, 187.14it/s]
Epoch 2: train loss=0.0036871449890111966, train accuracy=0.6054510058403634: 100%|████| 25/25 [00:00<00:00, 85.13it/s]
Epoch 2: val loss=0.0043147989483766774, val accuracy=0.5465116279069767: 100%|█████████| 3/3 [00:00<00:00, 236.69it/s]


New model saved, val loss 0.004363667653050534 -> 0.0043147989483766774


Epoch 3: train loss=0.0036833712916959035, train accuracy=0.6145360155743024: 100%|████| 25/25 [00:00<00:00, 99.44it/s]
Epoch 3: val loss=0.004250412763551224, val accuracy=0.5755813953488372: 100%|██████████| 3/3 [00:00<00:00, 239.55it/s]


New model saved, val loss 0.0043147989483766774 -> 0.004250412763551224


Epoch 4: train loss=0.003686450832776633, train accuracy=0.6242699545749514: 100%|████| 25/25 [00:00<00:00, 103.51it/s]
Epoch 4: val loss=0.0042106672081836435, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 202.19it/s]


New model saved, val loss 0.004250412763551224 -> 0.0042106672081836435


Epoch 5: train loss=0.003741206247105682, train accuracy=0.618429591174562: 100%|█████| 25/25 [00:00<00:00, 102.89it/s]
Epoch 5: val loss=0.004147362397160641, val accuracy=0.6046511627906976: 100%|██████████| 3/3 [00:00<00:00, 239.02it/s]


New model saved, val loss 0.0042106672081836435 -> 0.004147362397160641


Epoch 6: train loss=0.00377381614746017, train accuracy=0.6112913692407528: 100%|█████| 25/25 [00:00<00:00, 104.47it/s]
Epoch 6: val loss=0.00413021401957024, val accuracy=0.5872093023255814: 100%|███████████| 3/3 [00:00<00:00, 239.26it/s]


New model saved, val loss 0.004147362397160641 -> 0.00413021401957024


Epoch 7: train loss=0.003681831656532733, train accuracy=0.6255678131083712: 100%|████| 25/25 [00:00<00:00, 104.85it/s]
Epoch 7: val loss=0.004214010006466577, val accuracy=0.5755813953488372: 100%|██████████| 3/3 [00:00<00:00, 229.33it/s]
Epoch 8: train loss=0.003648810257855984, train accuracy=0.6262167423750811: 100%|████| 25/25 [00:00<00:00, 103.64it/s]
Epoch 8: val loss=0.004156720586294352, val accuracy=0.5872093023255814: 100%|██████████| 3/3 [00:00<00:00, 228.78it/s]
Epoch 9: train loss=0.003660760863426055, train accuracy=0.6223231667748216: 100%|████| 25/25 [00:00<00:00, 106.05it/s]
Epoch 9: val loss=0.004163230401138926, val accuracy=0.5813953488372093: 100%|██████████| 3/3 [00:00<00:00, 206.34it/s]
Epoch 10: train loss=0.003632847033561939, train accuracy=0.6190785204412719: 100%|███| 25/25 [00:00<00:00, 106.32it/s]
Epoch 10: val loss=0.0041625070883784185, val accuracy=0.5755813953488372: 100%|████████| 3/3 [00:00<00:00, 248.84it/s]
Epoch 11: train loss=0.00365583470972384

New model saved, val loss 0.00413021401957024 -> 0.004105454267457474


Epoch 20: train loss=0.0036980367418545396, train accuracy=0.6353017521090201: 100%|██| 25/25 [00:00<00:00, 101.65it/s]
Epoch 20: val loss=0.004099061929209288, val accuracy=0.5697674418604651: 100%|█████████| 3/3 [00:00<00:00, 239.20it/s]


New model saved, val loss 0.004105454267457474 -> 0.004099061929209288


Epoch 21: train loss=0.0036704860903240505, train accuracy=0.6288124594419209: 100%|██| 25/25 [00:00<00:00, 101.66it/s]
Epoch 21: val loss=0.004152823759372844, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 242.66it/s]
Epoch 22: train loss=0.003638769679453217, train accuracy=0.6353017521090201: 100%|███| 25/25 [00:00<00:00, 100.66it/s]
Epoch 22: val loss=0.00418674893850504, val accuracy=0.5930232558139535: 100%|██████████| 3/3 [00:00<00:00, 239.37it/s]
Epoch 23: train loss=0.003677395152473202, train accuracy=0.6255678131083712: 100%|████| 25/25 [00:00<00:00, 94.32it/s]
Epoch 23: val loss=0.004162428597378177, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 229.47it/s]
Epoch 24: train loss=0.00372271457328029, train accuracy=0.6301103179753407: 100%|█████| 25/25 [00:00<00:00, 95.10it/s]
Epoch 24: val loss=0.004169519210970679, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 206.36it/s]
Epoch 25: train loss=0.00365367120809944

New model saved, val loss 0.004099061929209288 -> 0.004084218293428421


Epoch 45: train loss=0.003574960644488053, train accuracy=0.6398442569759896: 100%|███| 25/25 [00:00<00:00, 105.67it/s]
Epoch 45: val loss=0.004233074569424918, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 238.29it/s]
Epoch 46: train loss=0.003663880021675773, train accuracy=0.6359506813757301: 100%|███| 25/25 [00:00<00:00, 106.30it/s]
Epoch 46: val loss=0.004182740786047869, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 221.32it/s]
Epoch 47: train loss=0.0036104722679783045, train accuracy=0.6340038935756003: 100%|██| 25/25 [00:00<00:00, 106.58it/s]
Epoch 47: val loss=0.004189945584119752, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 259.83it/s]
Epoch 48: train loss=0.0035834780650661796, train accuracy=0.6443867618429591: 100%|██| 25/25 [00:00<00:00, 106.32it/s]
Epoch 48: val loss=0.004074531672305839, val accuracy=0.5930232558139535: 100%|█████████| 3/3 [00:00<00:00, 265.79it/s]


New model saved, val loss 0.004084218293428421 -> 0.004074531672305839


Epoch 49: train loss=0.003596805253638454, train accuracy=0.6417910447761194: 100%|███| 25/25 [00:00<00:00, 102.80it/s]
Epoch 49: val loss=0.004090035724085431, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 238.80it/s]
Epoch 50: train loss=0.0035973504272704773, train accuracy=0.6320571057754705: 100%|██| 25/25 [00:00<00:00, 106.38it/s]
Epoch 50: val loss=0.0041561210744602735, val accuracy=0.5813953488372093: 100%|████████| 3/3 [00:00<00:00, 239.01it/s]
Epoch 51: train loss=0.0036263509685539874, train accuracy=0.6411421155094095: 100%|██| 25/25 [00:00<00:00, 105.39it/s]
Epoch 51: val loss=0.004095247145309005, val accuracy=0.5697674418604651: 100%|█████████| 3/3 [00:00<00:00, 241.18it/s]
Epoch 52: train loss=0.003573281067984975, train accuracy=0.6430889033095393: 100%|███| 25/25 [00:00<00:00, 105.51it/s]
Epoch 52: val loss=0.004109212652195332, val accuracy=0.6046511627906976: 100%|█████████| 3/3 [00:00<00:00, 248.27it/s]
Epoch 53: train loss=0.00365032421003139

New model saved, val loss 0.004074531672305839 -> 0.004006221110737601


Epoch 54: train loss=0.003604464875963274, train accuracy=0.6346528228423102: 100%|███| 25/25 [00:00<00:00, 106.25it/s]
Epoch 54: val loss=0.004177480416242467, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 239.03it/s]
Epoch 55: train loss=0.0036314222756180957, train accuracy=0.6314081765087606: 100%|██| 25/25 [00:00<00:00, 105.76it/s]
Epoch 55: val loss=0.004127488839764928, val accuracy=0.5697674418604651: 100%|█████████| 3/3 [00:00<00:00, 195.92it/s]
Epoch 56: train loss=0.0035856849357418083, train accuracy=0.6320571057754705: 100%|██| 25/25 [00:00<00:00, 107.13it/s]
Epoch 56: val loss=0.004101856503375741, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 259.60it/s]
Epoch 57: train loss=0.003661104934235659, train accuracy=0.63659961064244: 100%|█████| 25/25 [00:00<00:00, 105.16it/s]
Epoch 57: val loss=0.004112706974495289, val accuracy=0.5813953488372093: 100%|█████████| 3/3 [00:00<00:00, 229.34it/s]
Epoch 58: train loss=0.00356331770293206

Epoch 88: train loss=0.0035555663722277773, train accuracy=0.6378974691758599: 100%|██| 25/25 [00:00<00:00, 102.45it/s]
Epoch 88: val loss=0.004096621864063795, val accuracy=0.5697674418604651: 100%|█████████| 3/3 [00:00<00:00, 260.09it/s]
Epoch 89: train loss=0.003571536230004352, train accuracy=0.6437378325762492: 100%|████| 25/25 [00:00<00:00, 93.46it/s]
Epoch 89: val loss=0.004142723890931108, val accuracy=0.5872093023255814: 100%|█████████| 3/3 [00:00<00:00, 245.90it/s]
Epoch 90: train loss=0.0035631358430108645, train accuracy=0.6359506813757301: 100%|███| 25/25 [00:00<00:00, 93.57it/s]
Epoch 90: val loss=0.0041223127654818606, val accuracy=0.5813953488372093: 100%|████████| 3/3 [00:00<00:00, 219.33it/s]
Epoch 91: train loss=0.003626017331458159, train accuracy=0.6327060350421804: 100%|████| 25/25 [00:00<00:00, 96.94it/s]
Epoch 91: val loss=0.004050551978654639, val accuracy=0.5581395348837209: 100%|█████████| 3/3 [00:00<00:00, 239.48it/s]
Epoch 92: train loss=0.00368884864331529

In [33]:
for ds_name in mlp_models:
    # test
    best_model = mlp_models[ds_name]['model']
    test_dl = dataloaders[ds_name]['test']
    acc, confusion = evaluate(best_model, test_dl)
    mlp_models[ds_name]['accuracy'] = acc
    mlp_models[ds_name]['confusion'] = confusion
    print(f'mlp accuracy on {ds_name} dataset had accuracy of {acc:.4f}')

mlp accuracy on clinical dataset had accuracy of 0.7173
mlp accuracy on gene_expr dataset had accuracy of 0.5969
mlp accuracy on gene_expr_reduced dataset had accuracy of 0.6440


In [40]:
print(mlp_models['clinical']['confusion'])
labels=[int(mlp_models['clinical']['confusion'][0][0]),int(mlp_models['clinical']['confusion'][0][1]),
        int(mlp_models['clinical']['confusion'][1][0]),int(mlp_models['clinical']['confusion'][1][1])]
print(labels)
sns.heatmap(mlp_models['clinical']['confusion'],annot=labels)
plt.show()

[[91. 37.]
 [17. 46.]]
[91, 37, 17, 46]


ValueError: `data` and `annot` must have same shape.

# Model Comparison (10 Points)

Compare different models and different datasets (clinical, gene expressions, and gene reduced expressions) and try to explain their differences.

#### \# TODO