In [None]:
import copy
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.io import read_image

from tqdm import tqdm
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import wandb
from ece9603_project import lossFunctions

device = "cpu"
if torch.cuda.is_available():
    device="cuda"

In [None]:
# Run once to setup connection to wandb
wandb.login()

In [None]:
# Setup sweep hyperparameters
sweep_config = {
    'name': 'loss_function_evaluation',
    'method': 'grid',
    'parameters': {
        'learning_rate': {
            'values': [0.001]
        },
        'epochs': {
            'values': [10]
        },
        'loss_function': {
            'values': ['dice_loss',
                       'bce_dice_loss',
                       'jaccard_iou_loss',
                       'focal_loss',
                       'tversky_loss',
                       'focal_tversky_loss',
                       'bce_with_logits_loss']
        }
    }
}
sweep_id = wandb.sweep(sweep_config,
                       project="breast-histopathology-classification",
                       entity="ece9603_project")

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.info = df
        self.transform = transform

    def __len__(self):
        return len(self.info)

    def __getitem__(self, idx):
        path = self.info.imgPath.values[idx]
        label = self.info['class'].values[idx]
        image = read_image(path, mode=torchvision.io.image.ImageReadMode.RGB).float()

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
df= pd.read_csv("breastCancerDataframe.csv", index_col=0)
print(df.head())

patientIDs = df.patient.unique()
print("Number of Unique Patients: ", len(patientIDs))

patients_train, temp = train_test_split(patientIDs, test_size=0.3, random_state=42)
patients_val, patients_test = train_test_split(temp, test_size=0.5, random_state=42)

df_train = df.loc[df['patient'].isin(patients_train)]
print(df_train.head())
print("Number of Train Patients: ", df_train.patient.nunique())

df_val = df.loc[df['patient'].isin(patients_val)]
print(df_val.head())
print("Number of Validation Patients: ", df_val.patient.nunique())

df_test = df.loc[df['patient'].isin(patients_test)]
print(df_test.head())
print("Number of Test Patients: ", df_test.patient.nunique())

In [None]:
BATCH_SIZE = 128

transform = transforms.Compose([
        #transforms.RandomRotation(45),
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_dataset = CustomImageDataset(df_train, transform=transform)
val_dataset = CustomImageDataset(df_val, transform=transform)
test_dataset = CustomImageDataset(df_test, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [None]:
# Calculate performance measures
def compute_performance(yhat, y, pos_cutoff, evaluation_phase='validation'):

    # First, get tp, tn, fp, fn
    tp = sum(np.logical_and(yhat >= pos_cutoff, y == 1).numpy())
    tn = sum(np.logical_and(yhat < pos_cutoff, y == 0).numpy())
    fp = sum(np.logical_and(yhat >= pos_cutoff, y == 0).numpy())
    fn = sum(np.logical_and(yhat < pos_cutoff, y == 1).numpy())

    print(f"tp: {tp} tn: {tn} fp: {fp} fn: {fn}")

    # Precision
    # "Of the ones I labeled +, how many are actually +?"
    precision = tp / (tp + fp)

    # Recall
    # "Of all the + in the data, how many do I correctly label?"
    recall = tp / (tp + fn)

    # Sensitivity
    # "Of all the + in the data, how many do I correctly label?"
    sensitivity = recall

    # Specificity
    # "Of all the - in the data, how many do I correctly label?"
    specificity = tn / (fp + tn)

    balanced_acc = 0.5*(sensitivity+specificity)

    auroc = roc_auc_score(y, yhat)

    # Print results
    print("Balanced Accuracy: ", balanced_acc," Specificity: ",specificity, " AUROC Score: ", auroc,
          " Sensitivity: ", sensitivity," Precision: ", precision)

    # Log results to WandB
    wandb.log({"{} Balanced Accuracy".format(evaluation_phase): balanced_acc,
               "{} Specificity".format(evaluation_phase): specificity,
               "{} Sensitivity".format(evaluation_phase): sensitivity,
               "{} Precision".format(evaluation_phase): precision,
               "{} AUROC Score".format(evaluation_phase): auroc},
              commit=False)

def init_model():
    model = models.efficientnet_b0(pretrained=True)

    for param in model.parameters():
        param.requires_grad = False

    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(1280, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),

        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),

        nn.Linear(256, 1))

    def init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    model.apply(init_weights)

    return model

def train(model, dataloader_train, dataloader_val, device='cpu', epochs=10, early_stop=2, lr=0.001,
          loss_function='bce_with_logits_loss', verbose=True):

    opt = torch.optim.Adam(model.classifier.parameters(), lr=lr)
    criterion = lossFunctions.getLossFunction(loss_function)
    criterion.to(device)
    model.to(device)

    lowest_val_loss, train_loss = np.inf, 0
    lowest_val_epoch = 0
    epochs_wo_improvement = 0
    best_model = copy.deepcopy(model.state_dict())
    train_losses, val_losses=[], []
    train_preds, train_targets_list = [], []

    for e in range(epochs):
        epoch_train_loss = 0
        epoch_val_loss = 0

        model.train()

        for inputs, targets in tqdm(dataloader_train):

            inputs, targets = inputs.to(device), targets.to(device)

            model.zero_grad(set_to_none=True)

            train_output = model.forward(inputs).squeeze()
            train_preds+=train_output
            train_targets_list+=targets
            loss = criterion(train_output, targets.float())
            loss.backward()
            opt.step()

            epoch_train_loss+=loss

        compute_performance(torch.sigmoid(torch.Tensor(train_preds)), torch.Tensor(train_targets_list), 0.5, evaluation_phase='training')
        epoch_train_loss = epoch_train_loss.item()/((len(dataloader_train.dataset)%BATCH_SIZE)*BATCH_SIZE)

        train_losses.append(epoch_train_loss)

        #VALIDATION

        model.eval()
        model.zero_grad(set_to_none=True)
        val_preds, val_targets_list = [], []

        with torch.no_grad():
            for val_inputs, val_targets in tqdm(dataloader_val):

                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

                val_output = model.forward(val_inputs).squeeze()
                val_preds+=val_output
                val_targets_list+=val_targets

                epoch_val_loss += criterion(val_output, val_targets.float())

            epoch_val_loss = epoch_val_loss.item()/((len(dataloader_val.dataset)%BATCH_SIZE)*BATCH_SIZE)
            val_losses.append(epoch_val_loss)

            compute_performance(torch.sigmoid(torch.Tensor(val_preds)), torch.Tensor(val_targets_list), 0.5,
                                evaluation_phase='validation')

        if epoch_val_loss <= lowest_val_loss:
            best_model = copy.deepcopy(model.state_dict())
            lowest_val_loss = epoch_val_loss
            train_loss=epoch_train_loss
            lowest_val_epoch=e
            epochs_wo_improvement=0
        else:
            epochs_wo_improvement+=1

        if verbose:
            print("Epoch: {}/{}...".format(e, epochs), "Loss: {:.4f}...".format(epoch_train_loss), "Val Loss: {:.4f}".format(epoch_val_loss),)

        # Log to wandb project
        wandb.log({"training_loss": epoch_train_loss, "validation_loss": epoch_val_loss})

        #early stopping
        if epochs_wo_improvement>=early_stop:
            if verbose:
                print("Early Stop no improvement in validation loss in "+str(early_stop)+" validation steps")
            break

    if verbose:
        print("\nLowest Validation Loss: "+str(lowest_val_loss)+" at epoch "+str(lowest_val_epoch)+'\n')

    model.load_state_dict(best_model)

    # Record model to wandb
    wandb.watch(model)

    run_ID = datetime.now().strftime("%Y-%m-%d_%H-%M")
    torch.save({'model_state_dict': best_model}, './BestModels/'+str(run_ID)+'_'+loss_function+'_E_'+str(lowest_val_epoch)+'_TL_'+str(round(train_loss, 4))+'_VL_'+str(round(lowest_val_loss, 4))+'.pt')

    torch.cuda.empty_cache()

    return model

def test_model(model, test_dataloader, loss_function='bce_with_logits_loss', device='cuda'):
    #TESTING PHASE
    model.eval()
    model.to(device)
    test_preds, test_targets_list = [], []
    test_loss = 0
    criterion = lossFunctions.getLossFunction(loss_function)
    model.zero_grad(set_to_none=True)

    with torch.no_grad():

        for test_inputs, test_targets in tqdm(test_dataloader):

            test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)

            test_output = model.forward(test_inputs).squeeze()
            test_preds+=test_output
            test_targets_list+=test_targets

            test_loss += criterion(test_output, test_targets.float())

        test_loss = test_loss.item()/((len(test_dataloader.dataset)%BATCH_SIZE)*BATCH_SIZE)
        print("Test Loss: ", test_loss)

        compute_performance(torch.sigmoid(torch.Tensor(test_preds)), torch.Tensor(test_targets_list), 0.5, evaluation_phase='testing')
        # Log to wandb project
        wandb.log({"test_loss": test_loss})

In [None]:
def trainAndTestModel():
    config_defaults = {
        "learning_rate": 0.001,
        "epochs": 10,
        "loss_function": "bce_with_logits_loss"
    }
    wandb.init(project="breast-histopathology-classification",
               entity="ece9603_project",
               job_type="model_training",
               config=config_defaults)

    config = wandb.config
    single_model = init_model()
    torch.backends.cudnn.benchmark = True
    trained_model = train(single_model, train_dataloader, val_dataloader, loss_function=config.loss_function, device='cuda')
    test_model(trained_model, test_dataloader, loss_function=config.loss_function)

    # Done this training run
    wandb.finish()

In [None]:
# Run the sweep
wandb.agent(sweep_id, trainAndTestModel)