# Train Classifier

#### Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms.v2 as v2

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

np.random.seed(0)
torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

#### Tensorboard logging

In [None]:
import pathlib

logdir = pathlib.Path('./logs/cls')
i = 1
while (logdir/f'run{i}').exists():
    i += 1
logdir = logdir/f'run{i}'
logdir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(logdir)

print(f'Logging to: {logdir}')

#### Hyperparameters

In [None]:
import yaml

hparams_file = ''
# hparams_file = './hparams_cls.yaml'

if hparams_file:
    with open(hparams_file) as f:
        hparams = yaml.safe_load(f)
else:
    hparams = {
        'image_size': [224, 224],
        'batch_size': 32,
        'num_epochs': 10,
        'lr': 1e-4,
    }

writer.add_text('hparams', yaml.dump(hparams, sort_keys=False))

#### Prepare dataset

In [None]:
import os
from PIL import Image

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        neg_dir = os.path.join(root, 'neg')
        pos_dir = os.path.join(root, 'pos')

        with os.scandir(neg_dir) as it:
            neg_files = [entry.path for entry in it if entry.is_file()]
        with os.scandir(pos_dir) as it:
            pos_files = [entry.path for entry in it if entry.is_file()]

        self.transforms = transforms
        self.pos_files = pos_files
        self.neg_files = neg_files
    
    def __len__(self):
        return len(self.pos_files)

    def __getitem__(self, idx):
        with Image.open(self.pos_files[idx]) as img:
            pos_img = img.copy()
        with Image.open(self.neg_files[idx]) as img:
            neg_img = img.copy()

        if self.transforms:
            pos_img = self.transforms(pos_img)
            neg_img = self.transforms(neg_img)

        return (pos_img, neg_img)

In [None]:
image_size = hparams['image_size']
batch_size = hparams['batch_size']

transforms_list = [
    v2.ToImage(),
    # v2.RandomHorizontalFlip(),
    v2.Resize(image_size),
    v2.ToDtype(torch.float, scale=True),
]
transforms_composed = v2.Compose(transforms_list)

dataset = CustomDataset('./dataset/preprocessed/', transforms=transforms_composed)
dataset_train, dataset_val = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=batch_size,
    shuffle=False
)

In [None]:
samples = next(iter(train_loader))
pos = samples[0][:4]
neg = samples[1][:4]
print(pos.shape)
print(neg.shape)
grid_img = torchvision.utils.make_grid(torch.cat((pos, neg), dim=0), nrow=4)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

#### Build the model

In [None]:
import cls_model
model = cls_model.Classifier(num_classes=1).to(device)

from torchinfo import summary
print(summary(model, input_size=(batch_size, 3, *image_size)))

#### Training and evaluation

In [None]:
def evaluate(model, dataloader):
    loss = 0
    n_correct = 0
    count = 0

    training = model.training
    model.eval()

    with torch.no_grad():
        for data in tqdm(dataloader, leave=False):
            pos, neg = data
            pos = pos.to(device)
            neg = neg.to(device)

            x = torch.cat((pos, neg), dim=0)

            y_true = torch.cat((
                torch.ones(pos.shape[0], dtype=torch.float),
                torch.zeros(neg.shape[0], dtype=torch.float)), dim=0).to(device)
            y_true = y_true.unsqueeze(1)

            N = x.shape[0]

            y_pred = model(x)
            loss += model.loss_fn(y_pred, y_true).item() * N

            n_correct += torch.isclose(y_pred, y_true, atol=0.5).sum().item()
            count += y_true.shape[0]
    
    model.train(training)
    
    loss = loss / len(dataloader.dataset)
    acc = n_correct / count
    
    stats = {
        'loss': loss,
        'acc': acc,
    }
    return stats

##### Training

In [None]:
num_epochs = hparams['num_epochs']
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])

model.train()
step = 0
best_val_loss = np.inf

for epoch in tqdm(range(num_epochs)):
    for data in tqdm(train_loader, leave=False):
        pos, neg = data
        pos = pos.to(device)
        neg = neg.to(device)

        x = torch.cat((pos, neg), dim=0)
        
        y_true = torch.cat((
            torch.ones(pos.shape[0], dtype=torch.float),
            torch.zeros(neg.shape[0], dtype=torch.float)), dim=0).to(device)
        y_true = y_true.unsqueeze(1)
        
        y_pred = model(x)
        loss = model.loss_fn(y_pred, y_true)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step += 1
        writer.add_scalar('train/loss', loss.item(), step)
    
    stats_val = evaluate(model, val_loader)
    for k, v in stats_val.items():
        writer.add_scalar(f'val/{k}', v, step)
    
    if stats_val['loss'] < best_val_loss:
        best_val_loss = stats_val['loss']
        torch.save(model.state_dict(), logdir/'best_model.pth')

torch.save(model.state_dict(), logdir/'last_model.pth')

#### Evaluate on test set

In [None]:
model.load_state_dict(torch.load(logdir/'best_model.pth'))
test_stats = evaluate(model, val_loader)
print(test_stats)