In [None]:
import os
import sys

# set path to directory of file
path = os.getcwd()
# move up one directory and go down into Repos, then into rotnet
path = os.path.abspath(os.path.join(path, os.pardir, 'Repos', 'rotnet'))
# add path to sys
sys.path.append(path)

import medmnist
from medmnist import INFO

import importlib
import moment_kernels as mk
importlib.reload(mk)

import torch
import torch.nn as tnn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms

import e2cnn.nn as enn
import e2cnn.gspaces as gspaces

from benchmark_models import *

## **Hyperparameters**

In [None]:
# hyperparameters
EPOCHS = 20
BATCH_SIZE = 128
lr = 0.001
WEIGHT_DECAY = 0.0001

## **Dataset**

In [None]:
data_flag = "pathmnist"
download = True
info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## **Transforms**

In [None]:
# transforms to convert from image to normalized tensor (or more if augmentation)
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5], std = [0.5]),
    transforms.RandomHorizontalFlip(p=0.5),
])

# separate transforms for test
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5], std = [0.5])
])

## **Splitting and Shuffling**

In [None]:
train_dataset = DataClass(split = "train", transform = train_transforms, download = download)
valid_dataset = DataClass(split = "val", transform = test_transforms, download = download)
test_dataset = DataClass(split = "test", transform = test_transforms, download = download)

# set aside for visualization as PIL image
vis_dataset = DataClass(split = "test", download = download)

train_loader = data.DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = False)
valid_loader = data.DataLoader(dataset = valid_dataset, batch_size = BATCH_SIZE, shuffle = False)
test_loader = data.DataLoader(dataset = test_dataset, batch_size = BATCH_SIZE, shuffle = False)

## **Model**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#nn0 = VanillaCNN(img_channels = 3, n0 = 64, n_classes = 5, kernel_size = 3, padding = 1, num_layers = 4).to(device)
#nn1 = TrivialECNN(img_channels = 3, n0 = 16, n_classes = 5, kernel_size = 3, padding = 1, num_layers = 4).to(device)
#nn2 =TrivialIrrepECNN(img_channels = 3, n0 = 16, n_classes = 5, kernel_size = 3, padding = 1, num_layers = 4).to(device)
#nn3 = RegularECNN(img_channels = 3, n0 = 16, n_classes = 5, kernel_size = 3, padding = 1, num_layers = 4).to(device)
#nn4 = TrivialMoment(img_channels = 3, n0 = 16, n_classes = 5, kernel_size = 3, padding = 1, num_layers = 4).to(device)
nn5 = TrivialIrrepMoment(img_channels = 3, n0 = 16, n_classes = 5, kernel_size = 3, padding = 1, num_layers = 4).to(device)

## **Random Testing**

In [None]:
def test_model(model: torch.nn.Module):
    model.eval()

    x = torch.randn(1, 3, 28, 28).to(device)
    with torch.no_grad():
        y = model(x)
        print(y.shape)
        print(y)
    return y

output = test_model(nn5)

## **Training**