In [43]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import make_grid

import matplotlib.pyplot as plt

from mini_imagenet_dataset import MiniImageNetDataset
from tools import getDataset, print_class_distribution

import numpy as np
from sklearn.model_selection import train_test_split

In [44]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print("MPS device not found.")
print('Device:', device)

Device: mps


In [45]:
batch_size = 128
num_workers = 4
learning_rate = 0.0001 # 0.00005 - the best one so far
num_epochs = 10
image_size = 84
num_classes = 50

root_dir = os.path.join(os.getcwd(), 'datasets/miniImageNet')
dataset, label_mapping = getDataset(path=root_dir, num_classes=num_classes)


train_transforms = transforms.Compose(
        [
            transforms.RandomCrop(84, padding=8),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=5),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )

transforms = transforms.Compose(
        [
            transforms.RandomCrop(84, padding=8),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
       
train_dataset, temp_dataset = train_test_split(dataset, test_size=0.3, random_state=42)
val_dataset, test_dataset = train_test_split(temp_dataset, test_size=0.5, random_state=42)
      
train_dataset = MiniImageNetDataset(dataset=train_dataset, path=root_dir, phase='train', transform=train_transforms)
val_dataset = MiniImageNetDataset(dataset=val_dataset, path=root_dir, phase='val', transform=transforms)
test_dataset = MiniImageNetDataset(dataset=test_dataset, path=root_dir, phase='test', transform=transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [46]:
def eval(net, data_loader, criterion=nn.CrossEntropyLoss()):
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()
    net.eval()
    correct = 0.0
    num_images = 0.0
    loss = 0.0
    for i_batch, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        outs = net(images)
        loss += criterion(outs, labels).item()
        _, predicted = torch.max(outs.data, 1)
        correct += (predicted == labels).sum().item()
        num_images += len(labels)
        print('testing/evaluating -> batch: %d correct: %d numb images: %d' % (i_batch, correct, num_images) + '\r', end='')
    acc = correct / num_images
    loss /= len(data_loader)
    return acc, loss


# training function
def train(net, train_loader, valid_loader):

    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(params= net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0001, betas=(0.5, 0.999))
    scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()

    training_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        net.train()
        correct = 0.0  # used to accumulate number of correctly recognized images
        num_images = 0.0  # used to accumulate number of images
        total_loss = 0.0

        for i_batch, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            output_train = net(images)
            loss = criterion(output_train, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicts = output_train.argmax(dim=1)
            correct += predicts.eq(labels).sum().item()
            num_images += len(labels)
            total_loss += loss.item()

            print('training -> epoch: %d, batch: %d, loss: %f' % (epoch, i_batch, loss.item()) + '\r', end='')

        print()
        acc = correct / num_images
        acc_eval, val_loss = eval(net, valid_loader)
        average_loss = total_loss / len(train_loader)
        training_losses.append(average_loss)
        val_losses.append(val_loss)
        print('\nepoch: %d, lr: %f, accuracy: %f, avg. loss: %f, valid accuracy: %f valid loss: %f\n' % (epoch, optimizer.param_groups[0]['lr'], acc, average_loss, acc_eval, val_loss))

        scheduler.step()

    return net, training_losses, val_losses

In [47]:
from models.resnet18 import ResNet18

print(f"Hyperparameters:")
print(f"Batch Size: {batch_size}")
print(f"Learning Rate: {learning_rate}")
print(f"Number of Epochs: {num_epochs}")
print(f"Number of Workers: {num_workers}")
print(f"Number of Classes: {num_classes}\n")

# print_class_distribution(train_dataset, "Training", label_mapping)
# print_class_distribution(val_dataset, "Validation", label_mapping)
# print_class_distribution(test_dataset, "Testing", label_mapping)

model = ResNet18(num_classes=num_classes).to(device)
model, training_losses, val_losses = train(net=model, train_loader=train_loader, valid_loader=validation_loader)

acc_test, test_loss = eval(model, test_loader)
print('\naccuracy on testing data: %f' % acc_test)

plt.plot(training_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

Hyperparameters:
Batch Size: 128
Learning Rate: 0.0001
Number of Epochs: 10
Number of Workers: 4
Number of Classes: 50

training -> epoch: 0, batch: 160, loss: 1.336646
testing/evaluating -> batch: 34 correct: 2986 numb images: 4410
epoch: 0, lr: 0.000100, accuracy: 0.472935, avg. loss: 2.110046, valid accuracy: 0.677098 valid loss: 1.202568

training -> epoch: 1, batch: 160, loss: 1.153150
testing/evaluating -> batch: 34 correct: 3249 numb images: 4410
epoch: 1, lr: 0.000100, accuracy: 0.689504, avg. loss: 1.118639, valid accuracy: 0.736735 valid loss: 0.966275

training -> epoch: 2, batch: 160, loss: 1.002141
testing/evaluating -> batch: 34 correct: 3305 numb images: 4410
epoch: 2, lr: 0.000100, accuracy: 0.748639, avg. loss: 0.880443, valid accuracy: 0.749433 valid loss: 0.905919

training -> epoch: 3, batch: 160, loss: 0.695493
testing/evaluating -> batch: 34 correct: 3280 numb images: 4410
epoch: 3, lr: 0.000100, accuracy: 0.789553, avg. loss: 0.733809, valid accuracy: 0.743764 va

In [None]:
torch.save(model, os.path.join(os.getcwd(), 'pretrained/resnet18_model_75.pth'))

In [None]:
model.eval()  

for i_batch, (images, labels) in enumerate(test_loader):
    images, labels = images.to(device), labels.to(device)
    outs = model(images)
    _, predicted = torch.max(outs.data, 1)
    print(predicted)
    print(labels)
    _, predicted = torch.max(outs.data, 1)
    correct = (predicted == labels).sum().item()
    print(correct, len(labels))
    break

tensor([47, 11, 23, 10,  9,  3,  3, 30, 27, 48, 25, 23, 16, 10, 23, 39, 20, 10,
        31, 22, 18, 17, 25, 17, 19, 30, 42,  1, 30, 43, 45, 38, 45, 27, 25,  7,
         3, 21,  6, 48, 29, 47, 19, 16, 37, 29, 34, 24, 39,  0, 46, 43, 15, 30,
        36, 46, 31, 25, 32,  3,  6, 37, 16, 31,  5, 10, 41, 30, 21,  1, 35, 33,
        29, 27, 47, 23, 25, 12, 39,  8, 10, 19,  0, 39, 40, 13, 13, 20, 41, 22,
        25, 34, 46, 38, 33,  1, 16,  1, 38, 17, 17, 22,  1, 16, 13, 10,  9, 15,
        31, 29, 31, 12, 27, 16,  3, 12, 30, 11, 12, 40, 22, 16, 42, 22, 30,  5,
        18,  6], device='mps:0')
tensor([47, 11, 23, 10,  9, 12,  3, 30, 27, 48, 25, 23, 16, 10, 23, 28, 20, 10,
        31,  2, 34, 17, 25, 17, 19, 30, 42, 15, 30, 28, 11, 28, 45, 27, 25,  7,
         3, 21, 22, 34, 29, 47, 19, 16, 37,  7, 34, 24, 39, 47, 46, 43, 10, 30,
        32, 46, 31, 25, 32, 47,  6, 37, 39, 31,  5, 44, 41, 30, 22,  1, 35, 33,
        29, 27, 47, 23, 19, 36, 39,  8, 10, 19, 32, 44, 40, 13, 13, 28, 41, 22,
       

In [None]:
103/128

0.8046875