In [2]:
import os
import sys

# set path to directory of file
path = os.getcwd()
# move up one directory and go down into Repos, then into rotnet
path = os.path.abspath(os.path.join(path, os.pardir, 'Repos', 'rotnet'))
# add path to sys
sys.path.append(path)

import medmnist
from medmnist import INFO

import importlib
import moment_kernels as mk
importlib.reload(mk)

import torch
import torch.nn as tnn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms

import e2cnn.nn as enn
import e2cnn.gspaces as gspaces

# located in /rotnet/benchmark/benchmark_models.py
from benchmark.benchmark_models import *

## **Hyperparameters**

In [3]:
# hyperparameters
EPOCHS = 20
BATCH_SIZE = 128
lr = 0.001

## **Dataset**

In [4]:
data_flag = "dermamnist"
download = True
info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## **Transforms**

In [5]:
# transforms to convert from image to normalized tensor (or more if augmentation)
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5], std = [0.5]),
])

# separate transforms for test
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5], std = [0.5])
])

## **Splitting and Shuffling**

In [6]:
train_dataset = DataClass(split = "train", transform = train_transforms, download = download)
valid_dataset = DataClass(split = "val", transform = test_transforms, download = download)
test_dataset = DataClass(split = "test", transform = test_transforms, download = download)

train_loader = data.DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = False)
valid_loader = data.DataLoader(dataset = valid_dataset, batch_size = BATCH_SIZE, shuffle = False)
test_loader = data.DataLoader(dataset = test_dataset, batch_size = BATCH_SIZE, shuffle = False)

Using downloaded and verified file: /ifshome/jliem/.medmnist/dermamnist.npz
Using downloaded and verified file: /ifshome/jliem/.medmnist/dermamnist.npz
Using downloaded and verified file: /ifshome/jliem/.medmnist/dermamnist.npz


## **Model**

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# nn0 = VanillaCNN(img_channels = 3, n0 = 32, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)
# nn1 = TrivialECNN(img_channels = 3, n0 = 32, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)
# nn2 = TrivialIrrepECNN(img_channels = 3, n0 = 32, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)
# nn3 = RegularECNN(img_channels = 3, n0 = 32, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)
# nn4 = TrivialMoment(img_channels = 3, n0 = 32, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)
# nn5 = TrivialIrrepMoment(img_channels = 3, n0 = 16, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

## **Random Testing**

In [8]:
def test_model(model: torch.nn.Module):
    model.eval()

    x = torch.randn(1, 3, 28, 28).to(device)
    with torch.no_grad():
        y = model(x)
        print(y.shape)
        print(y)
    return y

# output = test_model(nn5)

torch.Size([1, 7])
tensor([[-0.7257,  1.4012,  0.6860,  0.3275, -1.3211, -0.3494, -0.8620]])


## **Training**

In [9]:
# dermaMNIST data
# pick our 6 models
# optim: Adam (w/ lr 1e-4) no wd or data augmentation
# train for 100 epochs
# every epoch, save the AUC and accuracy on the val set
# at end, report the AUC and accuracy for the final network
# report AUC and accuracy for the best version of the network
# Best accuracy version, best AUC version.
# Each network gets 3 scores
# Time each epoch save it
# find an approximate n_channels for vanilla and approximate number of parameters in all other models

In [10]:
nn0 = VanillaCNN(img_channels = 3, n0 = 32, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

# count number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(nn0))

1574151


In [19]:
nn4 = TrivialMoment(img_channels = 3, n0 = 55, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

print(count_parameters(nn4))

1554527


In [22]:
nn5 = TrivialIrrepMoment(img_channels = 3, n0 = 59, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

print(count_parameters(nn5))

1565621


In [29]:
nn1 = TrivialECNN(img_channels = 3, n0 = 67, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

print(count_parameters(nn1))

1540404


In [34]:
nn2 = TrivialIrrepECNN(img_channels = 3, n0 = 62, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

print(count_parameters(nn2))

1562221


In [40]:
nn3 = RegularECNN(img_channels = 3, n0 = 29, n_classes = n_classes, kernel_size = 3, padding = 1, num_layers = 5).to(device)

print(count_parameters(nn3))

1500090
