# **Camelyon-17 Experiments**

## **Dependencies**

In [1]:
!pip install wilds

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper

import sys
import os
from google.colab import drive, output
from collections import defaultdict
import copy

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
os.makedirs('./data', exist_ok=True)
os.makedirs('./data/camelyon17_v1.0', exist_ok=True)
drive.mount('/content/drive')

!tar -xvzf /content/drive/MyDrive/research/biostat/data/camelyon17_v1.0.tar.gz -C ./data/camelyon17_v1.0

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
./patches/patient_073_node_1/patch_patient_073_node_1_x_3744_y_40608.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_2528_y_39168.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_14688_y_17536.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_1312_y_42240.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_4480_y_41376.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_6944_y_40480.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_4864_y_42112.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_3680_y_44864.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_7776_y_40288.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_14176_y_18720.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_6688_y_40160.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_4864_y_40096.png
./patches/patient_073_node_1/patch_patient_073_node_1_x_1

## **Data Processing**

In [None]:
camelyon17_dir = "./data"
dataset = get_dataset(dataset="camelyon17", root_dir = camelyon17_dir, download = False)
print(dataset.metadata_fields)

In [None]:
height = 224
train_transform = transforms.Compose(
        [ # add some augmentation
            transforms.Resize((height, height)), transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ] # imagenet params
    )
eval_transform = transforms.Compose(
        [transforms.Resize((height, height)), transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] # imagenet params
    )

train_data = dataset.get_subset(
    "train",
    transform=train_transform
)
val_data = dataset.get_subset(
    "val",
    transform=eval_transform
)
# val_data, _ = data.random_split(val_data, [VAL_SIZE, len(val_data) - VAL_SIZE])

test_data = dataset.get_subset(
    "test",
    transform=eval_transform
)
# test_data, _ = data.random_split(test_data, [TEST_SIZE, len(test_data) - TEST_SIZE])

In [None]:
batch_size = 64
train_loader = get_train_loader("standard", train_data, batch_size = batch_size)
val_loader = get_eval_loader("standard", val_data, batch_size = batch_size)
test_loader = get_eval_loader("standard", test_data, batch_size = batch_size)

In [None]:
# custom data

grouper = CombinatorialGrouper(dataset, ['hospital'])

# lets use 10k samples
TRAIN_SIZE = 5000
remaining_samples = TRAIN_SIZE

print(f"Total samples remaining: {remaining_samples}")
train_x = []
train_y = []
train_z = []
for x, y, metadata in train_loader:
    
    if remaining_samples > 0:
        train_x.append(x)
        train_y.append(y)
        train_z.append(grouper.metadata_to_group(metadata)) # hospital
    else:
        break

    remaining_samples -= batch_size
    output.clear()
    print(f"Total samples remaining: {remaining_samples}")

train_x = torch.cat(train_x)
train_y = torch.cat(train_y)
train_z = torch.cat(train_z) # if want domain

In [None]:
# custom data

grouper = CombinatorialGrouper(dataset, ['hospital'])
VAL_SIZE = 2500


# lets use 10k samples
remaining_samples = VAL_SIZE

print(f"Total samples remaining: {remaining_samples}")
val_x = []
val_y = []
val_z = []
for x, y, metadata in val_loader:
    # get label = 0
    if remaining_samples > VAL_SIZE // 2 and y.sum() != 0:
        # print("sub")
        val_x.append(x)
        val_y.append(y)
        val_z.append(grouper.metadata_to_group(metadata)) # hospital
        remaining_samples -= batch_size
    # get label = 1
    elif remaining_samples <= VAL_SIZE // 2 and remaining_samples >= 0 and y.sum() == 0:
        val_x.append(x)
        val_y.append(y)
        val_z.append(grouper.metadata_to_group(metadata)) # hospital
        remaining_samples -= batch_size
    elif remaining_samples >= 0:
        continue
    else:
        break

    output.clear()
    print(f"Total samples remaining: {remaining_samples}")

val_x = torch.cat(val_x)
val_y = torch.cat(val_y)
val_z = torch.cat(val_z) # if want domain

In [None]:
# train on hosp 0, hosp 3, hosp 4 subset, eval on hosp 1, test on hosp 2# custom data

grouper = CombinatorialGrouper(dataset, ['hospital'])
TEST_SIZE = 7500
# lets use 10k samples
remaining_samples = TEST_SIZE

print(f"Total samples remaining: {remaining_samples}")
test_x = []
test_y = []
test_z = []
for x, y, metadata in test_loader:
    # get label = 0
    if remaining_samples > TEST_SIZE // 2 and y.sum() != 0:
        # print("sub")
        test_x.append(x)
        test_y.append(y)
        test_z.append(grouper.metadata_to_group(metadata)) # hospital
        remaining_samples -= batch_size
    # get label = 1
    elif remaining_samples <= TEST_SIZE // 2 and remaining_samples >= 0 and y.sum() == 0:
        test_x.append(x)
        test_y.append(y)
        test_z.append(grouper.metadata_to_group(metadata)) # hospital
        remaining_samples -= batch_size
    elif remaining_samples >= 0:
        continue
    else:
        break

    output.clear()
    print(f"Total samples remaining: {remaining_samples}")

test_x = torch.cat(test_x)
test_y = torch.cat(test_y)
test_z = torch.cat(test_z) # if want domain


In [None]:
print(val_y.sum(), test_y.sum())

In [None]:
class CamelyonDataset(Dataset):
    def __init__(self, x, y, z):
        """
        Tensors to pass:
        x -- data of interest
        y -- outcome
        """
        self.x = x
        self.y = y
        self.z = z # to match val/test loaders

    def __getitem__(self, index):
        x = self.x[index, ...]
        y = self.y[index]
        z = self.z[index]
        return x, y, z 
    
    def __len__(self):
        return len(self.x) # number of

In [None]:
batch_size = 64
train_dataset = CamelyonDataset(train_x, train_y, train_z)
val_dataset = CamelyonDataset(val_x, val_y, val_z)
test_dataset = CamelyonDataset(test_x, test_y, test_z)

train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)
val_loader = DataLoader(val_dataset, shuffle = False, batch_size = batch_size)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)

## **Model and Loss**

### **Loss**

In [None]:
# https://github.com/HobbitLong/SupContrast/blob/master/losses.py
"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR
    
    -- Ryan
    Modifications made as per https://github.com/HobbitLong/SupContrast/issues/104
    """
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        # normalize
        features = torch.nn.functional.normalize(features, p=2, dim=2)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one': # compare to one anchor
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all': # compare all
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask*log_prob).sum(1)/(mask.sum(1)+1e-6)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

### **Model**

In [None]:
class ResNetFeat(nn.Module):
    """
    resnet feature extractor
    """
    def __init__(self, num_channels, num_classes):
        super(ResNetFeat, self).__init__()
        self.resnet = models.resnet50(weights = None)#'IMAGENET1K_V2') # finetuning
        self.resnet.conv1 = nn.Conv2d(num_channels, 64, 7, stride=2, padding=3, bias=False)
        self.resnet.fc = nn.Linear(2048, num_classes) 
        self.feature_extractor = torch.nn.Sequential(*list(self.resnet.children())[:-1])

    def forward(self, x):
        z = self.feature_extractor(x)
        return z.view(z.shape[0], -1).unsqueeze(1)

    # def forward(self, x):
    #     return self.resnet(x)

In [None]:
class ResNet50(nn.Module):
  def __init__(self, num_classes):
    # super class 
    super(ResNet50, self).__init__()
    self.resnet = models.resnet50(pretrained=False) # set with pretrained for now 

    # remove last layer: (fc): Linear(in_features=2048, out_features=1000, bias=True)
    self.features = nn.Sequential(*list(self.resnet.children())[:-1])

    # add layers 
    #model.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.fc = nn.Sequential(
        nn.Linear(in_features=self.resnet.fc.in_features, out_features=1000), 
        nn.BatchNorm1d(1000), 
        nn.Dropout(0.2), 
        nn.Linear(1000, 3)
    )
  
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

## **Training**

### **Training Utils**

In [None]:
def pgd_linf(model, x, y, criterion, eps, step_size, adv_steps, randomize = True):
    """ 
    PGD l-inf norm
    https://adversarial-ml-tutorial.org/adversarial_examples/
    """
    if randomize: 
        delta = torch.rand_like(x, requires_grad=True) # uniform random -> can start pertubations at different locations
        delta.data = delta.data * 2 * eps - eps
    else:
        delta = torch.zeros_like(x, requires_grad=True) # start w/ no noise
        
    for t in range(adv_steps):
        z = model(x + delta)
        #print(z.shape)
        loss = criterion(z, y) # perturbed loss
        loss.backward() # gradiets wrt delta 
        delta.data = (delta + step_size*delta.grad.detach().sign()).clamp(-eps,eps) # linf clamp
        delta.grad.zero_() # reset grads
    return delta.detach() # optimal perturbation projected onto linf ball


In [None]:

def epoch_standard(model, criterion, loader, epoch, optimizer = None, device = 'cpu', show_acc = False):
    """
    standard epoch
    """
    if optimizer:
        model.train()
        mode = 'Train'
    else:
        model.eval()
        mode = 'Val'

    train_loss = []
    batches = tqdm(enumerate(loader), total=len(loader))
    batches.set_description("Epoch NA: Loss (NA)")
    correct = 0
    count = 0
    for batch_idx, (x, y, meta) in batches:
        x, y = x.to(device), y.to(device)
        # outer minimization
        z = model(x)
        #print(z.shape)
        loss = criterion(z, y)
        #print(loss.item())
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss.append(loss.item())

        if show_acc:
            correct += (z.max(axis = 1).indices == y).float().sum()
            count += y.shape[0]

        train_loss.append(loss.item())

        acc = 100 * correct / count if show_acc else torch.tensor([-1.0])

        batches.set_description(
            "Epoch {:d}: {:s} Loss ({:.2e}) ACC ({:.2e})".format(
                epoch, mode, loss.item(), acc
            )
        )

    return np.mean(train_loss), acc.detach().cpu().numpy()#(100 * correct/count).detach().cpu().numpy()

def epoch_adversarial(model, criterion, loader, epoch, eps=0.2, step_size=1e-2, adv_steps=40, optimizer = None, device = 'cpu', show_acc = False):
    """
    eps -- l_inf bound
    step_size -- delta stepsize for inner maximization
    adv_steps -- number of steps of adversarial pertubation
    """
    if optimizer:
        model.train()
        mode = 'Train'
    else:
        model.eval()
        mode = 'Val'

    train_loss = []
    batches = tqdm(enumerate(loader), total=len(loader))
    batches.set_description("Epoch NA: Adversarial Loss (NA)")
    correct = 0
    count = 0
    for batch_idx, (x, y, meta) in batches:
        x, y = x.to(device), y.to(device)
        # inner maximization
        delta = pgd_linf(model, x, y, criterion, eps, step_size, adv_steps)

        # outer minimization
        z = model(x + delta)
        loss = criterion(z, y)
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss.append(loss.item())

        if show_acc:
            correct += (z.max(axis = 1).indices == y).float().sum()
            count += y.shape[0]
        acc = 100 * correct / count if show_acc else torch.tensor([-1.0])

        batches.set_description(
            "Adversarial Epoch {:d}: {:s} Loss ({:.2e}) ACC ({:.2e})".format(
                epoch, mode, loss.item(), acc
            )
        )
    return np.mean(train_loss), acc.detach().cpu().numpy()

### **Fitting the Model**

### **Fitting Utils**

In [None]:
def fit_model(net, criterion, train_loader, val_loader, optimizer, n_epochs, adv_kwargs = dict(), device = 'cpu', use_adv = True, show_acc = False, val_acc = False):
    ###### Train Model ######
    train_losses = {'standard':[], 'adversarial':[]}
    val_losses = {'standard':[], 'adversarial':[]}
    best_val = float("inf")
    best_epoch = 0
    best_acc = 0
    train_accs = []
    val_accs = []

    for epoch in tqdm(range(n_epochs)):

        # train 
        train_loss, train_acc = epoch_standard(net, criterion, train_loader, epoch, optimizer, device = device, show_acc = show_acc)
        train_losses['standard'].append(train_loss)
        if show_acc:
            train_accs.append(train_acc)
        if use_adv:
            train_loss_adv, _ = epoch_adversarial(net, criterion, train_loader, epoch, **adv_kwargs, optimizer = optimizer, device = device, show_acc = show_acc)
            train_losses['adversarial'].append(train_loss_adv)

        # eval 
        val_loss, val_acc = epoch_standard(net, criterion, val_loader, epoch, optimizer = None, device = device, show_acc = show_acc)
        val_losses['standard'].append(val_loss)
        if show_acc:
            val_accs.append(val_acc)
        # if use_adv:
        #     val_loss_adv, _ = epoch_adversarial(net, criterion, val_loader, epoch, **adv_kwargs, optimizer = None, device = device, show_acc = show_acc)
        #     val_losses['adversarial'].append(val_loss_adv)
        
        

        # retain best val
        if not val_acc:
            if best_val >= val_losses[VAL_TYPE][-1]:
                best_val = val_losses[VAL_TYPE][-1]
                best_epoch = epoch
                print(f"Updating at {best_epoch}")
                # save model parameter/state dictionary
                best_model = copy.deepcopy(net.state_dict())
        else:
            if best_acc <= val_accs[-1]:
                best_acc = val_accs[-1]
                best_epoch = epoch
                print(f"Updating at {best_epoch}")
                # save model parameter/state dictionary
                best_model = copy.deepcopy(net.state_dict())

    # load best weights
    print(f"Best epoch at {best_epoch} with {VAL_TYPE} loss: {best_val}")
    net.load_state_dict(best_model)
    return train_losses, val_losses, train_accs, val_accs


### **Model Training**

In [None]:
NUM_EPOCHS = 20
LR = 3e-4
VAL_TYPE = 'standard' # or 'adversarial'

###### Supervised Contrastive Loss ######
temp = 0.07

###### Adversarial Parameters ######
adv_kwargs = {
    'eps':0.2, 
    'step_size':1e-2, 
    'adv_steps':10
}

In [None]:
num_classes = len(torch.unique(train_loader.dataset.y))
num_channels = 3
models = {
    'ce' : ResNet50(num_classes).to(device),
    'ce_adv' : ResNet50(num_classes).to(device),
    'supcon': ResNetFeat(num_channels, num_classes).to(device),
    'supcon_adv': ResNetFeat(num_channels, num_classes).to(device)
}

criterions = {
    'ce' : nn.CrossEntropyLoss(),
    'ce_adv' : nn.CrossEntropyLoss(),
    'supcon': SupConLoss(temperature = temp, 
                       contrast_mode = 'all', 
                       base_temperature = 0.07),
    'supcon_adv': SupConLoss(temperature = temp, 
                       contrast_mode = 'all', 
                       base_temperature = 0.07)
}

train_histories = {}
val_histories = {}
train_acc_histories = {}
val_acc_histories = {}

# might need to redefine resnet as its own class
# supcon takes in features, so would need to input 2nd last layer features
for model in tqdm(models):
    optimizer = torch.optim.AdamW(models[model].parameters(), lr = LR)
    adv = True if model.split('_')[-1] == 'adv' else False
    show_acc = True if 'ce' in model else False
    val_acc = True if 'ce' in model else False
    train_hist, val_hist, train_accs, val_accs = fit_model(models[model], criterions[model], 
                                     train_loader, val_loader, 
                                     optimizer, NUM_EPOCHS, adv_kwargs, 
                                     device = device, use_adv = adv,
                                     show_acc = show_acc, val_acc = val_acc)
    train_histories[model] = train_hist
    val_histories[model] = val_hist
    if 'ce' in model:
        train_acc_histories[model] = train_accs 
        val_acc_histories[model] = val_accs

In [None]:
# finetune SupCon to CE
models_ft = {
    'ce_supcon' : ResNet50(num_classes).to(device),
    'ce_supcon_adv' : ResNet50(num_classes).to(device)
}
models_ft['ce_supcon'].features = copy.deepcopy(models['supcon'].feature_extractor)
models_ft['ce_supcon_adv'].features = copy.deepcopy(models['supcon_adv'].feature_extractor)

criterions_ft = {
    'ce_supcon' : nn.CrossEntropyLoss(),
    'ce_supcon_adv' : nn.CrossEntropyLoss(),
}

for model in models_ft:
    optimizer = torch.optim.AdamW(models_ft[model].parameters(), lr = LR)
    adv = False
    show_acc = True if 'ce' in model else False
    val_acc = True if 'ce' in model else False
    train_hist, val_hist, train_accs, val_accs = fit_model(models_ft[model], criterions_ft[model], 
                                     train_loader, val_loader, 
                                     optimizer, NUM_EPOCHS, adv_kwargs, 
                                     device = device, use_adv = adv,
                                     show_acc = show_acc, 
                                     val_acc = val_acc)
    train_histories[model] = train_hist
    val_histories[model] = val_hist
    if 'ce' in model:
        train_acc_histories[model] = train_accs 
        val_acc_histories[model] = val_accs

In [None]:
# BCE plots
model_list = ['ce', 'ce_adv', 'ce_supcon', 'ce_supcon_adv']
markers = ['o', '^', 'x', '.']
for i in range(2):
    plt.plot(list(range(NUM_EPOCHS)), train_histories[model_list[i]], c = 'b', marker = markers[i], linestyle='dashed', label = f"Train: {model_list[i]}")
    plt.plot(list(range(NUM_EPOCHS)), val_histories[model_list[i]], c = 'm', marker = markers[i], linestyle='dashed', label = f"Val: {model_list[i]}")
plt.legend()
plt.show()

In [None]:
# ACC plots
model_list = ['ce', 'ce_adv', 'ce_supcon', 'ce_supcon_adv']
markers = ['o', '^', 'x', '.']
for i in range(2):
    plt.plot(list(range(NUM_EPOCHS)), train_acc_histories[model_list[i]], c = 'b', marker = markers[i], linestyle='dashed', label = f"Train: {model_list[i]}")
    plt.plot(list(range(NUM_EPOCHS)), val_acc_histories[model_list[i]], c = 'm', marker = markers[i], linestyle='dashed', label = f"Val: {model_list[i]}")
plt.legend()
plt.show()

In [None]:
# SupCon plots
model_list = ['supcon', 'supcon_adv']
markers = ['o', '^']
for i in range(2):
    plt.plot(list(range(NUM_EPOCHS)), train_histories[model_list[i]], c = 'b', marker = markers[i], linestyle='dashed', label = f"Train: {model_list[i]}")
    plt.plot(list(range(NUM_EPOCHS)), val_histories[model_list[i]], c = 'm', marker = markers[i], linestyle='dashed', label = f"Val: {model_list[i]}")
plt.legend()
plt.show()

## **Interpretability**