## Importing libraries

In [1]:
from torch.utils.data import DataLoader, Sampler, SubsetRandomSampler
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
from tqdm import tqdm  
import torch
from collections import Counter
from torch.utils.data import ConcatDataset
import random
import os
import torchvision.transforms as transforms
import pandas as pd

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

## Define filepaths as constant

In [None]:
# Define file paths as constants
CSV_FILE_PATH = r'C:\Users\Sandhra George\avalanche\data\dataset.csv'
ROOT_DIR_PATH = r'C:\Users\Sandhra George\avalanche\caxton_dataset\print24'

csv_file = r'C:\Users\Sandhra George\avalanche\data\dataset.csv'  # Path to the CSV file
root_dir = r'C:\Users\Sandhra George\avalanche\caxton_dataset\print24'  # Path to the image directory

## Load data into DataFrame and filter print24

In [None]:
# Load data into a DataFrame for easier processing
data = pd.read_csv(CSV_FILE_PATH)

# Limit dataset to the images between row indices 454 and 7058 (inclusive)
#data_limited = data.iloc[454:7059].reset_index(drop=True)

# Filter the dataset to only include images containing "print24"
data_filtered = data[data.iloc[:, 0].str.contains('print24', na=False)]

# Update the first column to contain only the image filenames
data_filtered.iloc[:, 0] = data_filtered.iloc[:, 0].str.replace(r'.*?/(image-\d+\.jpg)', r'\1', regex=True)

# Display the updated DataFrame
print("First rows of filtered DataFrame:")
print(data_filtered.head())

# Display the last few rows of the updated DataFrame
print("\nLast rows of filtered DataFrame:")
print(data_filtered.tail())

## Analysing the target hotend temperature column

In [None]:
# Extract unique temperatures in the 'target_hotend' column and sort them
unique_temperatures = sorted(data_filtered['target_hotend'].unique())  # Sort temperatures in ascending order

# Calculate the full range of temperatures (min and max)
temperature_min = data_filtered['target_hotend'].min()
temperature_max = data_filtered['target_hotend'].max()

# Print the unique temperatures (sorted), count, and full range
print("\nUnique target hotend temperatures in the dataset (sorted):")
print(unique_temperatures)
print(f"\nNumber of unique target hotend temperatures: {len(unique_temperatures)}")
print(f"Temperature range: {temperature_min} to {temperature_max}")

## Create a random temperature sub list and new dataframes with equal class distribution

In [None]:
# Extract unique temperatures and sort them
unique_temperatures = sorted(data_filtered['target_hotend'].unique())  # Sort temperatures in ascending order

# Check if we have enough unique temperatures to select from
if len(unique_temperatures) >= 50:
    # Select the lowest and highest temperatures
    temperature_min = unique_temperatures[0]
    temperature_max = unique_temperatures[-1]

    # Remove the lowest and highest temperatures from the unique temperatures list
    remaining_temperatures = [temp for temp in unique_temperatures if temp != temperature_min and temp != temperature_max]

    # Randomly select 40 other temperatures from the remaining ones
    random_temperatures = random.sample(remaining_temperatures, 40)

    # Add the random temperatures to the temperature_sublist
    temperature_sublist = [temperature_min, temperature_max] + random_temperatures
    
    # Sort from lowest to highest hotend temperature
    temperature_sublist = sorted(temperature_sublist)

    # Print the temperature sublist
    print("\nTemperature sublist:")
    print(temperature_sublist)
    
    # Split into three experience groups
    split_size = len(temperature_sublist) // 3
    experience_1 = temperature_sublist[:split_size]  # First third
    experience_2 = temperature_sublist[split_size:2*split_size]  # Second third
    experience_3 = temperature_sublist[2*split_size:]  # Last third

    # Print the results
    print("\nExperience Group 1:", experience_1)
    print("\nExperience Group 2:", experience_2)
    print("\nExperience Group 3:", experience_3)
else:
    print("Not enough unique temperatures to select from. At least 50 unique temperatures are required.")
    experience_1 = experience_2 = experience_3 = []

# Initialize a dictionary to store DataFrames for each class per experience
experience_datasets = {1: {}, 2: {}, 3: {}}

# Iterate through the three experience groups
for exp_id, experience_temps in enumerate([experience_1, experience_2, experience_3], start=1):
    if not experience_temps:
        print(f"Skipping Experience {exp_id} due to insufficient temperatures.")
        continue

    print(f"\nProcessing Experience {exp_id} with temperatures: {experience_temps}...")

    # Filter the dataset based on the current experience's temperature range
    exp_data = data_filtered[data_filtered['target_hotend'].isin(experience_temps)]
    
    # Check if exp_data is empty after filtering
    if exp_data.empty:
        print(f"No data found for Experience {exp_id} with temperatures {experience_temps}. Skipping...")
        continue

    # Create a dictionary to store class-wise data for this experience
    class_datasets = {}

    # Iterate through each class (0, 1, 2) and filter data
    for class_id in [0, 1, 2]:
        class_data = exp_data[exp_data['hotend_class'] == class_id]
        
        if class_data.empty:
            print(f"Warning: Class {class_id} in Experience {exp_id} has no data!")
        else:
            class_datasets[class_id] = class_data
            print(f"Class {class_id} dataset size in Experience {exp_id}: {len(class_data)}")

    # Ensure that all classes have data before proceeding to balance
    if len(class_datasets) != 3:
        print(f"Skipping Experience {exp_id} because one or more classes are missing data!")
        continue  # Skip processing this experience if any class has no data

    # Find the smallest class size in this experience
    min_class_size = min(len(class_datasets[class_id]) for class_id in class_datasets)
    print(f"Smallest class size in Experience {exp_id}: {min_class_size}")

    # Balance the dataset for this experience
    balanced_data = []

    for class_id in class_datasets:
        class_data = class_datasets[class_id]
        # Randomly sample 'min_class_size' images from the class data to balance class distribution
        sampled_class_data = class_data.sample(n=min_class_size, random_state=42)  # Sample equally
        balanced_data.append(sampled_class_data)

    # Combine all class data for this experience into one balanced dataset
    balanced_dataset = pd.concat(balanced_data).reset_index(drop=True)

    # Shuffle the final balanced dataset
    balanced_dataset = balanced_dataset.sample(frac=1, random_state=42).reset_index(drop=True)

    # Store the balanced dataset in the experience_datasets dictionary
    experience_datasets[exp_id] = balanced_dataset

    # Print summary for this experience
    print(f"\nBalanced dataset size for Experience {exp_id}: {len(balanced_dataset)}")
    print("Number of images in each class after balancing:")

    for class_id in [0, 1, 2]:
        class_count = len(balanced_dataset[balanced_dataset['hotend_class'] == class_id])
        print(f"Class {class_id}: {class_count} images")

    print("-" * 50)

# Print the first few rows for verification
for exp_id in [1, 2, 3]:
    if exp_id in experience_datasets:
        print(f"\nFirst five rows of Experience {exp_id} dataset:")
        print(experience_datasets[exp_id].head())

## Checking the class distribution of all the experience datasets

In [None]:
# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Check if the experience dataset exists (in case an experience was skipped)
    if exp_id in experience_datasets:
        # Select only the 'img_path' and 'hotend_class' columns
        balanced_dataset_filtered = experience_datasets[exp_id][['img_path', 'hotend_class']]

        # Check the class distribution in the filtered dataset
        class_distribution = balanced_dataset_filtered['hotend_class'].value_counts()
        
        # Print the class distribution for the current experience
        print(f"\nClass distribution for Experience {exp_id}:")
        print(class_distribution)

## Printing the indices, the classes, and the number of images in each class

In [None]:
# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Check if the experience dataset exists (in case an experience was skipped)
    if exp_id in experience_datasets:
        # Select only the 'img_path' and 'hotend_class' columns for the current experience dataset
        balanced_dataset_filtered = experience_datasets[exp_id][['img_path', 'hotend_class']]

        # Get the class distribution for the current experience dataset
        class_distribution = balanced_dataset_filtered['hotend_class'].value_counts()
        
        # Step 1: Print the indices, the classes, and the number of images in each class
        print(f"\n--- Experience {exp_id} ---")
        for class_label in class_distribution.index:
            # Get all indices for the current class
            class_indices = balanced_dataset_filtered[balanced_dataset_filtered['hotend_class'] == class_label].index.tolist()

            # Count the number of images for the current class
            num_images_in_class = len(class_indices)

            # Print the details for this class
            print(f"\nClass: {class_label} (Total images: {num_images_in_class})")
            print("Indices: ", class_indices)
            print(f"Number of images in class {class_label}: {num_images_in_class}")

        # Step 2: Get the number of unique classes
        num_classes = len(class_distribution)

        # Step 3: Set a small batch size
        small_batch_size = 15  # You can change this to a value like 32, 64, etc.

        # Step 4: Calculate the number of samples per class per batch
        samples_per_class = small_batch_size // num_classes  # Ensure it's divisible

        # Make sure we don't ask for more samples than available in the smallest class
        samples_per_class = min(samples_per_class, class_distribution.min())

        # Step 5: Calculate the total batch size
        batch_size = samples_per_class * num_classes

        print(f"\nRecommended Small Batch Size for Experience {exp_id}: {batch_size}")
        print(f"Samples per class in Experience {exp_id}: {samples_per_class}")
        print("-" * 50)  # To separate each experience's results

## Create training, validation, and testing datasets

In [None]:
# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Check if the experience dataset exists (in case an experience was skipped)
    if exp_id in experience_datasets:
        # Select only the 'img_path' and 'hotend_class' columns for the current experience dataset
        balanced_dataset_filtered = experience_datasets[exp_id][['img_path', 'hotend_class']]

        # Number of images per class (this will be the same after balancing)
        num_images_per_class = len(balanced_dataset_filtered) // 3  # Assuming there are 3 classes (0, 1, 2)

        # Calculate the number of samples per class for train, validation, and test sets
        train_size = int(0.7 * num_images_per_class)
        valid_size = int(0.15 * num_images_per_class)
        test_size = num_images_per_class - train_size - valid_size

        # Lists to hold indices for each class's dataset (train, validation, test)
        train_indices, valid_indices, test_indices = [], [], []

        # Split the data by class (assuming classes are 0, 1, 2)
        for class_label in [0, 1, 2]:
            class_data = balanced_dataset_filtered[balanced_dataset_filtered['hotend_class'] == class_label].index.tolist()

            # Shuffle the indices of the current class
            random.shuffle(class_data)

            # Split the indices for each class into train, validation, and test
            train_indices.extend(class_data[:train_size])
            valid_indices.extend(class_data[train_size:train_size + valid_size])
            test_indices.extend(class_data[train_size + valid_size:])

        # Sort the indices to ensure consistent processing
        train_indices, valid_indices, test_indices = sorted(train_indices), sorted(valid_indices), sorted(test_indices)

        # Create DataFrames for train, validation, and test sets based on the indices
        globals()[f'train_{exp_id}'] = balanced_dataset_filtered.loc[train_indices].reset_index(drop=True)
        globals()[f'valid_{exp_id}'] = balanced_dataset_filtered.loc[valid_indices].reset_index(drop=True)
        globals()[f'test_{exp_id}'] = balanced_dataset_filtered.loc[test_indices].reset_index(drop=True)

        # Count class distribution for each of the datasets
        def count_class_distribution(indices):
            class_counts = [0, 0, 0]  # Assuming 3 classes (0, 1, 2)
            for index in indices:
                class_label = balanced_dataset_filtered.loc[index, 'hotend_class']
                class_counts[class_label] += 1
            return class_counts

        # Count class distribution for each of the datasets
        train_class_distribution = count_class_distribution(train_indices)
        valid_class_distribution = count_class_distribution(valid_indices)
        test_class_distribution = count_class_distribution(test_indices)

        # Print the class distribution and dataset sizes
        print(f"\n--- Experience {exp_id} ---")
        print(f"Train set size: {len(train_indices)} | Class distribution: {train_class_distribution}")
        print(f"Validation set size: {len(valid_indices)} | Class distribution: {valid_class_distribution}")
        print(f"Test set size: {len(test_indices)} | Class distribution: {test_class_distribution}")

        print(f"Experience {exp_id} datasets created successfully!\n")

# Now, the datasets are directly available as:
# train_1, valid_1, test_1, train_2, valid_2, test_2, train_3, valid_3, test_3

## Check for Missing or Invalid Labels in Training, Validation, and Test Data

In [None]:
# Check for any missing labels or invalid labels
print(train_1['hotend_class'].isnull().sum())  # Count missing labels
print(train_1['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(train_2['hotend_class'].isnull().sum())  # Count missing labels
print(train_2['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(train_3['hotend_class'].isnull().sum())  # Count missing labels
print(train_3['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(valid_1['hotend_class'].isnull().sum())  # Count missing labels
print(valid_1['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(valid_2['hotend_class'].isnull().sum())  # Count missing labels
print(valid_2['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(valid_3['hotend_class'].isnull().sum())  # Count missing labels
print(valid_3['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(test_1['hotend_class'].isnull().sum())  # Count missing labels
print(test_1['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(test_2['hotend_class'].isnull().sum())  # Count missing labels
print(test_2['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(test_3['hotend_class'].isnull().sum())  # Count missing labels
print(test_3['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

## Balanced Dataset class

In [None]:
# Define the dataset class
class BalancedDataset(Dataset):
    def __init__(self, data_frame, root_dir, transform=None):
        self.data = data_frame
        self.root_dir = root_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Validate that the images exist in the directory
        self.valid_indices = self.get_valid_indices()

    def get_valid_indices(self):
        valid_indices = []
        for idx in tqdm(range(len(self.data)), desc="Validating images"):
            img_name = self.data.iloc[idx, 0].strip()
            img_name = img_name.split('/')[-1]  # Extract file name
            
            if img_name.startswith("image-"):
                try:
                    # Ensure we only include images in the valid range
                    image_number = int(img_name.split('-')[1].split('.')[0])
                    if 4 <= image_number <= 26637:
                        full_img_path = os.path.join(self.root_dir, img_name)
                        if os.path.exists(full_img_path):
                            valid_indices.append(idx)
                        else:
                            print(f"Image does not exist: {full_img_path}")
                except ValueError:
                    print(f"Invalid filename format for {img_name}. Skipping...")
        
        print(f"Total valid indices found: {len(valid_indices)}")  # Debugging output
        return valid_indices

    def __len__(self):
            return len(self.valid_indices)
    
    def __getitem__(self, idx):
        # Wrap around the index if it exceeds the length of valid indices.
        idx = idx % len(self.valid_indices)
        
        # Get the actual index from valid indices.
        actual_idx = self.valid_indices[idx]
        img_name = self.data.iloc[actual_idx, 0].strip()
        full_img_path = os.path.join(self.root_dir, img_name)
        
        # Extract the label from the DataFrame (assumed to be in column index 1).
        label = self.data.iloc[actual_idx, 1]
        
        try:
            # Attempt to open the image and convert it to RGB.
            image = Image.open(full_img_path).convert('RGB')
            
            # Apply transformations if defined.
            if self.transform:
                image = self.transform(image)
            
            # Return image, label, and a dummy task_label (set as None).
            return image, label
        except (OSError, IOError, ValueError) as e:
            print(f"Error loading image {full_img_path}: {e}")
            # If error occurs, try the next valid index.
            return self.__getitem__((idx + 1) % len(self.valid_indices))

## Balanced Batch Sampler class

In [None]:
class BalancedBatchSampler(Sampler):
    def __init__(self, data_frame, batch_size=15, samples_per_class=5):
        """
        data_frame: Pandas DataFrame with image paths and their respective class labels.
        batch_size: Total batch size.
        samples_per_class: Number of samples to draw from each class per batch.
        """
        self.data_frame = data_frame
        self.batch_size = batch_size
        self.samples_per_class = samples_per_class
        self.num_classes = len(data_frame['hotend_class'].unique())
        
        if self.batch_size % self.num_classes != 0:
            raise ValueError("Batch size must be divisible by the number of classes.")

        self.class_indices = {
            class_id: self.data_frame[self.data_frame['hotend_class'] == class_id].index.tolist()
            for class_id in self.data_frame['hotend_class'].unique()
        }
        
        # Shuffle class indices initially
        for class_id in self.class_indices:
            random.shuffle(self.class_indices[class_id])

        self.num_samples_per_epoch = sum(len(indices) for indices in self.class_indices.values())
        self.indices_used = {class_id: [] for class_id in self.class_indices}

    def __iter__(self):
        batches = []

        # Replenish indices for each class
        for class_id in self.class_indices:
            if not self.class_indices[class_id]:
                raise ValueError(f"Class {class_id} has no samples. Cannot form balanced batches.")

            # Shuffle and use all indices from this class
            self.indices_used[class_id] = self.class_indices[class_id].copy()
            random.shuffle(self.indices_used[class_id])

        # Generate balanced batches
        while len(batches) * self.batch_size < self.num_samples_per_epoch:
            batch = []
            for class_id in self.indices_used:
                if len(self.indices_used[class_id]) < self.samples_per_class:
                    # If a class runs out of samples, reshuffle and replenish
                    self.indices_used[class_id] = self.class_indices[class_id].copy()
                    random.shuffle(self.indices_used[class_id])

                # Take `samples_per_class` indices from the current class
                batch.extend(self.indices_used[class_id][:self.samples_per_class])
                self.indices_used[class_id] = self.indices_used[class_id][self.samples_per_class:]

            # Shuffle the batch and append
            random.shuffle(batch)
            batches.append(batch)

        return iter(batches)

    def __len__(self):
        # Total number of batches per epoch
        return self.num_samples_per_epoch // self.batch_size

In [None]:
# Define a dictionary to store datasets and DataLoaders
datasets = {}
dataloaders = {}

# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Ensure the dataset exists
    if f"train_{exp_id}" in globals():
        train_data = globals()[f"train_{exp_id}"]
        val_data = globals()[f"valid_{exp_id}"]
        test_data = globals()[f"test_{exp_id}"]

        # Create dataset instances
        datasets[f"train_{exp_id}"] = BalancedDataset(data_frame=train_data, root_dir=root_dir)
        datasets[f"valid_{exp_id}"] = BalancedDataset(data_frame=val_data, root_dir=root_dir)
        datasets[f"test_{exp_id}"] = BalancedDataset(data_frame=test_data, root_dir=root_dir)

        # Create batch samplers for balanced training
        train_sampler = BalancedBatchSampler(data_frame=train_data, batch_size=15, samples_per_class=5)
        val_sampler = BalancedBatchSampler(data_frame=val_data, batch_size=15, samples_per_class=5)
        test_sampler = BalancedBatchSampler(data_frame=test_data, batch_size=15, samples_per_class=5)

        # Create DataLoaders
        dataloaders[f"train_{exp_id}"] = DataLoader(datasets[f"train_{exp_id}"], batch_sampler=train_sampler, shuffle=False)
        dataloaders[f"valid_{exp_id}"] = DataLoader(datasets[f"valid_{exp_id}"], batch_sampler=val_sampler, shuffle=False)
        dataloaders[f"test_{exp_id}"] = DataLoader(datasets[f"test_{exp_id}"], batch_sampler=test_sampler)

        # Print dataset lengths
        print(f"   Experience {exp_id} datasets and DataLoaders created successfully!")
        print(f"   Train dataset length: {len(datasets[f'train_{exp_id}'])}")
        print(f"   Validation dataset length: {len(datasets[f'valid_{exp_id}'])}")
        print(f"   Test dataset length: {len(datasets[f'test_{exp_id}'])}")

## Checking class distribution in each dataset

In [None]:
def count_classes(dataset):
    # Extract labels from the dataset's data attribute (assuming labels are in column 1)
    values = [x for x in dataset.data.iloc[:, 1]]
    # Convert the list of values to a tensor.
    t = torch.tensor(values)
    # Convert the tensor to a NumPy array and count the classes.
    return Counter(t.numpy())

print("Class distribution in Train Dataset 1:", count_classes(datasets["train_1"]))
print("Class distribution in Train Dataset 2:", count_classes(datasets["train_2"]))
print("Class distribution in Train Dataset 3:", count_classes(datasets["train_3"]))
print("Class distribution in Validation Dataset 1:", count_classes(datasets["valid_1"]))
print("Class distribution in Validation Dataset 2:", count_classes(datasets["valid_2"]))
print("Class distribution in Validation Dataset 3:", count_classes(datasets["valid_3"]))
print("Class distribution in Test Dataset 1:", count_classes(datasets["test_1"]))
print("Class distribution in Test Dataset 2:", count_classes(datasets["test_2"]))
print("Class distribution in Test Dataset 3:", count_classes(datasets["test_3"]))

## Creating experience datasets

In [None]:
# Experience 1 datasets (just single datasets)
exp1_train = datasets["train_1"]
exp1_valid = datasets["valid_1"]
exp1_test  = datasets["test_1"]

# Experience 1_2 datasets (concatenating the corresponding datasets)
exp1_2_train = ConcatDataset([datasets["train_1"], datasets["train_2"]])
exp1_2_valid = ConcatDataset([datasets["valid_1"], datasets["valid_2"]])
exp1_2_test  = ConcatDataset([datasets["test_1"],  datasets["test_2"]])

# Experience 1_2_3 datasets (concatenating all three experiences)
exp1_2_3_train = ConcatDataset([datasets["train_1"], datasets["train_2"], datasets["train_3"]])
exp1_2_3_valid = ConcatDataset([datasets["valid_1"], datasets["valid_2"], datasets["valid_3"]])
exp1_2_3_test  = ConcatDataset([datasets["test_1"],  datasets["test_2"],  datasets["test_3"]])

## Checking class distribution in each experience dataset

In [None]:
def count_classes(dataset):
    counts = Counter()
    # If the dataset is a ConcatDataset, iterate through its sub-datasets
    if isinstance(dataset, ConcatDataset):
        for d in dataset.datasets:
            values = [x for x in d.data.iloc[:, 1]]  # assuming labels are in column 1
            t = torch.tensor(values)
            counts.update(Counter(t.numpy()))
    else:
        values = [x for x in dataset.data.iloc[:, 1]]
        t = torch.tensor(values)
        counts = Counter(t.numpy())
    return counts

# Assuming you have already defined the new experience datasets:
# Experience 1 datasets
print("Class distribution in Experience 1 train dataset:", count_classes(exp1_train))
print("Class distribution in Experience 1 valid dataset:", count_classes(exp1_valid))
print("Class distribution in Experience 1 test dataset:", count_classes(exp1_test))

# Experience 1_2 datasets
print("Class distribution in Experience 1_2 train dataset:", count_classes(exp1_2_train))
print("Class distribution in Experience 1_2 valid dataset:", count_classes(exp1_2_valid))
print("Class distribution in Experience 1_2 test dataset:", count_classes(exp1_2_test))

# Experience 1_2_3 datasets
print("Class distribution in Experience 1_2_3 train dataset:", count_classes(exp1_2_3_train))
print("Class distribution in Experience 1_2_3 valid dataset:", count_classes(exp1_2_3_valid))
print("Class distribution in Experience 1_2_3 test dataset:", count_classes(exp1_2_3_test))

## Benchmark experiment

In [None]:
import os
import csv
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset
from torchmetrics import ConfusionMatrix
from models.cnn_models import SimpleCNN
from sklearn.metrics import confusion_matrix

# -------------------------------------------------------------------
# Assume that you have already defined your following datasets:
#   exp1_train, exp1_valid, exp1_test,
#   exp1_2_train, exp1_2_valid, exp1_2_test,
#   exp1_2_3_train, exp1_2_3_valid, exp1_2_3_test
#
# Also assume that your model class (SimpleCNN) and confusion matrix
# class (ConfusionMatrix) are defined.
# -------------------------------------------------------------------

# Define the experiment configurations in a dictionary.
experiments = {
    "experience_1": (exp1_train, exp1_valid, exp1_test),
    "experience_1_2": (exp1_2_train, exp1_2_valid, exp1_2_test),
    "experience_1_2_3": (exp1_2_3_train, exp1_2_3_valid, exp1_2_3_test)
}

# Create the top-level benchmark experiment folder.
benchmark_folder = "benchmark_experiment"
os.makedirs(benchmark_folder, exist_ok=True)

# Training settings
num_epochs = 30
batch_size = 15
num_classes = 3  # update if needed

# Loop over each experiment configuration.
for exp_name, (train_dataset, val_dataset, test_dataset) in experiments.items():
    print(f"\nStarting experiment: {exp_name}\n")
    
    # Create a subfolder for this experiment.
    exp_folder = os.path.join(benchmark_folder, exp_name)
    os.makedirs(exp_folder, exist_ok=True)
    
    # Set the best model path (e.g., benchmark_experiment/experience_1/model_experience_1.pth)
    best_model_path = os.path.join(exp_folder, f"model_{exp_name}.pth")
    
    # Create DataLoaders for train, validation, and test.
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Set device to GPU if available.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize model, loss function, optimizer, and scheduler.
    model = SimpleCNN(num_classes=num_classes).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    # Initialize confusion matrix trackers.
    train_cm = ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)
    val_cm = ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)
    
    # For plotting losses.
    train_losses = []
    val_losses = []
    
    # Create CSV file to store epoch losses in the experiment folder.
    csv_file_path = os.path.join(exp_folder, "training_validation_losses.csv")
    header = ["Epoch", "Training Loss", "Validation Loss"]
    if not os.path.exists(csv_file_path):
        with open(csv_file_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(header)
    
    best_val_accuracy = 0.0
    start_epoch = 0  # always start fresh for each experiment
    
    # ----------------- Training and Validation Loop -----------------
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        model.train()  # Set model to training mode
        
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        class_counts = [0] * num_classes
        
        # Training phase with progress bar.
        for images, labels in tqdm(train_loader, desc="Training", leave=False):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            print(f"Outputs (Raw): {outputs}")  # Log raw outputs
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            
            # Update confusion matrix and class counts.
            train_cm.update(predicted, labels)
            for label in labels:
                class_counts[label.item()] += 1
                
            # Print predicted vs actual labels for each batch.
            for i in range(len(labels)):
                print(f"Predicted: {predicted[i].item()}, Actual: {labels[i].item()}")
        
        train_epoch_loss = running_loss / total_samples
        train_epoch_accuracy = correct_predictions / total_samples
        print(f"Training Loss: {train_epoch_loss:.4f}, Training Accuracy: {train_epoch_accuracy:.4f}")
        print(f"Training Class Distribution: {class_counts}")
        
        # Update learning rate scheduler.
        scheduler.step()
        
        train_losses.append(train_epoch_loss)
        
        # Compute and save training confusion matrix.
        cm_train = train_cm.compute()
        print(f"Training Confusion Matrix:\n{cm_train}")
        sns.heatmap(cm_train.cpu().numpy(), annot=True, fmt="d", cmap="Blues",
                    xticklabels=range(num_classes), yticklabels=range(num_classes))
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title(f'Training Confusion Matrix - Epoch {epoch + 1}')
        output_path_train = os.path.join(exp_folder, f"training_confusion_matrix_epoch_{epoch + 1}.png")
        plt.savefig(output_path_train)
        plt.clf()  # Clear the plot
        print(f"Training Confusion Matrix saved to: {output_path_train}")
        train_cm.reset()
        
        # ----------------- Validation Phase -----------------
        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        val_correct_predictions = 0
        val_total_samples = 0
        val_class_counts = [0] * num_classes
        
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validating", leave=False):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                print(f"Outputs (Raw): {outputs}")  # Log raw outputs
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct_predictions += (predicted == labels).sum().item()
                val_total_samples += labels.size(0)
                
                val_cm.update(predicted, labels)
                for label in labels:
                    val_class_counts[label.item()] += 1
                
                for i in range(len(labels)):
                    print(f"Predicted: {predicted[i].item()}, Actual: {labels[i].item()}")
        
        val_epoch_loss = val_loss / val_total_samples
        val_epoch_accuracy = val_correct_predictions / val_total_samples
        print(f"Validation Loss: {val_epoch_loss:.4f}, Validation Accuracy: {val_epoch_accuracy:.4f}")
        print(f"Validation Class Distribution: {val_class_counts}")
        val_losses.append(val_epoch_loss)
        
        cm_val = val_cm.compute()
        print(f"Validation Confusion Matrix:\n{cm_val}")
        sns.heatmap(cm_val.cpu().numpy(), annot=True, fmt="d", cmap="Blues",
                    xticklabels=range(num_classes), yticklabels=range(num_classes))
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title(f'Validation Confusion Matrix - Epoch {epoch + 1}')
        output_path_val = os.path.join(exp_folder, f"validation_confusion_matrix_epoch_{epoch + 1}.png")
        plt.savefig(output_path_val)
        plt.clf()
        print(f"Validation Confusion Matrix saved to: {output_path_val}")
        val_cm.reset()
        
        # Save the best model if validation accuracy improves.
        if val_epoch_accuracy > best_val_accuracy:
            best_val_accuracy = val_epoch_accuracy
            torch.save({
                "epoch": epoch + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_val_accuracy": best_val_accuracy
            }, best_model_path)
            print(f"Saved best model for {exp_name} at epoch {epoch + 1} with accuracy {best_val_accuracy:.4f}")
        
        # Append losses to CSV.
        with open(csv_file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch + 1, train_epoch_loss, val_epoch_loss])
    
    print(f"Experiment {exp_name} training complete. Losses saved to: {csv_file_path}")
    
    # ----------------- Testing Phase -----------------
    def test_model(model, test_loader):
        model.eval()
        correct_predictions = 0
        total_samples = 0
        test_class_counts = [0] * num_classes
        all_labels = []
        all_predictions = []
        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc="Testing", leave=False):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                correct_predictions += (predicted == labels).sum().item()
                total_samples += labels.size(0)
                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())
                for label in labels:
                    test_class_counts[label.item()] += 1
                for i in range(len(labels)):
                    print(f"Predicted: {predicted[i].item()}, Actual: {labels[i].item()}")
        avg_accuracy = correct_predictions / total_samples
        print(f"Test Accuracy: {avg_accuracy:.4f}")
        print(f"Test Class Distribution: {test_class_counts}")
        
        # Compute confusion matrix using sklearn
        from sklearn.metrics import confusion_matrix
        cm_test = confusion_matrix(all_labels, all_predictions, labels=list(range(num_classes)))
        print(f"Test Confusion Matrix:\n{cm_test}")
        
        sns.heatmap(cm_test, annot=True, fmt="d", cmap="Blues",
                    xticklabels=range(num_classes), yticklabels=range(num_classes))
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Test Confusion Matrix')
        output_path_test = os.path.join(exp_folder, "test_confusion_matrix.png")
        plt.savefig(output_path_test)
        plt.clf()
        print(f"Test Confusion Matrix saved to: {output_path_test}")

    
    test_model(model, test_loader)
    
    # (Optionally, you can also plot the training and validation losses for each experiment.)
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss', color='blue')
    plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', color='red')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Losses for {exp_name}')
    plt.legend()
    plt.grid(True)
    loss_plot_path = os.path.join(exp_folder, "training_validation_loss.png")
    plt.savefig(loss_plot_path)
    plt.clf()
    print(f"Training and Validation Loss plot saved to: {loss_plot_path}")

print("\nAll benchmark experiments completed.")