In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms, models
from torchvision.models import EfficientNet_V2_S_Weights
from PIL import Image
import pandas as pd
import numpy as np
import os
import csv
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import platform
import multiprocessing

# Custom dataset class for handling stability data
class StabilityDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, augment=False, image_size=224, zoom_proportion=0.15):
        self.stability_data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.augment = augment
        self.image_size = image_size
        self.zoom_proportion = zoom_proportion
        self.augmented_indices = self._create_augmented_indices() if augment else None

    # Create indices for augmented data
    def _create_augmented_indices(self):
        base_indices = list(range(len(self.stability_data)))
        flipped_indices = [idx + len(self.stability_data) for idx in base_indices]
        zoomed_indices = [idx + 2 * len(self.stability_data) for idx in base_indices]
        zoomed_flipped_indices = [idx + 3 * len(self.stability_data) for idx in base_indices]
        return base_indices + flipped_indices + zoomed_indices + zoomed_flipped_indices

    def __len__(self):
        return len(self.stability_data) * 4 if self.augment else len(self.stability_data)

    def __getitem__(self, idx):
        if self.augment:
            original_idx = idx % len(self.stability_data)
            augmentation = idx // len(self.stability_data)
        else:
            original_idx = idx
            augmentation = 0

        # Load image and stability class
        img_name = str(self.stability_data.iloc[original_idx, 0])
        img_path = os.path.join(self.img_dir, img_name)
        if not os.path.exists(img_path):
            img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        
        image = Image.open(img_path).convert('RGB')
        
        stability_height = self.stability_data.iloc[original_idx, -1]
        stability_class = int(stability_height) - 1

        # Apply augmentations if enabled
        if self.augment:
            if augmentation in [1, 3]:  # Flip
                image = image.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
            if augmentation in [2, 3]:  # Zoom
                width, height = image.size
                crop_size = int(min(width, height) * (1 - self.zoom_proportion))
                left = (width - crop_size) // 2
                top = (height - crop_size) // 2
                right = left + crop_size
                bottom = top + crop_size
                image = image.crop((left, top, right, bottom))

        # Resize the image to ensure consistent size
        image = image.resize((self.image_size, self.image_size), Image.BILINEAR)
        
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(stability_class, dtype=torch.long)

# Neural network model for stability prediction
class StabilityPredictor(nn.Module):
    def __init__(self, num_classes=6, dropout_rate=0.3):
        super(StabilityPredictor, self).__init__()
        weights = EfficientNet_V2_S_Weights.DEFAULT
        self.efficientnet = models.efficientnet_v2_s(weights=weights)
        num_ftrs = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate, inplace=True),
            nn.Linear(num_ftrs, num_classes)
        )

    def forward(self, x):
        return self.efficientnet(x)

# Function to train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, patience, device):
    model.to(device)
    
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model = None
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        
        # Training phase
        model.train()
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer, device, is_training=True)
        
        # Validation phase
        model.eval()
        val_loss, val_acc = run_epoch(model, val_loader, criterion, optimizer, device, is_training=False)
        
        # Learning rate scheduler step
        scheduler.step(val_loss)

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 60)

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            best_model = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            print(f'Early stopping triggered after {epoch + 1} epochs')
            model.load_state_dict(best_model)
            break

    return model

# Function to run a single epoch (training or validation)
def run_epoch(model, data_loader, criterion, optimizer, device, is_training=True):
    running_loss = 0.0
    correct = 0
    total = 0

    # Create progress bar
    progress_bar = tqdm(data_loader, desc="Training" if is_training else "Validating")

    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        if is_training:
            optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        if is_training:
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100. * correct / total:.2f}%'
        })
    
    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

# Function to calculate dataset statistics
def calculate_stats(dataset):
    loader = DataLoader(dataset, batch_size=100, num_workers=0, shuffle=False)
    mean = 0.
    std = 0.
    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
    
    mean /= len(dataset)
    std /= len(dataset)
    return mean, std

# Function to make predictions on the test set
def predict(model, test_loader, device):
    model.eval()
    predictions = []
    image_ids = []
    with torch.no_grad():
        for inputs, ids in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            predictions.extend(preds.cpu().numpy() + 1)  # Add 1 to convert back to 1-6 range
            image_ids.extend(ids.numpy())  # Convert tensor to numpy array
    return predictions, image_ids

# Main function to orchestrate the training and prediction process
def main(config):
    # Set up device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Create full dataset without normalization and augmentation
    full_dataset = StabilityDataset(csv_file=config['train_csv'], 
                                    img_dir=config['train_img_dir'], 
                                    transform=transforms.ToTensor(),
                                    augment=False,
                                    image_size=config['image_size'])

    # Split dataset into train and validation
    dataset_size = len(full_dataset)
    indices = list(range(dataset_size))
    np.random.shuffle(indices)
    split = int(np.floor(config['val_ratio'] * dataset_size))
    train_indices, val_indices = indices[split:], indices[:split]

    # Calculate statistics for training set only
    train_subset = Subset(full_dataset, train_indices)

    print("Calculating training dataset statistics...")
    train_mean, train_std = calculate_stats(train_subset)
    print(f"Training dataset mean: {train_mean}")
    print(f"Training dataset std: {train_std}")

    # Create transforms for both training and validation using training set statistics
    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=train_mean, std=train_std),
    ])

    # Create augmented training dataset and non-augmented validation dataset
    train_dataset = StabilityDataset(csv_file=config['train_csv'], 
                                     img_dir=config['train_img_dir'], 
                                     transform=data_transform,
                                     augment=config['use_augmentation'],
                                     image_size=config['image_size'],
                                     zoom_proportion=config['zoom_proportion'])
    train_dataset = Subset(train_dataset, [i for i in range(len(train_dataset)) if i % len(full_dataset) in train_indices])

    val_dataset = StabilityDataset(csv_file=config['train_csv'], 
                                   img_dir=config['train_img_dir'], 
                                   transform=data_transform,
                                   augment=False,
                                   image_size=config['image_size'])
    val_dataset = Subset(val_dataset, val_indices)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

    # Initialize model, criterion, optimizer, and scheduler
    model = StabilityPredictor(num_classes=config['num_classes'], dropout_rate=config['dropout_rate'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=config['lr_factor'], patience=config['lr_patience'])

    # Train model
    print('Training...')
    model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                        num_epochs=config['num_epochs'], patience=config['early_stopping_patience'], device=device)

    torch.save(model.state_dict(), config['model_save_path'])
    print("Training complete. Model saved.")

    # Prediction on test set
    test_dataset = StabilityDataset(csv_file=config['test_csv'],
                                    img_dir=config['test_img_dir'],
                                    transform=data_transform)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

    predictions, image_ids = predict(model, test_loader, device)

    # Save predictions to CSV
    with open(config['predictions_save_path'], 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['id', 'labels'])
        for img_id, pred in zip(image_ids, predictions):
            writer.writerow([int(img_id) + 1, int(pred)])  # Ensure both are integers
    print(f"Predictions saved to {config['predictions_save_path']}")

# Function to get optimal number of workers for data loading
def get_optimal_num_workers():
    # Windows can't do multiprocessing
    if platform.system() == 'Windows':
        return 0
    else:
        return multiprocessing.cpu_count()

# Hyperparameters and configuration
config = {
    # Path to the CSV file containing training data
    'train_csv': './COMP90086_2024_Project_train/train.csv',
    
    # Directory containing training images
    'train_img_dir': './COMP90086_2024_Project_train/train',
    
    # Path to the CSV file containing test data
    'test_csv': './COMP90086_2024_Project_test/test.csv',
    
    # Directory containing test images
    'test_img_dir': './COMP90086_2024_Project_test/test',
    
    # Size to which all images will be resized (224x224 pixels)
    'image_size': 224,
    
    # Proportion of training data to use for validation
    'val_ratio': 0.05,
    
    # Whether to use data augmentation during training
    'use_augmentation': True,
    
    # Proportion of zoom to apply during augmentation
    'zoom_proportion': 0.15,
    
    # Number of samples per batch for training and validation
    'batch_size': 16,
    
    # Number of subprocesses to use for data loading (automatically determined)
    'num_workers': get_optimal_num_workers(),
    
    # Number of classes in the classification task (6 stability levels)
    'num_classes': 6,
    
    # Dropout rate for regularization in the model
    'dropout_rate': 0.3,
    
    # Initial learning rate for the optimizer
    'learning_rate': 0.001,
    
    # Factor by which to reduce learning rate on plateau
    'lr_factor': 0.1,
    
    # Number of epochs with no improvement after which learning rate will be reduced
    'lr_patience': 2,
    
    # Maximum number of training epochs
    'num_epochs': 30,
    
    # Number of epochs with no improvement after which training will be stopped
    'early_stopping_patience': 5,
    
    # File path where the trained model will be saved
    'model_save_path': 'stability_predictor_efficientnetv2.pth',
    
    # File path where the predictions on the test set will be saved
    'predictions_save_path': 'predictions.csv'
}
    
# Run the main function
main(config)

Calculating training dataset statistics...
Training dataset mean: tensor([0.4675, 0.4409, 0.4062])
Training dataset std: tensor([0.2719, 0.2284, 0.1912])
Training...
Epoch 1/30




Training:   0%|          | 0/1824 [00:00<?, ?it/s]

KeyboardInterrupt: 