In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import os
from torchvision import datasets, models, transforms

device = "mps" if getattr(torch,'has_mps',False) \
    else "gpu" if torch.cuda.is_available() else "cpu"

print(device)

mps


In [2]:
#Define transforms for the training data and testing data
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224 pixels
    transforms.ToTensor(),  # Convert the image to a tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize the image
])

In [3]:
# Load the datasets with ImageFolder
data_dir = '../data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), transform) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'valid']}

In [7]:
# Load the datasets with ImageFolder
model = models.resnet50(pretrained=True)

In [8]:
#modify the last layer of the model
num_classes = len(image_datasets['train'].classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)

In [9]:
#define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
#Train the model
num_epochs = 10
for epoch in range(num_epochs):
    for inputs, labels in dataloaders['train']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


In [None]:
#Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')    

In [None]:
#Evaluate the model
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in dataloaders['valid']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print('Accuracy on the validation set: {:.4f}%'.format(100 * accuracy))
