##Install Requirement

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/Omid-Nejati/MedViT.git
%cd /content/MedViT

In [1]:
pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[K     |████████████████████████████████| 43 kB 1.5 MB/s eta 0:00:01
[?25hCollecting timm
  Downloading timm-1.0.11-py3-none-any.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 2.8 MB/s eta 0:00:01
[?25hCollecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl (11.0 MB)
[K     |████████████████████████████████| 11.0 MB 9.3 MB/s eta 0:00:01
[?25hCollecting scikit-image
  Downloading scikit_image-0.24.0-cp39-cp39-macosx_12_0_arm64.whl (13.4 MB)
[K     |████████████████████████████████| 13.4 MB 12.7 MB/s eta 0:00:01
[?25hCollecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[K     |████████████████████████████████| 50 kB 11.7 MB/s eta 0:00:01
[?25hCollecting tqdm
  Downloading tqdm-4.66.5-py3-none-

In [2]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchsummary import summary

from tqdm import tqdm
import medmnist
from medmnist import INFO, Evaluator

import torchattacks
from torchattacks import PGD, FGSM

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
print("PyTorch", torch.__version__)
print("Torchvision", torchvision.__version__)
print("Torchattacks", torchattacks.__version__)
print("Numpy", np.__version__)
print("Medmnist", medmnist.__version__)

##Dataset

data_flag =  
[tissuemnist, pathmnist, chestmnist, dermamnist, octmnist, pnemoniamnist, retinamnist, breastmnist, bloodmnist, tissuemnist, organamnist, organcmnist, organsmnist]

In [None]:
data_flag = 'retinamnist'
# [tissuemnist, pathmnist, chestmnist, dermamnist, octmnist,
# pnemoniamnist, retinamnist, breastmnist, bloodmnist, tissuemnist, organamnist, organcmnist, organsmnist]
download = True

NUM_EPOCHS = 10
BATCH_SIZE = 10
lr = 0.005

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

print("number of channels : ", n_channels)
print("number of classes : ", n_classes)

In [None]:
from torchvision.transforms.transforms import Resize
# preprocessing
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    torchvision.transforms.AugMix(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=train_transform, download=download)
test_dataset = DataClass(split='test', transform=test_transform, download=download)

# pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

In [None]:
print(train_dataset)
print("===================")
print(test_dataset)

##Model

MedViTs ---> [MedViT_small, MedViT_base, MedViT_large]

In [63]:
from MedViT import MedViT_small, MedViT_base, MedViT_large

model = MedViT_small(num_classes = n_classes).cuda()
#model = MedViT_base(num_classes = n_classes).cuda()
#model = MedViT_large(num_classes = n_classes).cuda()

initialize_weights...


## Train

In [64]:
# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# train

for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    print('Epoch [%d/%d]'% (epoch+1, NUM_EPOCHS))
    model.train()
    for inputs, targets in tqdm(train_loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

##Test

In [66]:
split = 'test'

model.eval()
y_true = torch.tensor([])
y_score = torch.tensor([])

data_loader = train_loader_at_eval if split == 'train' else test_loader

with torch.no_grad():
    for inputs, targets in data_loader:
        inputs = inputs.cuda()
        outputs = model(inputs)
        outputs = outputs.softmax(dim=-1)
        y_score = torch.cat((y_score, outputs.cpu()), 0)

    y_score = y_score.detach().numpy()

    evaluator = Evaluator(data_flag, split, size=224)
    metrics = evaluator.evaluate(y_score)

    print('%s  auc: %.3f  acc: %.3f' % (split, *metrics))

test  auc: 0.623  acc: 0.472


## Adversarial Robustness

reduce bach size for GPU limitation

In [67]:
BATCH_SIZE = 5
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model.eval()

correct = 0
total = 0

atk = FGSM(model, eps=0.01)

for images, labels in test_loader:
    labels = labels.squeeze(1)
    images = atk(images, labels).cuda()
    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)

    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()

print('FGSM Robust accuracy: %.2f %%' % (100 * float(correct) / total))

In [None]:
model.eval()

correct = 0
total = 0

atk = PGD(model, eps=8/255, alpha=4/255, steps=10, random_start=True)

for images, labels in test_loader:
    labels = labels.squeeze(1)
    images = atk(images, labels).cuda()
    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)

    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()

print('PGD Robust accuracy: %.2f %%' % (100 * float(correct) / total))