## Install wandb if it hasn't been installed on the system yet

In [None]:
# only run this once
#! pip install wandb
! wandb login

In [None]:
import os
import copy
import pandas as pd
import numpy as np

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.io import read_image
from torchsummary import summary

from tqdm import tqdm
from datetime import datetime
from sklearn.model_selection import train_test_split

import wandb

device = "cpu"
if torch.cuda.is_available():
    device="cuda"

In [None]:
# As a small example, we can adjust this for future hyper parameter tuning
wandb_config = wandb.config = {
    "learning_rate": 0.001,
    "epochs": 10,
    "loss_function": "BCEWithLogitsLoss"
}

wandb.init(project="breast-histopathology-classification",
           entity="ece9603_project",
           job_type="model_training",
           config=wandb_config)

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.info = df
        self.transform = transform

    def __len__(self):
        return len(self.info)

    def __getitem__(self, idx):
        path = self.info.imgPath.values[idx]
        label = self.info['class'].values[idx]
        image = read_image(path, mode=torchvision.io.image.ImageReadMode.RGB).float()

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
df= pd.read_csv("breastCancerDataframe.csv", index_col=0)
print(df.head())

patientIDs = df.patient.unique()
print("Number of Unique Patients: ", len(patientIDs))

patients_train, rest = train_test_split(patientIDs, test_size=0.3, random_state=42)
patients_val, patients_test = train_test_split(rest, test_size=0.5, random_state=42)

df_train = df.loc[df['patient'].isin(patients_train)]
print(df_train.head())
print("Number of Train Patients: ", df_train.patient.nunique())

df_val = df.loc[df['patient'].isin(patients_val)]
print(df_val.head())
print("Number of Validation Patients: ", df_val.patient.nunique())

df_test = df.loc[df['patient'].isin(patients_test)]
print(df_test.head())
print("Number of Test Patients: ", df_test.patient.nunique())

In [None]:
transform = transforms.Compose([
        #transforms.RandomRotation(45),
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_dataset = CustomImageDataset(df_train, transform=transform)
val_dataset = CustomImageDataset(df_val, transform=transform)
test_dataset = CustomImageDataset(df_test, transform=transform)

BATCH_SIZE = 128

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

In [None]:
"""
model=models.resnet18(pretrained=True)
print(model)

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.BatchNorm1d(512),
    nn.Dropout(0.5),

    nn.Linear(512, 256),
    nn.ReLU(),
    nn.BatchNorm1d(256),
    nn.Dropout(0.5),

    nn.Linear(256, 1))

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
"""

In [None]:
model = models.efficientnet_b0(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

model.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(1280, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),

    nn.Dropout(0.2),
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),

    nn.Linear(256, 1))

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
#summary(model, (3,224,224))

In [None]:
# Calculate performance measures
def compute_performance(yhat, y, pos_cutoff):

    # First, get tp, tn, fp, fn
    tp = sum(np.logical_and(yhat >= pos_cutoff, y == 1).numpy())
    tn = sum(np.logical_and(yhat < pos_cutoff, y == 0).numpy())
    fp = sum(np.logical_and(yhat >= pos_cutoff, y == 0).numpy())
    fn = sum(np.logical_and(yhat < pos_cutoff, y == 1).numpy())

    print(f"tp: {tp} tn: {tn} fp: {fp} fn: {fn}")

    # Accuracy
    #acc = (tp + tn) / (tp + tn + fp + fn)

    # Precision
    # "Of the ones I labeled +, how many are actually +?"
    precision = tp / (tp + fp)

    # Recall
    # "Of all the + in the data, how many do I correctly label?"
    recall = tp / (tp + fn)

    # Sensitivity
    # "Of all the + in the data, how many do I correctly label?"
    sensitivity = recall

    # Specificity
    # "Of all the - in the data, how many do I correctly label?"
    specificity = tn / (fp + tn)

    balanced_acc = 0.5*(sensitivity+specificity)

    # Print results
    print("Balanced Accuracy: ", balanced_acc," Specificity: ",specificity,
          " Sensitivity: ",sensitivity," Precision: ",precision,)
    # Log results to WandB
    wandb.log({"Balanced Accuracy": balanced_acc,
               "Specificity": specificity,
               "Sensitivity": sensitivity,
               "Precision": precision},
              commit=False)

#fMeasure =  2*((performance[1]*performance[2])/(performance[1]+performance[2]))

In [None]:
def train(model, device='cpu', epochs=10, early_stop=2, lr=0.001, verbose=True):

    opt = torch.optim.Adam(model.classifier.parameters(), lr=lr)
    #opt = torch.optim.Adam(model.fc.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    model.to(device)

    lowest_val_loss = np.inf
    lowest_val_epoch = 0
    epochs_wo_improvement = 0
    best_model = copy.deepcopy(model.state_dict())
    train_losses, val_losses=[], []

    for e in range(epochs):
        epoch_train_loss = 0
        epoch_val_loss = 0

        model.train()
        #with torch.autograd.profiler.profile() as prof:
        for inputs, targets in tqdm(train_dataloader):

            inputs, targets = inputs.to(device), targets.to(device)

            model.zero_grad(set_to_none=True)

            output = model.forward(inputs)
            loss = criterion(output.squeeze(), targets.float())
            loss.backward()
            opt.step()

            epoch_train_loss+=loss

        #print(prof.key_averages().table(sort_by="self_cpu_time_total"))

        epoch_train_loss = epoch_train_loss.item()/((len(df_train)%BATCH_SIZE)*BATCH_SIZE)

        train_losses.append(epoch_train_loss)

        #VALIDATION

        model.eval()
        val_preds, val_targets_list = [], []

        with torch.no_grad():
            for val_inputs, val_targets in tqdm(val_dataloader):

                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

                model.zero_grad(set_to_none=True)

                val_output = model.forward(val_inputs).squeeze()
                val_preds+=val_output
                val_targets_list+=val_targets

                epoch_val_loss += criterion(val_output, val_targets.float())

            epoch_val_loss = epoch_val_loss.item()/((len(df_val)%BATCH_SIZE)*BATCH_SIZE)
            val_losses.append(epoch_val_loss)

            compute_performance(torch.sigmoid(torch.Tensor(val_preds)), torch.Tensor(val_targets_list), 0.5)

        if epoch_val_loss <= lowest_val_loss:
            best_model = copy.deepcopy(model.state_dict())
            lowest_val_loss = epoch_val_loss
            lowest_val_epoch=e
            epochs_wo_improvement=0
        else:
            epochs_wo_improvement+=1

        if verbose:
            print("Epoch: {}/{}...".format(e, epochs), "Loss: {:.4f}...".format(epoch_train_loss), "Val Loss: {:.4f}".format(epoch_val_loss),)
        # Log to wandb project
        wandb.log({"training_loss": epoch_train_loss,
                   "validation_loss": epoch_val_loss})
            
        #early stopping
        if epochs_wo_improvement>=early_stop:
            if verbose:
                print("Early Stop no improvement in validation loss in "+str(early_stop)+" validation steps")
            break

    if verbose:
        print("\nLowest Validation Loss: "+str(lowest_val_loss)+" at epoch "+str(lowest_val_epoch)+'\n')

    model.load_state_dict(best_model)
    # Record model to wandb
    wandb.watch(model)

    return model, train_losses, val_losses

torch.backends.cudnn.benchmark = True
model, train_losses, val_losses = train(model, device=device)
# Done this training run
wandb.finish()

In [None]:
torch.save({'model_state_dict': model.state_dict()}, './BestModels/E_4_TL_0.3239_VL_0.0101.pt')