In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import os

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
# Transforms

transform = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

In [None]:
# Check if CIFAR10 dataset is available. Download if it isn't
if os.path.isdir("/data/cifar10"):
    print("CIFAR10 dataset is available")
    # Finish the loading process
else:
    print("CIFAR10 dataset is not downloaded")
    print("Downloading CIFAR10 dataset...")
    trainset = torchvision.datasets.CIFAR10(root="../../data/CIFAR-10/train/", train=True, download=True, transform=transform['train'])
    testset = torchvision.datasets.CIFAR10(root="../../data/CIFAR-10/test/", train=False, download=True, transform=transform['test'])
    print("CIFAR10 dataset is downloaded")

In [None]:
# Load CIFAR10 dataset
# trainset = torchvision.datasets.ImageFolder(root="../../data/CIFAR-10/train", transform=transform['train'])

total_size = len(trainset)
val_size = int(0.1*total_size) # HAS TO BE AN INTEGER
train_size = total_size - val_size
generator = torch.Generator().manual_seed(42)

trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size], generator=generator)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

In [None]:
model = models.resnet50(pretrained=True)

num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=50):
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print('--' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = trainloader
            else:
                model.eval()
                dataloader = valloader

            running_loss = 0.0
            running_correct = 0

            for inputs, labels in iter(dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):    
                    outputs = model(inputs)
                    _, preds = torch.max(outputs,1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss = running_loss + loss
                running_correct = running_correct + torch.sum(preds == labels)
            
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_correct / len(dataloader)

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    print("Train complete")
    return model


In [None]:
model_ft = train_model(model, criterion, optimizer, scheduler, num_epochs=30)

In [None]:
model_save_path = './saved_model.pth'  # Change this path as needed
torch.save(model_ft.state_dict(), model_save_path)