# Notebook

In [None]:
# installs- comment out if good

!git clone https://github.com/wtlu71/cnn-comp-med.git

import numpy as np
import os
from tqdm import tqdm
import sys
import os
import pandas as pd
import seaborn as sns
from scipy.stats import ttest_rel, f_oneway
import statsmodels.api as sm
from statsmodels.stats.multicomp import pairwise_tukeyhsd

repo_path = '/content/cnn-comp-med'
# Add to Python path
sys.path.append(repo_path)
# imports
from my_scripts.test import potato
from my_scripts.my_models import SmallCNN, SmallMLP,LargeCNN
from my_scripts.dataset_loading import H5Dataset
from my_scripts.utils import run_epoch


import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader, Subset,TensorDataset
import torchvision.models as models
from torchvision.models import resnet50
from torchvision import datasets, transforms
from sklearn.model_selection import StratifiedShuffleSplit


RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed_all(RANDOM_STATE)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
weights_folder = "weights"
dataset_folder = os.path.join(os.getcwd(),"data")
external_dataset_folder = "breast-histopathology-images-subset"

In [None]:
!pip install -U gdown

In [3]:
os.makedirs(weights_folder,exist_ok=True)
os.makedirs(external_dataset_folder,exist_ok=True)
download = True # set true if need to download
if download:
  !gdown --folder --remaining-ok https://drive.google.com/drive/folders/1EnF-qAa2JMD66KRQCIs-xOUmWiXiDA11 -O weights #weights folder storing train,test,val stratified idx
  !gdown --folder --remaining-ok https://drive.google.com/drive/folders/14EiLf9I6FGopdydykxPLmVXsiltzBGDH -O data #saved data for deterministic results
  !gdown --id 1lAtil7jdHGi9MM6d3WbxALmYYWVC_2R5 -O breast-histopathology-images-subset.zip #external testing subset
  !unzip -q breast-histopathology-images-subset.zip

Retrieving folder contents
Retrieving folder 1TGWX-NN_KmssOnxQucKMhV_t_Fbxobsm .ipynb_checkpoints
Retrieving folder 18Lyiz4SQdaDn0Ot0oQj4lN9NKf5qkd_e resnet
Processing file 1Z72UXI0XmKiYEEQJDAyV3LFp15_mG4pM resnet.pth
Processing file 1HncVXpZ74ubG0dDDgOJ2mwZL8ANnMiCH efficientNet.pth
Processing file 19nhMhct1JG_61LIP5YLLEURSLT0i1Xkh LargeCNN_subset.pth
Processing file 1SXSm5-C9mwdvmPzfgOlcqj0UipK17PWg ResNet_subset.pth
Processing file 1XDoT88p-K5nU-gwBq-26c0OCOM9Nrzyp ResNet.pth
Processing file 14o0LoXRmt-YvXEE1fIqXUlvKmxoFtSxw SmallCNN_subset.pth
Processing file 1_CjNFjxdXJztdQAb6ASvm_mPQluqU_Jk SmallMLP_subset.pth
Processing file 1HkfPOp6pRIM6aYzYTyV7jDjRgueM5PjH test.idx.npy
Processing file 1j9CoCJtBHgbnvM4ZrT2mz4wqYQ341Hoz test.npy
Processing file 1QXXpAsAGwGknjcOl-j8-TjALetQObNdn train.idx.npy
Processing file 1Yl_L14-Q-tGEOEa0KOZwxOnGGulPxHtC train.npy
Processing file 1HLwctzyNbIAUcoIpwvh7RXNbA0f27S8V val.idx.npy
Processing file 1krzZg7q3pTGPbILC9U9nt4YCFblIRme_ val.npy
Retrieving

Get train/val/test data from Zenodo:

In [None]:
os.makedirs(dataset_folder,exist_ok=True)
# !apt-get install -y wget2
if download:
  !apt-get install -y aria2

  # !aria2c -x 16 -s 16 -k 1M -c -j 1 "URL" -o output_filename
  #train
  !aria2c -x 16 -s 16 -k 1M -c -j 1 "https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_train_x.h5.gz?download=1" -o data/train_x.h5.gz
  !gunzip data/train_x.h5.gz

  !wget https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_train_y.h5.gz?download=1 -O data/train_y.h5.gz
  !gunzip data/train_y.h5.gz

  # val
  !aria2c -x 16 -s 16 -k 1M -c -j 1 "https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_valid_x.h5.gz?download=1" -o data/valid_x.h5.gz
  !gunzip data/valid_x.h5.gz

  !wget https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_valid_y.h5.gz?download=1 -O data/valid_y.h5.gz
  !gunzip data/valid_y.h5.gz

  # test
  !aria2c -x 16 -s 16 -k 1M -c -j 1 "https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_test_x.h5.gz?download=1" -o data/test_x.h5.gz
  !gunzip data/test_x.h5.gz

  !wget https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_test_y.h5.gz?download=1 -O data/test_y.h5.gz
  !gunzip data/test_y.h5.gz

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
aria2 is already the newest version (1.36.0-1).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.

12/12 17:22:27 [[1;32mNOTICE[0m] Downloading 1 item(s)
[0m
12/12 17:22:29 [[1;32mNOTICE[0m] Shutdown sequence commencing... Press Ctrl-C again for emergency shutdown.

12/12 17:22:29 [[1;32mNOTICE[0m] Download GID#359719237273b2d6 not complete: /content/data/train_x.h5.gz

Download Results:
gid   |stat|avg speed  |path/URI
359719|[1;34mINPR[0m|    32MiB/s|/content/data/train_x.h5.gz

Status Legend:
(INPR):download in-progress.

aria2 will resume download if the transfer is restarted.
If there are any errors, then see the log file. See '-l' option in help/man page for details.

gzip: data/train_x.h5.gz: invalid compressed data--format violated
--2025-12-12 17:22:34--  https://zenodo.org/records/2546921/files/camelyonpatch_level_2_split_train_y.h5.gz?download=1
Resolving z

In [None]:
# define transforms

# is there a way to get this programmatically from the data or not worth it
IMG_SIZE = 96
# BATCH_SIZE = 2048
BATCH_SIZE = 1024
# Training transforms with augmentation
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
eval_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

])

print("Data transforms defined")

In [None]:
def load_dataset(path):
  dataset_dict = torch.load(path)
  dataset = TensorDataset(dataset_dict['images'],dataset_dict['labels'])
  return dataset


def save_dataset(dataset,save_path):
  all_images = []
  all_labels = []

  for img, label in dataset:
      all_images.append(img)
      all_labels.append(label)

  save_dict = {
      "images": torch.stack(all_images),
      "labels": torch.tensor(all_labels),
  }

  torch.save(save_dict, save_path)

In [None]:

# dataset paths- Colab virtual session
load_saved = True
pin_memory = torch.cuda.is_available()

if load_saved:
  train_path = os.path.join(dataset_folder,"train_dataset.pt")
  test_path = os.path.join(dataset_folder,"test_dataset.pt")
  val_path = os.path.join(dataset_folder,"val_dataset.pt")
  external_dataset_path = os.path.join(dataset_folder,"external_test_dataset.pt")
  train_dataset = load_dataset(train_path)
  val_dataset = load_dataset(test_path)
  test_dataset = load_dataset(val_path)
  external_test_dataset = load_dataset(external_dataset_path)

else:
  train_img_h5_path = os.path.join(dataset_folder,"train_x.h5")
  train_label_h5_path = os.path.join(dataset_folder,"train_y.h5")

  val_img_h5_path = os.path.join(dataset_folder,"valid_x.h5")
  val_label_h5_path = os.path.join(dataset_folder,"valid_y.h5")

  test_img_h5_path = os.path.join(dataset_folder,"test_x.h5")
  test_label_h5_path = os.path.join(dataset_folder,"test_y.h5")


  train_subset_size = 50000 #keeping compute restraints in mind
  eval_subset_size = 5000
  use_precomputed_indices = True

  train_dataset = H5Dataset(train_img_h5_path,train_label_h5_path,transform=train_transforms)
  train_labels = np.array([label for _, label in train_dataset])

  val_dataset = H5Dataset(val_img_h5_path,val_label_h5_path,transform=train_transforms)
  val_labels = np.array([label for _, label in val_dataset])

  test_dataset = H5Dataset(test_img_h5_path,test_label_h5_path,transform=train_transforms)
  test_labels = np.array([label for _, label in test_dataset])

  if use_precomputed_indices:
    train_idx = np.load(os.path.join(weights_folder,f"train.npy"))
    val_idx = np.load(os.path.join(weights_folder,f"val.npy"))
    test_idx = np.load(os.path.join(weights_folder,f"test.npy"))
    test_idx = np.load(os.path.join(weights_folder,f"test.npy"))
  else:
    sss = StratifiedShuffleSplit(n_splits=1, train_size=train_subset_size, random_state=RANDOM_STATE)
    train_idx, _ = next(sss.split(np.zeros(len(train_labels)), train_labels))  # indices for stratified subset
    np.save(os.path.join(weights_folder,f"train.npy"),train_idx)

    sss = StratifiedShuffleSplit(n_splits=1, train_size=eval_subset_size, random_state=RANDOM_STATE)
    val_idx, _ = next(sss.split(np.zeros(len(val_labels)), val_labels))  # indices for stratified subset
    np.save(os.path.join(weights_folder,f"val.npy"),val_idx)

    sss = StratifiedShuffleSplit(n_splits=1, train_size=eval_subset_size, random_state=RANDOM_STATE)
    test_idx, _ = next(sss.split(np.zeros(len(test_labels)), test_labels))  # indices for stratified subset
    np.save(os.path.join(weights_folder,f"test.npy"),test_idx)

  train_dataset = Subset(train_dataset, train_idx)
  val_dataset = Subset(val_dataset, val_idx)
  test_dataset = Subset(test_dataset, test_idx)
  #external is already stratified
  external_test_dataset = datasets.ImageFolder(external_dataset_folder,transform=eval_transforms)
  #ensure reproducibility of same stratified dataset used across runs
  save_dataset(train_dataset,"train_dataset.pt")
  save_dataset(val_dataset,"val_dataset.pt")
  save_dataset(test_dataset,"test_dataset.pt")
  save_dataset(external_test_dataset,"external_test_dataset.pt")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,pin_memory=pin_memory)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,pin_memory=pin_memory)
external_test_loader = DataLoader(external_test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,pin_memory=pin_memory)

In [None]:
# define/load models
num_classes = 2
mlpmodel = SmallMLP(size=((IMG_SIZE,IMG_SIZE))).to(device)
smallcnnmodel = SmallCNN(channels=(32,64,128)).to(device)
largecnnmodel = LargeCNN(channels=(32,64,128,64,64,64)).to(device)

resnetmodel = resnet50(weights = models.ResNet50_Weights.IMAGENET1K_V2).to(device)
# resnetmodel = resnet50(weights = None).to(device)
resnetmodel.fc = nn.Linear(resnetmodel.fc.in_features, num_classes)
# wresnetmodel = models.wide_resnet50_2(pretrained=True).to(device)
# wresnetmodel.fc = nn.Linear(wresnetmodel.fc.in_features, num_classes)

modellist = [mlpmodel, smallcnnmodel, resnetmodel]
modeldict = {
    "SmallMLP": mlpmodel,
    "SmallCNN": smallcnnmodel,
    "LargeCNN": largecnnmodel,
    "ResNet": resnetmodel
}

# training hyperparameters
# can change learning rate or use scheduler
WEIGHT_DECAY = 1e-4
LEARNING_RATE = 3e-4
# finetune with less epochs to avoid forgetting
# EPOCHS = 15
EPOCHS = 1
PATIENCE = EPOCHS
# patience- number of epochs the model continues after no improvement in validation loss

# consider other loss functions https://neptune.ai/blog/pytorch-loss-functions
criterion = nn.CrossEntropyLoss()
# don't change adam, never change dude

# optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
import torch, gc

gc.collect()
torch.cuda.empty_cache()


In [None]:

import matplotlib.pyplot as plt
historylist = []
stopepochs = []

for modelname, model in modeldict.items():
    print(f"Training {modelname}...")
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),weight_decay=WEIGHT_DECAY, lr=LEARNING_RATE)

    # try to integrate into wandb instead of storing this so we can have a pretty dashboard?
    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "train_auc": [],
        "val_acc": [],
        "val_auc": []
    }

    best_auc = -np.inf
    best_state = None
    bad_epochs = 0

    print("\nStarting training...\n")
    for epoch in tqdm(range(1, EPOCHS + 1)):
        tr_loss, tr_acc, _, _, tr_auc,_ = run_epoch(train_loader, model, criterion, optimizer=optimizer, train=True, device=device)
        va_loss, va_acc, va_sens, va_spec, va_auc,_ = run_epoch(val_loader, model, criterion, optimizer=None, train=False,device=device)

        history["train_loss"].append(tr_loss)
        history["val_loss"].append(va_loss)
        history["train_auc"].append(tr_auc)
        history["train_acc"].append(tr_acc)
        history["val_acc"].append(va_acc)
        history["val_auc"].append(va_auc)

        print(f"Epoch {epoch:02d}: "
            f"train_loss={tr_loss:.4f} "
            f"val_loss={va_loss:.4f} "
            f"val_acc={va_acc:.3f} "
            f"val_auc={va_auc:.3f}")

        if va_auc > best_auc + 1e-4:
            best_auc = va_auc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= PATIENCE:
                print(f"\nEarly stopping at epoch {epoch}")
                break

    # Restore best model
    if best_state is not None:
        weights_path = os.path.join(weights_folder,f"{modelname}_subset.pth")
        torch.save(best_state,weights_path)
        model.load_state_dict(best_state)
        print(f"\nRestored best model (val_auc={best_auc:.4f})")


    plt.plot(history["train_loss"],label="train_loss")
    plt.plot(history["val_loss"],label="val_loss")
    plt.xlabel("epochs")
    plt.ylabel("CrossEntropyLoss")
    plt.legend()
    plt.title(modelname)
    plt.show()
    plt.plot(history["train_acc"],label="train_acc")
    plt.plot(history["val_acc"],label="val_acc")
    plt.xlabel("epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.title(modelname)
    plt.show()
    plt.plot(history["train_auc"],label="train_auc")
    plt.plot(history["val_auc"],label="val_auc")
    plt.xlabel("epochs")
    plt.ylabel("AUC")
    plt.show()

    historylist.append(history)
    stopepochs.append(epoch)

In [None]:
import json
history_path = os.path.join(weights_folder,"history.json")
with open(history_path,'w') as f:
  json.dump(historylist,f)

In [None]:

for i,history in enumerate(historylist):
  modelname = list(modeldict.keys())[i]
  plt.plot(history["train_loss"],label=modelname)
  plt.xlabel("epochs")
  plt.ylabel("CrossEntropyLoss")
  plt.legend()
  plt.title("Train Loss")
  # plt.show()

plt.show()



In [None]:
for i,history in enumerate(historylist):
  modelname = list(modeldict.keys())[i]
  plt.plot(history["val_loss"],label=modelname)
  plt.xlabel("epochs")
  plt.ylabel("CrossEntropyLoss")
  plt.legend()
  plt.title("Val Loss")
  # plt.show()

plt.show()



In [None]:
for i,history in enumerate(historylist):
  modelname = list(modeldict.keys())[i]
  plt.plot(history["train_acc"],label=modelname)
  plt.xlabel("epochs")
  plt.ylabel("Accuracy")
  plt.legend()
  plt.title("Train Accuracy")
  plt.ylim([0,1])
  # plt.show()

plt.show()


In [None]:
for i,history in enumerate(historylist):
  modelname = list(modeldict.keys())[i]
  plt.plot(history["val_acc"],label=modelname)
  plt.xlabel("epochs")
  plt.ylabel("Accuracy")
  plt.legend()
  plt.title("Valid Accuracy")
  plt.ylim([0,1])
  # plt.show()

plt.show()

In [None]:
for i,history in enumerate(historylist):
  modelname = list(modeldict.keys())[i]
  plt.plot(history["train_auc"],label=modelname)
  plt.xlabel("epochs")
  plt.ylabel("AUC")
  plt.legend()
  plt.title("Train AUC")
  plt.ylim([0,1])
  # plt.show()

plt.show()

In [None]:
for i,history in enumerate(historylist):
  modelname = list(modeldict.keys())[i]
  plt.plot(history["val_auc"],label=modelname)
  plt.xlabel("epochs")
  plt.ylabel("AUC")
  plt.legend()
  plt.title("Val AUC")
  plt.ylim([0,1])
  # plt.show()

plt.show()

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# finally, evaluate on test set
# also do saliency

# model_attr_scores = []
# model_attr_channels = []
# model_preds = []
# model_targets = []

for modelname, model in modeldict.items():
  weights_path = os.path.join(weights_folder,f"{modelname}_subset.pth")
  model_state_dict = torch.load(weights_path)
  model.load_state_dict(model_state_dict)
  print(f"\nEvaluating {modelname}...")
  model.to(device)
  _, va_acc, va_sens, va_spec, va_auc,cm = run_epoch(test_loader, model, criterion, train=False, device=device)
  print(f"Final Validation Performance:")
  print(f"  AUC:         {va_auc:.4f}")
  print(f"  Accuracy:    {va_acc:.4f}")
  print(f"  Sensitivity: {va_sens:.4f}")
  print(f"  Specificity: {va_spec:.4f}")
  class_names = ["Normal", "Cancer"]
  disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                display_labels=class_names)
  disp.plot(cmap='Blues', values_format='d')
  plt.title(f"Confusion Matrix {modelname} Internal test set")
  plt.show()



In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


for modelname, model in modeldict.items():
    weights_path = os.path.join(weights_folder,f"{modelname}_subset.pth")
    model_state_dict = torch.load(weights_path)
    model.load_state_dict(model_state_dict)
    print(f"\nEvaluating {modelname}...")
    model.to(device)
    _, va_acc, va_sens, va_spec, va_auc,cm = run_epoch(external_test_loader, model, criterion, train=False, device=device)
    print(f"Final Validation Performance:")
    print(f"  AUC:         {va_auc:.4f}")
    print(f"  Accuracy:    {va_acc:.4f}")
    print(f"  Sensitivity: {va_sens:.4f}")
    print(f"  Specificity: {va_spec:.4f}")
    class_names = ["Normal", "Cancer"]
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=class_names)
    disp.plot(cmap='Blues', values_format='d')
    plt.title(f"Confusion Matrix {modelname} External test set")
    plt.show()

In [None]:
# evaluating and saliency
!pip install torch torchvision captum pillow matplotlib
# !pip install captum

import torch
from torchvision import models, transforms
from captum.attr import Saliency
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

model_attr_scores = []
model_attr_channels = []
model_preds = []
model_targets = []

for modelname, model in modeldict.items():
  weights_path = os.path.join(weights_folder,f"{modelname}_subset.pth")
  model_state_dict = torch.load(weights_path)
  model.load_state_dict(model_state_dict)
  print(f"\nEvaluating saliency for {modelname}...")
  model.to(device)
  print(f"Saliency for {modelname}")
  model.eval()
  saliency = Saliency(model)

  attr_scores_list = []
  attr_channels_list = []
  pred_list = []
  target_list = []

  for images, targets in test_loader:
    images = images.to(device, non_blocking=True)
    targets = targets.to(device, non_blocking=True)
    images.requires_grad_()

    with torch.set_grad_enabled(True):
      logits = model(images)
      # loss = criterion(logits, targets)
      pred_classes = logits.argmax(dim=1)
      attr = saliency.attribute(images, target=pred_classes)
      # attr = attr.abs().cpu().detach().numpy()
      attr = attr.abs()
      attr_scores = attr.max(dim=1)[0]    # (batch_size, H, W)
      attr_channel = attr.argmax(dim=1)   # (batch_size, H, W)
    pred_list.append(pred_classes.cpu())
    target_list.append(targets.cpu())
    attr_scores_list.append(attr_scores)
    attr_channels_list.append(attr_channel)

  all_attr_scores = torch.cat(attr_scores_list, dim=0)
  all_attr_channels = torch.cat(attr_channels_list, dim=0)
  model_attr_scores.append(all_attr_scores)
  model_attr_channels.append(all_attr_channels)
  model_preds.append(torch.cat(pred_list, dim=0))
  model_targets.append(torch.cat(target_list, dim=0))

In [None]:
import torch.nn.functional as F

center_size = 32
pool_center_mask = torch.zeros((3, 3)).bool()
pool_center_mask[1, 1] = True

center_mask = torch.zeros((96, 96)).bool()
center_mask[32:64, 32:64] = True

all_sal_results = []

# assess the saliency score maps for each model
for i, (saliencies, imp_channels) in enumerate(zip(model_attr_scores, model_attr_channels)):
  # correct shape for avgpool
  sal = saliencies.unsqueeze(1)
  avgpool_sal = F.avg_pool2d(sal, kernel_size = center_size, stride = center_size).squeeze(1)
  center_avg = avgpool_sal[:, pool_center_mask].view(-1) # center
  outer_avg = avgpool_sal[:, ~pool_center_mask].view(len(avgpool_sal), 8).mean(dim=1) # outer
  center_focused = center_avg > outer_avg # vector for each image
  ratio = center_avg/outer_avg
  print(f"length of centeravg is {center_avg.shape}")
  print(f"length of outeravg is {outer_avg.shape}")
  print(f"length of centerfocused is {center_focused.shape}")
  print(f"length of ratio is {ratio.shape}")

  # now get most frequent channel in center and in outer
  imp_channels_center = imp_channels[:, center_mask]
  imp_channels_outer = imp_channels[:, ~center_mask]
  imp_channel_center = torch.mode(imp_channels_center.reshape(len(imp_channels), -1), dim=1).values
  imp_channel_outer = torch.mode(imp_channels_outer.reshape(len(imp_channels), -1), dim=1).values

  for j in range(len(center_avg)):
    row = {
        "modelname": list(modeldict.keys())[i],
        "image index": j,
        "center": center_avg[j].item(),
        "outer": outer_avg[j].item(),
        "center focused": center_focused[j].item(),
        "center ratio": ratio[j].item(),
        "important channel center": imp_channel_center[j].item(),
        "important channel outer": imp_channel_outer[j].item(),
    }
    all_sal_results.append(row)
    # sal_results = {
    #     "model name": [list(modeldict.keys())[i]]*len(center_avg.cpu()), # repeat model name for each image
    #     "image index":
    #     "center": center_avg.cpu(),
    #     "outer": outer_avg.cpu(),
    #     "center focused": center_focused.cpu()
    # }
    # all_sal_results.append(sal_results)

In [None]:
# bar plot- four groups two bars each- for each model, avg. center score vs. avg outer score
sal_df = pd.DataFrame(all_sal_results)
# melt the center and outer columns together
sal_melt_df = sal_df.melt(id_vars=["modelname", "image index", "center focused"], value_vars=["center", "outer"], value_name="saliency", var_name="region")
plt.figure(figsize=(8, 5))
p = sns.barplot(data=sal_melt_df, x="modelname", y="saliency", hue="region", errorbar="sd")
p.set(ylabel="Avg Saliency Score", title="Center vs. Outer Saliency by Model")
# p.tick_params(axis='x', rotation=30)
for container in p.containers:
  p.bar_label(container, fmt='%.3f', padding=-10)
plt.show()
# paired t-test of center vs. outer for each model
models = sal_df["modelname"].unique()
for model in models:
    sub = sal_df[sal_df["modelname"] == model]
    tval, pval = ttest_rel(sub["center"], sub["outer"])
    print(f"{model}: t={tval:.3f}, p={pval:.4f}")

# plot center vs. outer focus ratio for each model
p = sns.barplot(data=sal_df, x="modelname", y="center ratio")
p.set(ylabel="Center vs. Outer Saliency Ratio", title="Center vs. Outer Saliency Ratio by Model")
# p.tick_params(axis='x', rotation=30)
for container in p.containers:
  p.bar_label(container, fmt='%.3f', padding=-10)
plt.show()
#1-way ANOVA and Tukey post-hoc for significant difference in center focus b/w models
groups = [sal_df[sal_df["modelname"] == model]["center ratio"]
          for model in models]
Fval, pval = f_oneway(*groups)
print(f"ANOVA: F = {Fval:.4f}, p = {pval:.4e}")
tukey = pairwise_tukeyhsd(endog=sal_df["center ratio"], groups=sal_df["modelname"], alpha=0.05)
print(tukey)

# count how many images were center-focused vs. not
hue_order = [True, False]
p = sns.countplot(data=sal_df, x="modelname", hue="center focused", hue_order=hue_order)
p.set(ylabel="Number of Images Center-Focused", title="Count of Images Where Model Focused on Center")
# p.tick_params(axis='x', rotation=30)
for container in p.containers:
  p.bar_label(container, padding=-10)
plt.show()

# confusion matrix- 2x4- center-focused or not on one axis, TP, TN, FP, FN on the other
sal_df["Pred Class"] = torch.cat(model_preds, dim=0).cpu().numpy()
sal_df["Target Class"] = torch.cat(model_targets, dim=0).cpu().numpy()

sal_df["Class Result"] = -1

sal_df.loc[(sal_df["Pred Class"] == 1) & (sal_df["Target Class"] == 1), "Class Result"] = 0 # TP
sal_df.loc[(sal_df["Pred Class"] == 0) & (sal_df["Target Class"] == 0), "Class Result"] = 1 # TN
sal_df.loc[(sal_df["Pred Class"] == 1) & (sal_df["Target Class"] == 0), "Class Result"] = 2 # FP
sal_df.loc[(sal_df["Pred Class"] == 0) & (sal_df["Target Class"] == 1), "Class Result"] = 3 # FN

# check to make sure there's only 0, 1, 2, 3 and no -1 in "Class Result"
print(sal_df["Class Result"].unique())

# convert center focused to bool
sal_df["Center-Focused?"] = sal_df["center focused"].astype(int)

fig, axes = plt.subplots(4, 1, figsize=(16, 16))

# plot confusion matrices
for ax, modelname in zip(axes, modeldict.keys()):
  print("confusing", modelname)
  model_df = sal_df[sal_df["modelname"] == modelname]
  # sanity check
  print(f"{model_df['Center-Focused?'].sum()} center-focused images")
  print(f"{(model_df['Class Result'] == 0).sum()} true positives")
  print(f"{(model_df["Class Result"] == 1).sum()} true negatives")
  print(f"{(model_df['Class Result'] == 2).sum()} false positives")
  print(f"{(model_df['Class Result'] == 3).sum()} false negatives")
  cm = confusion_matrix(model_df["Center-Focused?"], model_df["Class Result"])
  # plot heatmap
  s = sns.heatmap(
    cm[:2], annot=True, fmt="d", cmap="Blues", ax=ax, cbar=False,
    xticklabels=["True Positive", "True Negative", "False Positive", "False Negative"],
    yticklabels=["Not Center-Focused", "Center-Focused"]
  )
  ax.set_title(f"{modelname}", fontsize=12, fontweight="bold")
  ax.set_xlabel("Result", fontsize=10)
  ax.set_ylabel("Saliency", fontsize=10)
plt.tight_layout() # adjust spacing
plt.show()

# plot saliency ratio by class result and model
# make a new column where the class result labels are strings
sal_df["ClassResult"] = sal_df["Class Result"].map({
    0: "True Positive",
    1: "True Negative",
    2: "False Positive",
    3: "False Negative"
})
hue_order = ["True Positive", "False Positive", "True Negative", "False Negative"]
# bar plot- ratio in TP, TN, FP, FN for each mode
p = sns.barplot(data=sal_df, x="modelname", y="center ratio", hue="ClassResult", hue_order=hue_order)
p.set(ylabel="Center vs. Outer Saliency Ratio", title="Center vs. Outer Saliency Ratio by Class Result and Model")
# p.tick_params(axis='x', rotation=30)
for container in p.containers:
  p.bar_label(container, fmt='%.3f', padding=2)
plt.show()

# make new columns where the channel index is the color name
sal_df["Center Important Channel"] = sal_df["important channel center"].map({
    0: "Red",
    1: "Green",
    2: "Blue",
})
sal_df["Outer Important Channel"] = sal_df["important channel outer"].map({
    0: "Red",
    1: "Green",
    2: "Blue",
})
hue_order = ["Red", "Green", "Blue"]
palette = ["red", "green", "blue"]
# plot important channel center by model
p = sns.countplot(data=sal_df, x="modelname", hue="Center Important Channel", hue_order=hue_order, palette=palette)
p.set(ylabel="Count of Images", title="# Images by Model: Most Important Channel in Center")
# p.tick_params(axis='x', rotation=30)
for container in p.containers:
  p.bar_label(container, padding=2)
plt.show()

# plot important channel outer by model
p = sns.countplot(data=sal_df, x="modelname", hue="Outer Important Channel", hue_order=hue_order, palette=palette)
p.set(ylabel="Count of Images", title="# Images by Model: Most Important Channel in Outer Area")
# p.tick_params(axis='x', rotation=30)
for container in p.containers:
  p.bar_label(container, padding=2)
plt.show()


In [None]:
# representative images

# pip install torch torchvision captum pillow matplotlib

import torch
from torchvision import models, transforms
from captum.attr import Saliency
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# load internal test set images and run evaluation
images,targets = next(iter(test_loader))
images.to(device)
idx = np.random.randint(0,len(images))
model.eval()

input_tensor = images[idx].to(device).unsqueeze(0)
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()

while targets[idx].item() != 0 and pred_class !=0:
  idx = np.random.randint(0,len(images))
  input_tensor = images[idx].to(device).unsqueeze(0)
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()

input_tensor.requires_grad_()

for modelname, model in modeldict.items():
  model.to(device)
  model.eval()
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()
  print(pred_class,targets[idx])
  saliency = Saliency(model)
  attr = saliency.attribute(input_tensor, target=pred_class)

  attr = attr.abs().squeeze().cpu().detach().numpy()
  attr = np.max(attr, axis=0)
  print(attr.shape)
  plt.figure(figsize=(5,5))
  plt.imshow(np.array(images[idx].permute(2,1,0)), alpha=0.6)
  plt.imshow(attr, cmap='hot', alpha=0.4)
  plt.axis('off')
  plt.title(f"Predicted class: {pred_class} True class:{targets[idx].item()} {modelname}")

  plt.show()


In [None]:
# pip install torch torchvision captum pillow matplotlib

import torch
from torchvision import models, transforms
from captum.attr import Saliency
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


images,targets = next(iter(external_test_loader))
images.to(device)
idx = np.random.randint(0,len(images))
model.eval()

input_tensor = images[idx].to(device).unsqueeze(0)
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()

while targets[idx].item() != 0 and pred_class !=0:
  idx = np.random.randint(0,len(images))
  input_tensor = images[idx].to(device).unsqueeze(0)
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()

input_tensor.requires_grad_()

for modelname, model in modeldict.items():
  model.to(device)
  model.eval()
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()
  print(pred_class,targets[idx])
  saliency = Saliency(model)
  attr = saliency.attribute(input_tensor, target=pred_class)

  attr = attr.abs().squeeze().cpu().detach().numpy()
  attr = np.max(attr, axis=0)
  print(attr.shape)
  plt.figure(figsize=(5,5))
  plt.imshow(np.array(images[idx].permute(2,1,0)), alpha=0.6)
  plt.imshow(attr, cmap='hot', alpha=0.4)
  plt.axis('off')
  plt.title(f"Predicted class: {pred_class} True class:{targets[idx].item()} {modelname}")

  plt.show()


In [None]:
# pip install torch torchvision captum pillow matplotlib

import torch
from torchvision import models, transforms
from captum.attr import Saliency
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


images,targets = next(iter(test_loader))
images.to(device)
idx = np.random.randint(0,len(images))
model.eval()

input_tensor = images[idx].to(device).unsqueeze(0)
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()

while targets[idx].item() != 1 and pred_class !=1:
  idx = np.random.randint(0,len(images))
  input_tensor = images[idx].to(device).unsqueeze(0)
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()

input_tensor.requires_grad_()

for modelname, model in modeldict.items():
  model.to(device)
  model.eval()
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()
  print(pred_class,targets[idx])
  saliency = Saliency(model)
  attr = saliency.attribute(input_tensor, target=pred_class)

  attr = attr.abs().squeeze().cpu().detach().numpy()
  attr = np.max(attr, axis=0)
  print(attr.shape)
  plt.figure(figsize=(5,5))
  plt.imshow(np.array(images[idx].permute(2,1,0)), alpha=0.6)
  plt.imshow(attr, cmap='hot', alpha=0.4)
  plt.axis('off')
  plt.title(f"Predicted class: {pred_class} True class:{targets[idx].item()} {modelname}")

  plt.show()


In [None]:
# pip install torch torchvision captum pillow matplotlib

import torch
from torchvision import models, transforms
from captum.attr import Saliency
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


images,targets = next(iter(test_loader))
images.to(device)
idx = np.random.randint(0,len(images))
model.eval()

input_tensor = images[idx].to(device).unsqueeze(0)
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()

while targets[idx].item() == pred_class:
  idx = np.random.randint(0,len(images))
  input_tensor = images[idx].to(device).unsqueeze(0)
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()

input_tensor = images[idx].to(device).unsqueeze(0)
input_tensor.requires_grad_()

for modelname, model in modeldict.items():
  model.to(device)
  model.eval()
  output = model(input_tensor)
  pred_class = output.argmax(dim=1).item()
  print(pred_class,targets[idx])
  saliency = Saliency(model)
  attr = saliency.attribute(input_tensor, target=pred_class)

  attr = attr.abs().squeeze().cpu().detach().numpy()
  attr = np.max(attr, axis=0)
  print(attr.shape)
  plt.figure(figsize=(5,5))
  plt.imshow(np.array(images[idx].permute(2,1,0)), alpha=0.6)
  plt.imshow(attr, cmap='hot', alpha=0.4)
  plt.axis('off')
  plt.title(f"Predicted class: {pred_class} True class:{targets[idx].item()} {modelname}")

  plt.show()
