In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import os
import sys
from datetime import datetime
from PIL import Image
import cv2
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.classification import Accuracy, AUROC, F1Score
from torchvision.utils import make_grid
from torch.optim.lr_scheduler import ReduceLROnPlateau

sys.path.append('..')

from src.data.utils import get_image_paths
from src.data.dataset import HalfCircleBinaryClfDataset
from src.data.transforms import TRAIN_TRANSFORMS, TEST_TRANSFORMS
from src.modeling.model import HCCLF

In [15]:
device = torch.device('mps' if torch.mps.is_available() else 'cpu')
device

device(type='mps')

In [16]:
images_filepaths = get_image_paths(directory="/Users/alexandrepoupeau/Documents/work/code/perso/aitt-symbol-clf/data/")
train_images_filepaths, test_images_filepaths = train_test_split(images_filepaths, test_size=0.2)
train_images_filepaths, val_images_filepaths = train_test_split(train_images_filepaths, test_size=0.2)

train_ds = HalfCircleBinaryClfDataset(images_filepaths=train_images_filepaths, transform=TRAIN_TRANSFORMS)
val_ds = HalfCircleBinaryClfDataset(images_filepaths=val_images_filepaths, transform=TEST_TRANSFORMS)
test_ds = HalfCircleBinaryClfDataset(images_filepaths=test_images_filepaths, transform=TEST_TRANSFORMS)

In [17]:
train_loader = DataLoader(train_ds, batch_size=64, shuffle=False, num_workers=4, persistent_workers=True, prefetch_factor=64)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4, persistent_workers=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2, persistent_workers=False)

In [18]:
test_ds[0][0].dtype

torch.float32

In [19]:
test_ds[7][1]

0.0

In [22]:
load_pretrained = True
pretrained_model_checkpoint_filepath = "../models/keep/model_20250315_234944_47.pt"

if not load_pretrained:
    model = HCCLF(lr=1e-4).to(device)
else:
    print("Load pretrained model")
    model = HCCLF(lr=1e-4).to(device)
    model.load_state_dict(torch.load(pretrained_model_checkpoint_filepath))

Load pretrained model


In [23]:
optimizer = model.configure_optimizers()
scheduler = ReduceLROnPlateau(optimizer, 'min')
loss_fn = torch.nn.BCELoss()

In [24]:
def train_one_epoch(epoch_index, tb_writer):
    threshold = 0.5
    running_loss = 0.
    running_auroc = 0.
    running_accuracy = 0.
    running_f1 = 0.
    last_loss = 0.
    step_logs = 100

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.unsqueeze(1).to(torch.float32).to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)
        preds = (outputs >= threshold).float()

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        auroc_value = AUROC("binary").to(device)(outputs, labels)
        running_auroc += auroc_value
        accuracy_value = Accuracy("binary").to(device)(preds, labels)
        running_accuracy += accuracy_value
        f1_value = F1Score("binary").to(device)(preds, labels)
        running_f1 += f1_value

        if i % step_logs == step_logs - 1:
            last_loss = running_loss / step_logs # loss per batch
            last_accuracy = running_accuracy / step_logs
            last_f1 = running_f1 / step_logs
            last_auroc = running_auroc / step_logs

            # print('  batch {} loss: {} acc: {} f1: {} auroc: {}'.format(i + 1, last_loss, last_accuracy, last_f1, last_auroc))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)

            running_loss = 0.
            running_accuracy = 0.
            running_f1 = 0.
            running_auroc = 0.

    return last_loss, last_accuracy, last_f1, last_auroc

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('../runs/hcclf_trainer_{}'.format(timestamp))

dataiter = iter(train_loader)
images, labels = next(dataiter)

grid = make_grid(images)
writer.add_image('images', grid)

# writer.add_graph(model=model, input_to_model=images)

epoch_number = 0

EPOCHS = 5

best_vloss = 1e6

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss, avg_acc, avg_f1, avg_auroc = train_one_epoch(epoch_number, writer)

    running_vloss = 0.0
    running_vacc = 0.0
    running_vf1 = 0.
    running_vauroc = 0.

    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_loader):
            vinputs, vlabels = vdata
            vinputs = vinputs.to(device)
            vlabels = vlabels.unsqueeze(1).to(torch.float32).to(device)
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            auroc_value = AUROC("binary").to(device)(voutputs, vlabels)
            running_vauroc += auroc_value
            accuracy_value = Accuracy("binary").to(device)(voutputs, vlabels)
            running_vacc += accuracy_value
            f1_value = F1Score("binary").to(device)(voutputs, vlabels)
            running_vf1 += f1_value
            running_vloss += vloss

    scheduler.step(vloss)

    avg_vloss = running_vloss / (i + 1)
    avg_vacc = running_vacc / (i + 1)
    avg_vf1 = running_vf1 / (i + 1)
    avg_vauroc = running_vauroc / (i + 1)
    # print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    # print('ACC train {} valid {}'.format(avg_acc, avg_vacc))
    # print('F1SCORE train {} valid {}'.format(avg_f1, avg_vf1))
    # print('AUROC train {} valid {}'.format(avg_auroc, avg_vauroc))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Loss/diff',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.add_scalars('Accuracy/diff',
                    { 'Training' : avg_acc, 'Validation' : avg_vacc },
                    epoch_number + 1)
    writer.add_scalars('F1Score/diff',
                    { 'Training' : avg_f1, 'Validation' : avg_vf1 },
                    epoch_number + 1)
    writer.add_scalars('AUROC/diff',
                    { 'Training' : avg_auroc, 'Validation' : avg_vauroc },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = '../models/model_{}_{}.pt'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1



EPOCH 1:
EPOCH 2:
EPOCH 3:
EPOCH 4:
EPOCH 5:


In [26]:
running_vloss = 0.0
running_vacc = 0.0
running_vf1 = 0.
running_vauroc = 0.

model.eval()

# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
    for i, vdata in enumerate(test_loader):
        vinputs, vlabels = vdata
        vinputs = vinputs.to(device)
        vlabels = vlabels.unsqueeze(1).to(torch.float32).to(device)
        voutputs = model(vinputs)
        vloss = loss_fn(voutputs, vlabels)
        auroc_value = AUROC("binary").to(device)(voutputs, vlabels)
        running_vauroc += auroc_value
        accuracy_value = Accuracy("binary").to(device)(voutputs, vlabels)
        running_vacc += accuracy_value
        f1_value = F1Score("binary").to(device)(voutputs, vlabels)
        running_vf1 += f1_value
        running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
avg_vacc = running_vacc / (i + 1)
avg_vf1 = running_vf1 / (i + 1)
avg_vauroc = running_vauroc / (i + 1)
print('LOSS test {}'.format(avg_vloss))
print('ACC test {}'.format(avg_vacc))
print('F1SCORE test {}'.format(avg_vf1))
print('AUROC test {}'.format(avg_vauroc))

LOSS test 0.008875399827957153
ACC test 0.9986076951026917
F1SCORE test 0.9967607259750366
AUROC test 0.999913215637207
