In [13]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt

from mini_imagenet_dataset import MiniImageNetDataset
from tools import getDataset, print_class_distribution

import numpy as np

In [14]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print("MPS device not found.")
print('Device:', device)

Device: mps


In [15]:
batch_size = 128
num_workers = 4
learning_rate = 0.001
num_epochs = 20
image_size = 84
num_classes = 60

root_dir = os.path.join(os.getcwd(), 'datasets/miniImageNet')
dataset, label_mapping = getDataset(path=root_dir, num_classes=num_classes)


train_transforms = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.1, hue=0.1),
            transforms.RandomHorizontalFlip(),
            transforms.CenterCrop((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

transforms = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
       
       
train_dataset = MiniImageNetDataset(dataset=dataset, path=root_dir, phase='train', transform=train_transforms)
val_dataset = MiniImageNetDataset(dataset=dataset, path=root_dir, phase='val', transform=transforms)
test_dataset = MiniImageNetDataset(dataset=dataset, path=root_dir, phase='test', transform=transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
validation_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [16]:
def eval(net, data_loader, criterion=nn.CrossEntropyLoss()):
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()
    net.eval()
    correct = 0.0
    num_images = 0.0
    loss = 0.0
    for i_batch, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        outs = net(images)
        loss += criterion(outs, labels).item()
        _, predicted = torch.max(outs.data, 1)
        correct += (predicted == labels).sum().item()
        num_images += len(labels)
        print('testing/evaluating -> batch: %d correct: %d numb images: %d' % (i_batch, correct, num_images) + '\r', end='')
    acc = correct / num_images
    loss /= len(data_loader)
    return acc, loss


# training function
def train(net, train_loader, valid_loader):

    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(params= net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0001, betas=(0.5, 0.999))
    scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()

    training_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        net.train()
        correct = 0.0  # used to accumulate number of correctly recognized images
        num_images = 0.0  # used to accumulate number of images
        total_loss = 0.0

        for i_batch, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            output_train = net(images)
            loss = criterion(output_train, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicts = output_train.argmax(dim=1)
            correct += predicts.eq(labels).sum().item()
            num_images += len(labels)
            total_loss += loss.item()

            print('training -> epoch: %d, batch: %d, loss: %f' % (epoch, i_batch, loss.item()) + '\r', end='')

        print()
        acc = correct / num_images
        acc_eval, val_loss = eval(net, valid_loader)
        average_loss = total_loss / len(train_loader)
        training_losses.append(average_loss)
        val_losses.append(val_loss)
        print('\nepoch: %d, lr: %f, accuracy: %f, avg. loss: %f, valid accuracy: %f valid loss: %f\n' % (epoch, optimizer.param_groups[0]['lr'], acc, average_loss, acc_eval, val_loss))

        scheduler.step()

    return net, training_losses, val_losses

In [17]:
from models.resnet18 import ResNet18

print(f"Hyperparameters:")
print(f"Batch Size: {batch_size}")
print(f"Learning Rate: {learning_rate}")
print(f"Number of Epochs: {num_epochs}")
print(f"Number of Workers: {num_workers}\n")

# print_class_distribution(train_dataset, "Training", label_mapping)
# print_class_distribution(val_dataset, "Validation", label_mapping)
# print_class_distribution(test_dataset, "Testing", label_mapping)

# model = ResNet18(num_classes=num_classes).to(device)
model, training_losses, val_losses = train(net=model, train_loader=train_loader, valid_loader=validation_loader)

acc_test, test_loss = eval(model, test_loader)
print('\naccuracy on testing data: %f' % acc_test)

plt.plot(training_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

Hyperparameters:
Batch Size: 128
Learning Rate: 0.001
Number of Epochs: 20
Number of Workers: 4



RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same

In [None]:
torch.save(model, os.path.join(os.getcwd(), 'pretrained/resnet18_model_full_2.pth'))

In [8]:
model.eval()  

for i_batch, (images, labels) in enumerate(test_loader):
    images, labels = images.to(device), labels.to(device)
    outs = model(images)
    _, predicted = torch.max(outs.data, 1)
    print(predicted)
    print(labels)
    _, predicted = torch.max(outs.data, 1)
    correct = (predicted == labels).sum().item()
    print(correct, len(labels))
    break

tensor([11, 24, 30, 17, 23, 39, 26,  4, 29, 40, 37,  0, 34, 18, 36, 55,  0, 27,
        18, 42, 48, 35, 41, 43, 31, 24, 19, 28,  6, 20, 47, 17, 50, 38, 41, 47,
        56, 56, 10,  2, 22, 51, 56, 58, 40,  0, 13, 10,  1, 34, 40, 14, 15, 46,
        54, 36, 57, 21,  0, 38, 31, 47,  8,  1, 30, 15, 15, 35, 10, 27, 25, 34,
        18, 29, 33, 50,  1, 36, 12, 57, 20, 33, 32,  3, 45, 34, 50, 10, 25,  9,
        55, 38, 25,  4, 40, 50, 12, 10, 17, 54, 54, 25, 13, 26, 11,  1, 46, 38,
         9, 22,  2, 28, 46, 27, 40, 52, 44, 10, 12, 27, 50, 14,  5, 23, 47, 26,
        43,  9], device='mps:0')
tensor([23, 26, 16, 32, 23, 41, 40,  4, 29, 40, 37, 46, 18, 47,  6, 55, 49, 20,
         0, 12, 48, 35, 41, 57, 31, 24, 19, 28,  6, 28, 48,  9, 34, 19, 41, 47,
        56, 22, 48, 43, 55, 22, 56, 58, 49, 20, 13, 10,  1, 50,  0, 14, 25, 53,
        20,  6, 57, 41, 34, 50, 31,  0,  8, 14, 44, 15,  1, 46, 10, 49, 25, 22,
         3, 29, 20, 15,  1, 47, 12, 57, 32, 33, 32,  3, 15, 25, 32, 51, 20,  9,
       