### Import Packages

In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset, ConcatDataset
from torchsummary import summary
import torchvision
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
from collections import Counter

# from google.colab import drive
# drive.mount('/content/gdrive')

In [2]:
# Set a random seed for reproducibility
def set_seed(seed_value=42):
    random.seed(seed_value)       # Python random module
    np.random.seed(seed_value)    # Numpy module
    torch.manual_seed(seed_value) # Torch
    os.environ['PYTHONHASHSEED'] = str(seed_value)  # Environment variable

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # if using multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(24)  # Call the function with your chosen seed

### Load Data

In [3]:
'TODO: Define transformations - crop or resize'
transform = transforms.Compose([
    # transforms.Resize((224, 224)), # Ensure correct size
    # transforms.CenterCrop((224, 224)),  # Crop the image at the center to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization for pre-trained models - Is this ideal normalization?
])
# Define data paths
# train_datapath = "/Users/dawsonhaddox/Documents/COSC 78/Final Project/Train Data"
# validation_datapath = "/Users/dawsonhaddox/Documents/COSC 78/Final Project/Validation Data"
# test_datapath = "/Users/dawsonhaddox/Documents/COSC 78/Final Project/Test Data"
train_datapath = "/Users/dawsonhaddox/Documents/COSC 78/Final Project/Train Data Entropy"
validation_datapath = "/Users/dawsonhaddox/Documents/COSC 78/Final Project/Validation Data Entropy"
test_datapath = "/Users/dawsonhaddox/Documents/COSC 78/Final Project/Test Data Entropy"

# Setup datasets using ImageFolder
train_dataset = datasets.ImageFolder(train_datapath, transform=transform)
val_dataset = datasets.ImageFolder(validation_datapath, transform=transform)
test_dataset = datasets.ImageFolder(test_datapath, transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

### Load and Perform Initial Modifications on ResNet Model

In [4]:
# Load pre-trained ResNet model
resnet = models.resnet18(pretrained=True)

# Freeze resnet parameters
for param in resnet.parameters():
  param.requires_grad = False

# Modify the final fully connected layer to match the number of classes (assuming three classes: Non Demented, Very Mild Dementia, Mild Dementia)
resnet.fc = nn.Linear(resnet.fc.in_features, 3)
resnet.fc.requires_grad = True

# Print model summary
# summary(resnet, input_size=(3, 224, 224))



### Fine Tune ResNet50 Model

In [5]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr=0.0001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [6]:
def train_and_validate(model, train_loader, val_loader, optimizer, scheduler, loss_func, epochs=25, patience=10, save_path='best_model.pth'):
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # To store the training and validation loss for plotting or analysis
    history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': [], 'train_f1': [], 'val_f1': []}

    best_val_loss = float('inf')
    patience_counter = 0  # Counter for the early stopping

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0
        train_preds, train_targets = [], []

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

            train_preds.extend(predicted.cpu().numpy())
            train_targets.extend(labels.cpu().numpy())

        train_accuracy = 100 * correct_train / total_train
        train_f1 = f1_score(train_targets, train_preds, average='weighted')
        epoch_train_loss = train_loss / len(train_loader.dataset)
        history['train_loss'].append(epoch_train_loss)
        history['train_accuracy'].append(train_accuracy)
        history['train_f1'].append(train_f1)

      # Scheduler step (commonly after training step, can be adjusted as per scheduler type)
        if scheduler != None:
            scheduler.step()

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        val_preds, val_targets = [], []

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} - Validation"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = loss_func(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)

                val_preds.extend(predicted.cpu().numpy())
                val_targets.extend(labels.cpu().numpy())

        val_accuracy = 100 * correct_val / total_val
        val_f1 = f1_score(val_targets, val_preds, average='weighted')
        epoch_val_loss = val_loss / len(val_loader.dataset)
        history['val_loss'].append(epoch_val_loss)
        history['val_accuracy'].append(val_accuracy)
        history['val_f1'].append(val_f1)

        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Train F1: {train_f1:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%, Validation F1: {val_f1:.4f}')

        # Check for improvement in validation loss
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)  # Save the best model
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Stopping early after {epoch + 1} epochs due to no improvement in validation loss.")
                model.load_state_dict(torch.load(save_path))  # Load the best model weights
                break

    return history


In [7]:
# Epoch [1/200], Train Loss: 0.6838, Train Accuracy: 77.68%, Train F1: 0.6893, Validation Loss: 0.6437, Validation Accuracy: 77.11%, Validation F1: 0.6714
history = train_and_validate(resnet, train_loader, val_loader, optimizer, None, loss_func, epochs=200)

Epoch 1/200 - Training: 100%|███████████████████| 98/98 [25:08<00:00, 15.40s/it]
Epoch 1/200 - Validation: 100%|█████████████████| 19/19 [04:40<00:00, 14.78s/it]


Epoch [1/200], Train Loss: 0.6838, Train Accuracy: 77.68%, Train F1: 0.6893, Validation Loss: 0.6437, Validation Accuracy: 77.11%, Validation F1: 0.6714


Epoch 2/200 - Training:  50%|█████████▌         | 49/98 [12:49<12:49, 15.70s/it]


KeyboardInterrupt: 

In [None]:
# Plotting training and validation accuracy
plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.plot(history['train_accuracy'], label='Train Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

# Plotting training and validation F1 score
plt.subplot(1, 3, 2)
plt.plot(history['train_f1'], label='Train F1 Score')
plt.plot(history['val_f1'], label='Validation F1 Score')
plt.title('Model F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.legend()

# Plotting training and validation loss
plt.subplot(1, 3, 3)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()