In [2]:
from utils import ImageDataset

import numpy as np
import torch
import torch.nn as nn
from tqdm.notebook import tqdm

from torch.utils.data import DataLoader, random_split
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomHorizontalFlip, RandomResizedCrop, RandomRotation, RandomVerticalFlip, ToDtype, ToImage

In [7]:
rng = torch.Generator().manual_seed(77)
BATCH_SIZE = 128
LEARNING_RATE = 0.1
WEIGHT_DECAY = 0.001
N_EPOCHS = 100
IMAGE_DIR = 'Data/downsampled_data'
METADATA_PATH = 'Data/metadata_BR00116991.csv'
CNN_OUT_PATH = 'cnn.pth'

In [4]:
class Classifier(nn.Module):
    def __init__(self, output_size: int):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding='same')
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding='same')
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.mlp1 = nn.Linear(128 * 32 * 32, 256)
        self.mlp2 = nn.Linear(256, output_size)

        self.dropout = nn.Dropout(0.5)

        self.activation = nn.LeakyReLU()        
        
    def forward(self, x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        x = self.pool(self.activation(self.conv3(x)))

        x = x.view(x.size(0), -1)

        x = self.dropout(self.activation(self.mlp1(x)))
        x = self.mlp2(x)
        
        return x


In [5]:
def train_loop(device, classifier, train_loader, optimizer, scheduler, loss_fn, n_epochs):
    losses, accs = [], []
    classifier.train()
    for epoch in range(1, 1+n_epochs):
        loss_val = 0
        acc_val = 0
        n = 0
        for (X, y) in tqdm(train_loader):
            m = y.size(0)
    
            X, y = X.to(device), y.to(device)
            yhat = classifier(X)
            pred = torch.argmax(yhat, dim=1)
            
            loss = loss_fn(yhat, y)
            acc = (pred == y).sum()
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            loss_val += loss.item() * m
            acc_val += acc.item()
            n += m
        
        print(f'Epoch {epoch}: Loss = {loss_val / n:.3f}; Accuracy = {acc_val / n:.3f}')
        losses.append(loss_val / n)
        accs.append(acc_val / n)
        
        scheduler.step()
    
    return np.array(losses), np.array(accs)
        

In [6]:
transforms = Compose([
    ToImage(),
    RandomResizedCrop((256, 256)),
    ToDtype(torch.float32, scale=True),
])
        
dataset = ImageDataset(image_dir=IMAGE_DIR, metadata_path=METADATA_PATH, transforms=transforms, convert_rgb=False)
train_dataset, test_dataset = random_split(dataset, [0.8, 0.2], generator=rng)
train_loader, test_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True), DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
device = torch.device('mps')
clf = Classifier(dataset.n_classes()).to(device)
optim = torch.optim.Adam(clf.parameters(), LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=10, gamma=0.1)
loss_fn = nn.CrossEntropyLoss()

losses, accs = train_loop(device, clf, train_loader, optim, scheduler, loss_fn, N_EPOCHS)

  0%|          | 0/18 [00:00<?, ?it/s]

Epoch 1: Loss = 123793.530; Accuracy = 0.040


  0%|          | 0/18 [00:00<?, ?it/s]

Epoch 2: Loss = 44168.996; Accuracy = 0.010


  0%|          | 0/18 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
torch.save(clf.state_dict(), CNN_OUT_PATH)