In [1]:

from models.resnet import Resnet50Model
from datetime import datetime

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2

from datasets.LabeledImageDataset import LabeledImageDataset, default_image_transform
from train import train_classifier
from utils import reduce_dataset, split_dataset, oversample_dataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 128
dataset = LabeledImageDataset("data/TrueDataset", extension=".png")
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, test_dataset = split_dataset(dataset, train_ratio=0.7)
train_dataset = oversample_dataset(train_dataset, v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    v2.RandomRotation(30),
    v2.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    v2.RandomPerspective(distortion_scale=0.5, p=0.5),
    v2.GaussianBlur(kernel_size=3),
    v2.RandomErasing(p=0.5),
    default_image_transform,
]))

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=True, )

model = Resnet50Model(hidden_layers=2, units_per_layer=2048, dropout=0.4)  #torch.load("resnet-50-epoch-1.pickle")

print(f"Dataset: {len(train_dataset):,} training, {len(test_dataset):,} testing")


Device: cuda:0
Dataset: 48,017 training, 10,294 testing


In [None]:
print(f"Training starts {datetime.now().isoformat()}")

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, test_loader, device,
                                        learning_rate=0.0001,
                                        max_epochs=50,
                                        checkpoint_every=None,
                                        eval_every=10)

Training starts 2024-09-13T12:52:44.495130


Epoch 1 training:   0%|          | 0/376 [00:00<?, ?it/s]