In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms, models
from torchvision.models import EfficientNet_V2_S_Weights, convnext_base
from PIL import Image
import pandas as pd
import numpy as np
import os
import csv
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import platform
import multiprocessing
import cv2
import json

class StabilityDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, augment=False, use_quantized=False, 
                 additional_columns=None, target_column=None, balance_dataset=False, is_test=False):
        self.stability_data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.augment = augment
        self.use_quantized = use_quantized
        self.additional_columns = additional_columns or []
        self.target_column = target_column if not is_test else None
        self.is_test = is_test
        self.image_files = self._get_image_files()
        self.feature_categories = self._get_feature_categories()
        
        if balance_dataset and self.target_column and not is_test:
            self._balance_dataset()

    def _get_feature_categories(self):
        feature_categories = {}
        for col in self.additional_columns:
            if col in self.stability_data.columns:
                unique_values = self.stability_data[col].unique()
                feature_categories[col] = {
                    'num_categories': len(unique_values),
                    'value_to_index': {val: idx for idx, val in enumerate(unique_values)}
                }
        if self.target_column and not self.is_test:
            unique_values = self.stability_data[self.target_column].unique()
            feature_categories[self.target_column] = {
                'num_categories': len(unique_values),
                'value_to_index': {val: idx for idx, val in enumerate(unique_values)}
            }
        return feature_categories

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        original_idx = idx // 4 if self.augment else idx
        image_id = self.stability_data.iloc[original_idx, 0]

        if self.transform:
            image = self.transform(image)
        else:
            image = torch.from_numpy(image.transpose((2, 0, 1))).float() / 255.0

        additional_data = []
        for col in self.additional_columns:
            if col in self.feature_categories:
                value = self.stability_data.iloc[original_idx][col]
                index = self.feature_categories[col]['value_to_index'][value]
                additional_data.append(torch.tensor(index, dtype=torch.long))

        if self.target_column and not self.is_test:
            target_value = self.stability_data.iloc[original_idx][self.target_column]
            target_class = self.feature_categories[self.target_column]['value_to_index'][target_value]
            return (image, image_id, torch.tensor(target_class, dtype=torch.long), *additional_data)
        else:
            return (image, image_id, *additional_data)


    def _balance_dataset(self):
        if self.target_column is None:
            return

        # Count occurrences of each class
        class_counts = self.stability_data[self.target_column].value_counts()
        min_class_count = class_counts.min()

        # Undersample each class
        balanced_data = []
        for class_label in class_counts.index:
            class_data = self.stability_data[self.stability_data[self.target_column] == class_label]
            balanced_data.append(class_data.sample(min_class_count, replace=False))

        # Combine the balanced classes
        self.stability_data = pd.concat(balanced_data).reset_index(drop=True)

        # Update image files based on the balanced dataset
        self.image_files = self._get_image_files()


    def _get_image_files(self):
        image_files = []
        for idx, row in self.stability_data.iterrows():
            img_name = str(row.iloc[0])
            if self.use_quantized:
                image_files.append(f"quantized/{img_name}_quantized.jpg")
                if self.augment:
                    image_files.extend([
                        f"quantized/{img_name}_flipped_quantized.jpg",
                        f"quantized/{img_name}_zoomed_quantized.jpg",
                        f"quantized/{img_name}_zoomed_flipped_quantized.jpg"
                    ])
            else:
                image_files.append(f"{img_name}_original.jpg")
                if self.augment:
                    image_files.extend([
                        f"{img_name}_flipped.jpg",
                        f"{img_name}_zoomed.jpg",
                        f"{img_name}_zoomed_flipped.jpg"
                    ])
        return image_files

    def __len__(self):
        return len(self.image_files)

    def get_feature_dimensions(self):
        return {col: info['num_categories'] for col, info in self.feature_categories.items() if col != self.target_column}

    def get_target_dimension(self):
        return self.feature_categories[self.target_column]['num_categories']

class StabilityPredictor(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.3, additional_features=None):
        super(StabilityPredictor, self).__init__()

        # Default pre-trained weights
        weights = EfficientNet_V2_S_Weights.DEFAULT
        self.efficientnet = models.efficientnet_v2_s(weights=weights)

        # Get the number of input features to the final classifier layer
        num_ftrs = self.efficientnet.classifier[1].in_features

        # Embedding layers for additional features
        self.additional_features = additional_features or {}
        self.embedding_layers = nn.ModuleDict()
        self.embedding_dim = 16  # You can adjust this value
        total_embedding_dim = 0

        for feature, num_categories in self.additional_features.items():
            self.embedding_layers[feature] = nn.Embedding(num_categories, self.embedding_dim)
            total_embedding_dim += self.embedding_dim

        # Combine image features with embeddings
        self.combined_layer = nn.Linear(num_ftrs + total_embedding_dim, num_ftrs)

        # Replace the default classifier with a custom one (Dropout + Linear layer)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate, inplace=True),
            nn.Linear(num_ftrs, num_classes)
        )

    def forward(self, x, *additional_inputs):
        # Process the image through EfficientNet
        x = self.efficientnet.features(x)
        x = self.efficientnet.avgpool(x)
        x = torch.flatten(x, 1)

        # Process additional features through embedding layers
        embeddings = []
        for i, (feature, _) in enumerate(self.additional_features.items()):
            embedding = self.embedding_layers[feature](additional_inputs[i])
            embeddings.append(embedding)

        # Concatenate image features with embeddings
        if embeddings:
            x = torch.cat([x] + embeddings, dim=1)
            x = self.combined_layer(x)

        # Final classification
        x = self.classifier(x)
        return x


class EfficientAttentionNet(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.3, additional_features=None):
        super(EfficientAttentionNet, self).__init__()

        # Default pre-trained weights for EfficientNet V2 Small
        weights = EfficientNet_V2_S_Weights.DEFAULT
        self.efficientnet = models.efficientnet_v2_s(weights=weights)

        # Spatial attention module
        self.spatial_attention = SpatialAttentionModule(kernel_size=7)

        # Get the number of input features to the final classifier layer
        num_ftrs = self.efficientnet.classifier[1].in_features

        # Embedding layers for additional features
        self.additional_features = additional_features or {}
        self.embedding_layers = nn.ModuleDict()
        self.embedding_dim = 16
        total_embedding_dim = 0

        for feature, num_categories in self.additional_features.items():
            self.embedding_layers[feature] = nn.Embedding(num_categories, self.embedding_dim)
            total_embedding_dim += self.embedding_dim

        # Combine image features with embeddings
        self.combined_layer = nn.Linear(num_ftrs + total_embedding_dim, num_ftrs)

        # Replace the default classifier with a custom one (Dropout + Linear layer)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate, inplace=True),
            nn.Linear(num_ftrs, num_classes)
        )

    def forward(self, x, *additional_inputs):
        # Pass through the feature extractor (EfficientNet backbone) until the last feature map
        features = self.efficientnet.features(x)  # Extract convolutional features
        
        # Apply spatial attention module to the feature maps
        features = self.spatial_attention(features)
        
        # Global average pooling
        x = self.efficientnet.avgpool(features)
        
        # Flatten the pooled features
        x = torch.flatten(x, 1)

        # Process additional features through embedding layers
        embeddings = []
        for i, (feature, _) in enumerate(self.additional_features.items()):
            embedding = self.embedding_layers[feature](additional_inputs[i])
            embeddings.append(embedding)

        # Concatenate image features with embeddings
        if embeddings:
            x = torch.cat([x] + embeddings, dim=1)
            x = self.combined_layer(x)

        # Final classification
        x = self.classifier(x)
        return x

class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionModule, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Channel-wise max and average pooling (along spatial dimensions)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avg_out, max_out], dim=1)
        attention_map = self.sigmoid(self.conv(concat))
        return x * attention_map

class EfficientChannelAttentionNet(nn.Module):
    def __init__(self, num_classes=6, dropout_rate=0.0, additional_features=None):
        super(EfficientChannelAttentionNet, self).__init__()

        # Default pre-trained weights for EfficientNet V2 Small
        weights = EfficientNet_V2_S_Weights.DEFAULT
        self.efficientnet = models.efficientnet_v2_s(weights=weights)

        # Add channel attention modules after specific layers in the EfficientNet backbone
        self.channel_attention1 = ChannelAttentionModule(in_planes=24)  # After first block (features[1])
        self.channel_attention2 = ChannelAttentionModule(in_planes=48)  # After second block (features[2])

        # Get the number of input features to the final classifier layer
        num_ftrs = self.efficientnet.classifier[1].in_features

        # Embedding layers for additional features
        self.additional_features = additional_features or {}
        self.embedding_layers = nn.ModuleDict()
        self.embedding_dim = 16
        total_embedding_dim = 0

        for feature, num_categories in self.additional_features.items():
            self.embedding_layers[feature] = nn.Embedding(num_categories, self.embedding_dim)
            total_embedding_dim += self.embedding_dim

        # Combine image features with embeddings
        self.combined_layer = nn.Linear(num_ftrs + total_embedding_dim, num_ftrs)

        # Replace the default classifier with a custom one (Dropout + Linear layer)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate, inplace=True),
            nn.Linear(num_ftrs, num_classes)
        )

    def forward(self, x, *additional_inputs):
        # Pass input through the first few layers of EfficientNet
        x = self.efficientnet.features[0](x)  # Initial convolution and stem
        x = self.efficientnet.features[1](x)  # First block (channels: 24)
        x = self.channel_attention1(x)  # Apply channel attention after the first block
        
        x = self.efficientnet.features[2](x)  # Second block (channels: 48)
        x = self.channel_attention2(x)  # Apply channel attention after the second block
        
        # Continue with the rest of the EfficientNet layers
        for i in range(3, len(self.efficientnet.features)):
            x = self.efficientnet.features[i](x)

        # Global average pooling
        x = self.efficientnet.avgpool(x)
        x = torch.flatten(x, 1)

        # Process additional features through embedding layers
        embeddings = []
        for i, (feature, _) in enumerate(self.additional_features.items()):
            embedding = self.embedding_layers[feature](additional_inputs[i])
            embeddings.append(embedding)

        # Concatenate image features with embeddings
        if embeddings:
            x = torch.cat([x] + embeddings, dim=1)
            x = self.combined_layer(x)

        # Final classification
        x = self.classifier(x)

        return x

class ChannelAttentionModule(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        attention = self.sigmoid(avg_out + max_out)
        return x * attention
    
class ConvnextPredictor(nn.Module):
    def __init__(self, num_classes=6, freeze_layers=True, additional_features=None):
        super(ConvnextPredictor, self).__init__()

        # Default pre-trained weights
        weights = models.convnext.ConvNeXt_Base_Weights.DEFAULT
        self.convnextnet = convnext_base(weights=weights)

        # Get the number of input features to the final classifier layer
        num_ftrs = self.convnextnet.classifier[2].in_features

        # Embedding layers for additional features
        self.additional_features = additional_features or {}
        self.embedding_layers = nn.ModuleDict()
        self.embedding_dim = 16  # You can adjust this value
        total_embedding_dim = 0

        for feature, num_categories in self.additional_features.items():
            self.embedding_layers[feature] = nn.Embedding(num_categories, self.embedding_dim)
            total_embedding_dim += self.embedding_dim

        # Combine ConvNeXt features with embeddings
        self.combined_layer = nn.Linear(num_ftrs + total_embedding_dim, num_ftrs)

        # Replace the default classifier with a custom one
        self.classifier = nn.Sequential(
            nn.LayerNorm(num_ftrs),  # ConvNeXt uses LayerNorm instead of BatchNorm
            nn.Flatten(start_dim=1),
            nn.Linear(num_ftrs, num_classes)
        )

        if freeze_layers:
            print('Layers frozen!')
            # Freeze ConvNeXt backbone layers for quicker fine-tuning training
            for param in self.convnextnet.parameters():
                param.requires_grad = False

            # Only unfreeze the classifier layers and the combined layer
            for param in self.classifier.parameters():
                param.requires_grad = True
            for param in self.combined_layer.parameters():
                param.requires_grad = True
            for embedding_layer in self.embedding_layers.values():
                for param in embedding_layer.parameters():
                    param.requires_grad = True

    def forward(self, x, *additional_inputs):
        # Pass through ConvNeXt backbone
        x = self.convnextnet.features(x)
        x = self.convnextnet.avgpool(x)
        x = torch.flatten(x, 1)

        # Process additional features through embedding layers
        embeddings = []
        for i, (feature, _) in enumerate(self.additional_features.items()):
            embedding = self.embedding_layers[feature](additional_inputs[i])
            embeddings.append(embedding)

        # Concatenate ConvNeXt features with embeddings
        if embeddings:
            x = torch.cat([x] + embeddings, dim=1)
            x = self.combined_layer(x)

        # Final classification
        x = self.classifier(x)

        return x

def colour_quantisation(image, k=20):
    # Convert the image to 2D pixel array
    pixels = np.float32(image.reshape(-1, 3))

    # Define criteria for K-Means (stop after 10 iter or if accuracy reaches 1.0)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)

    # Apply K-Means clustering
    _, labels, palette = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)

    # Convert back to 8-bit values
    quantised = np.uint8(palette)[labels.flatten()]

    # Reshape the image to original dimensions
    quantised = quantised.reshape(image.shape)
    
    return quantised

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, patience, device, use_full_dataset, save_path=None, load_path=None):
    model.to(device)
   
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model = None
    start_epoch = 0
    lr_schedule = None
   
    if use_full_dataset and load_path:
        with open(load_path, 'r') as f:
            loaded_params = json.load(f)
        print("Loaded training parameters")
        lr_schedule = {int(k): v for k, v in loaded_params['lr_schedule'].items()}
        num_epochs = max(num_epochs, max(lr_schedule.keys()) + 1)
   
    for epoch in range(start_epoch, num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
       
        # Apply pre-determined learning rate if using full dataset and lr_schedule is available
        if use_full_dataset and lr_schedule and epoch in lr_schedule:
            new_lr = lr_schedule[epoch]
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
            print(f'Learning rate set to {new_lr} for epoch {epoch+1}')
       
        # Training phase
        model.train()
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer, device, is_training=True)
       
        # Validation phase (if not using full dataset)
        if not use_full_dataset:
            model.eval()
            val_loss, val_acc = run_epoch(model, val_loader, criterion, optimizer, device, is_training=False)
           
            # Learning rate scheduler step
            scheduler.step(val_loss)
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        else:
            if not lr_schedule:
                scheduler.step()
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
       
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 60)

        # Early stopping check (only if not using full dataset)
        if not use_full_dataset:
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0
                best_model = model.state_dict()
            else:
                epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f'Early stopping triggered after {epoch + 1} epochs')
                model.load_state_dict(best_model)
                break

    if not use_full_dataset and save_path:
        # Save the number of epochs and current learning rate
        save_params = {
            'epochs': epoch + 1,
            'lr_schedule': {epoch: group['lr'] for group in optimizer.param_groups}
        }
        with open(save_path, 'w') as f:
            json.dump(save_params, f)
        
            print("Saved training parameters")


    return model

def run_epoch(model, data_loader, criterion, optimizer, device, is_training=True):
    running_loss = 0.0
    correct = 0
    total = 0

    # Create progress bar
    progress_bar = tqdm(data_loader, desc="Training" if is_training else "Validating")

    for batch in progress_bar:
        inputs = batch[0].to(device)
        labels = batch[2].to(device)
        additional_inputs = [feature.to(device) for feature in batch[3:]]  # Change this from batch[2:] to batch[3:]
        
        if is_training:
            optimizer.zero_grad()
        
        outputs = model(inputs, *additional_inputs)
        loss = criterion(outputs, labels)
        
        if is_training:
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100. * correct / total:.2f}%'
        })
    
    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

def predict(model, test_loader, device):
    model.eval()
    predictions = []
    image_ids = []
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch[0].to(device)
            ids = batch[1]
            additional_inputs = [feature.to(device) for feature in batch[2:]]  # Changed from batch[3:]
            outputs = model(inputs, *additional_inputs)
            _, preds = torch.max(outputs, 1)
            predictions.extend(preds.cpu().numpy() + 1)  # Add 1 to convert back to 1-6 range
            image_ids.extend(ids.numpy())
    return predictions, image_ids

# Windows can't do multicore processing
def get_optimal_num_workers():
    if platform.system() == 'Windows':
        return 0
    else:
        return multiprocessing.cpu_count()
    
def load_stats(stats_file):
    with open(stats_file, 'r') as f:
        stats = json.load(f)
    return stats['mean'], stats['std']

def train_and_save(config):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Determine which stats file to use
    stats_folder = 'dataset_stats'
    stats_file = 'full_quantized.json' if config['use_quantized'] else 'full.json'
    if not config.get('use_full_dataset', False):
        stats_file = 'split_quantized.json' if config['use_quantized'] else 'split.json'
    
    stats_path = os.path.join(stats_folder, stats_file)
    mean, std = load_stats(stats_path)
    
    normalize_transform = transforms.Normalize(mean=mean, std=std)
    base_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize_transform,
    ])
    
    if config.get('use_full_dataset', False):
        # Use full dataset without validation
        full_dataset = StabilityDataset(csv_file=config['full_training_csv'],
                                        img_dir=config['train_img_dir'],
                                        transform=base_transform,
                                        augment=config['use_augmentation'],
                                        use_quantized=config['use_quantized'],
                                        additional_columns=config['additional_columns'],
                                        target_column=config['target_column'],
                                        balance_dataset=config['balance_dataset'])
        
        train_loader = DataLoader(full_dataset, batch_size=config['batch_size'], 
                                  shuffle=True, num_workers=get_optimal_num_workers())
        val_loader = None
    else:
        # Use split dataset with validation
        train_dataset = StabilityDataset(csv_file=config['split_training_csv'],
                                         img_dir=config['train_img_dir'],
                                         transform=base_transform,
                                         augment=config['use_augmentation'],
                                         use_quantized=config['use_quantized'],
                                         additional_columns=config['additional_columns'],
                                         target_column=config['target_column'],
                                         balance_dataset=config['balance_dataset'])
        
        val_dataset = StabilityDataset(csv_file=config['validation_csv'],
                                       img_dir=config['train_img_dir'],
                                       transform=base_transform,
                                       augment=False,
                                       use_quantized=config['use_quantized'],
                                       additional_columns=config['additional_columns'],
                                       target_column=config['target_column'],
                                       balance_dataset=False)
        
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                                  shuffle=True, num_workers=get_optimal_num_workers())
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                                shuffle=False, num_workers=get_optimal_num_workers())
    
    # Get the number of categories for each additional feature
    additional_features = (full_dataset if config.get('use_full_dataset', False) else train_dataset).get_feature_dimensions()
    num_classes = (full_dataset if config.get('use_full_dataset', False) else train_dataset).get_target_dimension()
    
    # Initialize model, criterion, optimizer, and scheduler
    if config['model'] == 'StabilityPredictor':
        model = StabilityPredictor(num_classes=num_classes, dropout_rate=config['dropout_rate'], additional_features=additional_features)
    elif config['model'] == 'EfficientAttentionNet':
        model = EfficientAttentionNet(num_classes=num_classes, dropout_rate=config['dropout_rate'], additional_features=additional_features)
    elif config['model'] == 'EfficientChannelAttentionNet':
        model = EfficientChannelAttentionNet(num_classes=num_classes, dropout_rate=config['dropout_rate'], additional_features=additional_features)
    elif config['model'] == 'ConvnextPredictor':
        model = ConvnextPredictor(num_classes=num_classes, freeze_layers=config['freeze_layers'], additional_features=additional_features)
    else:
        print('Unrecognised model in config. Defaulting to StabilityPredictor (EfficientNet)')
        model = StabilityPredictor(num_classes=num_classes, dropout_rate=config['dropout_rate'], additional_features=additional_features)

    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = nn.CrossEntropyLoss()
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=config['lr_factor'], patience=config['lr_patience'], verbose=True)

    # Train model
    print('Training...')
    params_file = f"training_params/{config['model']}_training_params.json"
    model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                        num_epochs=config['num_epochs'], patience=config['early_stopping_patience'], 
                        device=device, use_full_dataset=config.get('use_full_dataset', False),
                        save_path=params_file if not config.get('use_full_dataset', False) else None,
                        load_path=params_file if config.get('use_full_dataset', False) else None)

    model_save_path = f"models/{config['model']}_trained_model.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"Saved model to {model_save_path}")

    # Prediction on test set
    if config['run_predictions']:
        test_dataset = StabilityDataset(csv_file=config['test_csv'],
                                        img_dir=config['test_img_dir'],
                                        transform=base_transform,
                                        use_quantized=config['use_quantized'],
                                        additional_columns=config['additional_columns'],
                                        is_test=True)  # Set is_test to True for test dataset
        test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=get_optimal_num_workers())

        predictions, image_ids = predict(model, test_loader, device)

        # Save predictions to CSV
        predictions_save_path = f"predictions/{config['model']}_predictions.csv"
        with open(predictions_save_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['id', config['target_column']])
            for img_id, pred in zip(image_ids, predictions):
                writer.writerow([int(img_id), int(pred)])
        print(f"Predictions saved to {predictions_save_path}")

In [None]:
config = {
    # Files
    'full_training_csv': './COMP90086_2024_Project_train/train.csv',
    'split_training_csv': 'split_train.csv',
    'validation_csv': 'split_val.csv',
    'train_img_dir': './preprocessed_images/train',
    'test_csv': './COMP90086_2024_Project_test/test.csv',
    'test_img_dir': './preprocessed_images/test',

    # Training parameters
    'model': 'EfficientAttentionNet',
    'target_column': 'stable_height',
    'additional_columns': [],
    'use_full_dataset': False,
    'balance_dataset': False,
    'run_predictions': False,
    'use_augmentation': True,
    'use_quantized': False,
    'batch_size': 16,
    'dropout_rate': 0.3,
    'learning_rate': 0.001,
    'lr_factor': 0.1,
    'lr_patience': 3,
    'freeze_layers': False,
    'num_epochs': 30,
    'early_stopping_patience': 6
}

train_and_save(config)

Training...
Epoch 1/30




Training:   0%|          | 0/1728 [00:00<?, ?it/s]