In [1]:


import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2

from datasets.CSVStreamDataset import CSVStreamDataset
from datasets.LabeledImageDataset import LabeledImageDataset
from models.resnet import Resnet18Model
from utils import oversample_dataset
from utils import reduce_dataset, split_dataset, undersample_dataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 128
dataset = CSVStreamDataset("pretrained_outputs.csv")
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, test_dataset = split_dataset(dataset, train_ratio=0.7)
train_dataset = oversample_dataset(undersample_dataset(train_dataset, target_size=2000),
                                   augment_Size=1800,
                                   transforms=v2.Compose([
                                       v2.ToImage(),
                                       # v2.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
                                       v2.RandomHorizontalFlip(p=0.5),
                                       v2.RandomVerticalFlip(p=0.5),
                                       v2.RandomRotation(degrees=30),
                                       # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                                       # v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
                                       v2.ToDtype(torch.float32, scale=True),
                                       v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                   ]))  #undersample_dataset(train_dataset)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=True, )

model = Resnet18Model(hidden_layers=2, units_per_layer=2048, dropout=0.4)

print(f"Dataset: {len(train_dataset):,} training, {len(test_dataset):,} testing")


Device: cuda:0
Dataset: 48,017 training, 10,294 testing
