In [None]:
import os
import time
import copy
import torch
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from PIL import Image


### Set parameters

In [None]:
data_root = r"E:/iCloudDrive/UoM/Y3/34212robotics/cw/dataset1"

# Hyperparameters
num_epochs = 15
batch_size = 32
learning_rate = 1e-4
train_ratio = 0.7
val_ratio = 0.2
patience = 4

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)

save_dir = f"results/1"
os.makedirs(save_dir, exist_ok=True)

### Data preprocessing & dataset splitting

In [None]:
# Training data augmentation: random crop to 224x224, horizontal flip
train_transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Validation and test transforms: center crop to 224x224
val_test_transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Load base dataset without transform for splitting
base_dataset = datasets.ImageFolder(root=data_root)
dataset_size = len(base_dataset)
train_size = int(train_ratio * dataset_size)
val_size = int(val_ratio * dataset_size)
test_size = dataset_size - train_size - val_size

# Split dataset
train_subset, val_subset, test_subset = random_split(
    base_dataset, [train_size, val_size, test_size]
)

# Apply transforms
def apply_transform(subset, transform):
    subset.dataset.transform = transform
    return subset

train_subset = apply_transform(train_subset, train_transform)
val_subset = apply_transform(val_subset, val_test_transform)
test_subset = apply_transform(test_subset, val_test_transform)

# Create DataLoaders
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=4)

# Dataset info
class_names = base_dataset.classes
num_classes = len(class_names)
print(f"Total images: {dataset_size}")
print(f"Train/Val/Test sizes: {train_size}/{val_size}/{test_size}")
print(f"Classes: {class_names}")

In [None]:
# # Load a raw sample image
# raw_img_path, raw_label = base_dataset.samples[571]
# raw_img = Image.open(raw_img_path).convert("RGB")

# # Apply transforms
# train_img_transformed = train_transform(raw_img)
# val_img_transformed = val_test_transform(raw_img)

# # Function to display a normalized tensor image (for visualization)
# def imshow(tensor, title):
#     img = tensor.clone().detach().cpu().numpy()
#     img = img.transpose(1, 2, 0)
#     mean = np.array([0.485, 0.456, 0.406])
#     std  = np.array([0.229, 0.224, 0.225])
#     img = std * img + mean  # Denormalize
#     img = np.clip(img, 0, 1)
#     plt.imshow(img)
#     plt.title(title)
#     plt.axis('off')

# # Plot original and transformed versions
# plt.figure(figsize=(12, 4))
# plt.subplot(1, 3, 1)
# plt.imshow(raw_img)
# plt.title("Original Image")
# plt.axis('off')

# plt.subplot(1, 3, 2)
# imshow(train_img_transformed, "Train Transform (RandomCrop)")

# plt.subplot(1, 3, 3)
# imshow(val_img_transformed, "Val/Test Transform (CenterCrop)")
# plt.tight_layout()
# plt.show()

### Build ResNet18 model

In [None]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)


### Training

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50, patience=5):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_val_loss = float('inf')
    early_stop_counter = 0

    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    since = time.time()
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            if phase == "train":
                train_loss_history.append(epoch_loss)
                train_acc_history.append(epoch_acc.item())
            else:
                val_loss_history.append(epoch_loss)
                val_acc_history.append(epoch_acc.item())

                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    early_stop_counter = 0
                else:
                    early_stop_counter += 1

            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

        if early_stop_counter >= patience:
            print("Early stopping triggered!")
            break

    time_elapsed = time.time() - since
    print(f"\nTraining complete in {int(time_elapsed//60)}m {int(time_elapsed % 60)}s")
    print(f"Best val Acc: {best_acc:.4f}")

    model.load_state_dict(best_model_wts)
    return model, (train_loss_history, val_loss_history, train_acc_history, val_acc_history)


In [None]:
best_model, history = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience)
train_loss_history, val_loss_history, train_acc_history, val_acc_history = history


### Evaluate on test set

In [None]:
def evaluate_model(model, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss = running_loss / len(test_loader.dataset)
    test_acc = running_corrects.double() / len(test_loader.dataset)
    return test_loss, test_acc, all_preds, all_labels

test_loss, test_acc, all_preds, all_labels = evaluate_model(best_model, test_loader, criterion)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

### Confusion Matrix on Test Set

In [None]:
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
plt.figure(figsize=(10, 10))
disp.plot(cmap=plt.cm.Blues, values_format='d')
plt.title("Confusion Matrix")
plt.xticks(rotation=90)
plt.tight_layout() 
plt.savefig(os.path.join(save_dir, "confusion_matrix.png"), dpi=300)
plt.show()


In [None]:
report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
report_path = os.path.join(save_dir, "classification_report.txt")
with open(report_path, "w") as f:
    f.write(report)
print("Classification Report:\n")
print(report)


### Plot Training Curves

In [None]:
epochs_range = range(1, len(train_loss_history) + 1)

plt.figure(figsize=(6, 5))
plt.plot(epochs_range, train_loss_history, label="Train Loss")
plt.plot(epochs_range, val_loss_history, label="Val Loss")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "loss_curve.png"), dpi=300)
plt.show()

plt.figure(figsize=(6, 5))
plt.plot(epochs_range, train_acc_history, label="Train Acc")
plt.plot(epochs_range, val_acc_history, label="Val Acc")
plt.title("Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "accuracy_curve.png"), dpi=300)
plt.show()

### Save the Best Model

In [None]:
model_save_path = "./models/best_model.pth"
torch.save(best_model.state_dict(), model_save_path)
print(f"Best model saved to {model_save_path}")