## Importing libraries

In [1]:
from torch.utils.data import DataLoader, Sampler, SubsetRandomSampler
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
from tqdm import tqdm  
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import random
import os
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import pandas as pd
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics,\
    loss_metrics, timing_metrics, cpu_usage_metrics, StreamConfusionMatrix,\
    disk_usage_metrics
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import EWC
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from avalanche.benchmarks import nc_benchmark
from models.cnn_models import SimpleCNN



# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)  # If using a GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

## Define filepaths as constant

In [3]:
# Define file paths as constants
CSV_FILE_PATH = r'C:\Users\Sandhra George\avalanche\data\dataset.csv'
ROOT_DIR_PATH = r'C:\Users\Sandhra George\avalanche\caxton_dataset\print24'

csv_file = r'C:\Users\Sandhra George\avalanche\data\dataset.csv'  # Path to the CSV file
root_dir = r'C:\Users\Sandhra George\avalanche\caxton_dataset\print24'  # Path to the image directory

## Load data into DataFrame and filter print24

In [4]:
# Load data into a DataFrame for easier processing
data = pd.read_csv(CSV_FILE_PATH)

# Limit dataset to the images between row indices 454 and 7058 (inclusive)
#data_limited = data.iloc[454:7059].reset_index(drop=True)

# Filter the dataset to only include images containing "print24"
data_filtered = data[data.iloc[:, 0].str.contains('print24', na=False)]

# Update the first column to contain only the image filenames
data_filtered.iloc[:, 0] = data_filtered.iloc[:, 0].str.replace(r'.*?/(image-\d+\.jpg)', r'\1', regex=True)

# Display the updated DataFrame
print("First rows of filtered DataFrame:")
print(data_filtered.head())

# Display the last few rows of the updated DataFrame
print("\nLast rows of filtered DataFrame:")
print(data_filtered.tail())

First rows of filtered DataFrame:
          img_path               timestamp  flow_rate  feed_rate  z_offset  \
99496  image-4.jpg  2020-10-07T11:45:35-86        100        100       0.0   
99497  image-5.jpg  2020-10-07T11:45:36-32        100        100       0.0   
99498  image-6.jpg  2020-10-07T11:45:36-79        100        100       0.0   
99499  image-7.jpg  2020-10-07T11:45:37-26        100        100       0.0   
99500  image-8.jpg  2020-10-07T11:45:37-72        100        100       0.0   

       target_hotend  hotend    bed  nozzle_tip_x  nozzle_tip_y  img_num  \
99496          205.0  204.86  64.83           654           560        3   
99497          205.0  204.62  65.08           654           560        4   
99498          205.0  204.62  65.08           654           560        5   
99499          205.0  204.62  65.08           654           560        6   
99500          205.0  204.62  65.08           654           560        7   

       print_id  flow_rate_class  feed_r

## Analysing the target hotend temperature column

In [5]:
# Extract unique temperatures in the 'target_hotend' column and sort them
unique_temperatures = sorted(data_filtered['target_hotend'].unique())  # Sort temperatures in ascending order

# Calculate the full range of temperatures (min and max)
temperature_min = data_filtered['target_hotend'].min()
temperature_max = data_filtered['target_hotend'].max()

# Print the unique temperatures (sorted), count, and full range
print("\nUnique target hotend temperatures in the dataset (sorted):")
print(unique_temperatures)
print(f"\nNumber of unique target hotend temperatures: {len(unique_temperatures)}")
print(f"Temperature range: {temperature_min} to {temperature_max}")


Unique target hotend temperatures in the dataset (sorted):
[180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0]

Number of unique target hotend temperatures: 51
Temperature range: 180.0 to 230.0


## Create a random temperature sub list and new dataframes with equal class distribution

In [8]:
# Extract unique temperatures and sort them
unique_temperatures = sorted(data_filtered['target_hotend'].unique())  # Sort temperatures in ascending order

# Check if we have enough unique temperatures to select from
if len(unique_temperatures) >= 50:
    # Select the lowest and highest temperatures
    temperature_min = unique_temperatures[0]
    temperature_max = unique_temperatures[-1]

    # Remove the lowest and highest temperatures from the unique temperatures list
    remaining_temperatures = [temp for temp in unique_temperatures if temp != temperature_min and temp != temperature_max]

    # Randomly select 40 other temperatures from the remaining ones
    random_temperatures = random.sample(remaining_temperatures, 40)

    # Add the random temperatures to the temperature_sublist
    temperature_sublist = [temperature_min, temperature_max] + random_temperatures
    
    # Sort from lowest to highest hotend temperature
    temperature_sublist = sorted(temperature_sublist)

    # Print the temperature sublist
    print("\nTemperature sublist:")
    print(temperature_sublist)
    
    # Split into three experience groups
    split_size = len(temperature_sublist) // 3
    experience_1 = temperature_sublist[:split_size]  # First third
    experience_2 = temperature_sublist[split_size:2*split_size]  # Second third
    experience_3 = temperature_sublist[2*split_size:]  # Last third

    # Print the results
    print("\nExperience Group 1:", experience_1)
    print("\nExperience Group 2:", experience_2)
    print("\nExperience Group 3:", experience_3)
else:
    print("Not enough unique temperatures to select from. At least 50 unique temperatures are required.")
    experience_1 = experience_2 = experience_3 = []

# Initialize a dictionary to store DataFrames for each class per experience
experience_datasets = {1: {}, 2: {}, 3: {}}

# Iterate through the three experience groups
for exp_id, experience_temps in enumerate([experience_1, experience_2, experience_3], start=1):
    if not experience_temps:
        print(f"Skipping Experience {exp_id} due to insufficient temperatures.")
        continue

    print(f"\nProcessing Experience {exp_id} with temperatures: {experience_temps}...")

    # Filter the dataset based on the current experience's temperature range
    exp_data = data_filtered[data_filtered['target_hotend'].isin(experience_temps)]
    
    # Check if exp_data is empty after filtering
    if exp_data.empty:
        print(f"No data found for Experience {exp_id} with temperatures {experience_temps}. Skipping...")
        continue

    # Create a dictionary to store class-wise data for this experience
    class_datasets = {}

    # Iterate through each class (0, 1, 2) and filter data
    for class_id in [0, 1, 2]:
        class_data = exp_data[exp_data['hotend_class'] == class_id]
        
        if class_data.empty:
            print(f"Warning: Class {class_id} in Experience {exp_id} has no data!")
        else:
            class_datasets[class_id] = class_data
            print(f"Class {class_id} dataset size in Experience {exp_id}: {len(class_data)}")

    # Ensure that all classes have data before proceeding to balance
    if len(class_datasets) != 3:
        print(f"Skipping Experience {exp_id} because one or more classes are missing data!")
        continue  # Skip processing this experience if any class has no data

    # Find the smallest class size in this experience
    min_class_size = min(len(class_datasets[class_id]) for class_id in class_datasets)
    print(f"Smallest class size in Experience {exp_id}: {min_class_size}")

    # Balance the dataset for this experience
    balanced_data = []

    for class_id in class_datasets:
        class_data = class_datasets[class_id]
        # Randomly sample 'min_class_size' images from the class data to balance class distribution
        sampled_class_data = class_data.sample(n=min_class_size, random_state=42)  # Sample equally
        balanced_data.append(sampled_class_data)

    # Combine all class data for this experience into one balanced dataset
    balanced_dataset = pd.concat(balanced_data).reset_index(drop=True)

    # Shuffle the final balanced dataset
    balanced_dataset = balanced_dataset.sample(frac=1, random_state=42).reset_index(drop=True)

    # Store the balanced dataset in the experience_datasets dictionary
    experience_datasets[exp_id] = balanced_dataset

    # Print summary for this experience
    print(f"\nBalanced dataset size for Experience {exp_id}: {len(balanced_dataset)}")
    print("Number of images in each class after balancing:")

    for class_id in [0, 1, 2]:
        class_count = len(balanced_dataset[balanced_dataset['hotend_class'] == class_id])
        print(f"Class {class_id}: {class_count} images")

    print("-" * 50)

# Print the first few rows for verification
for exp_id in [1, 2, 3]:
    if exp_id in experience_datasets:
        print(f"\nFirst five rows of Experience {exp_id} dataset:")
        print(experience_datasets[exp_id].head())


Temperature sublist:
[180.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0, 193.0, 195.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 225.0, 227.0, 228.0, 229.0, 230.0]

Experience Group 1: [180.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0, 193.0, 195.0]

Experience Group 2: [197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 210.0, 211.0, 212.0, 213.0, 214.0]

Experience Group 3: [215.0, 216.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 225.0, 227.0, 228.0, 229.0, 230.0]

Processing Experience 1 with temperatures: [180.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0, 193.0, 195.0]...
Class 0 dataset size in Experience 1: 5448
Class 1 dataset size in Experience 1: 759
Class 2 dataset size in Experience 1: 33
Smallest class size in Experience 1: 33

Balanced d

## Checking the class distribution of all the experience datasets

In [9]:
# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Check if the experience dataset exists (in case an experience was skipped)
    if exp_id in experience_datasets:
        # Select only the 'img_path' and 'hotend_class' columns
        balanced_dataset_filtered = experience_datasets[exp_id][['img_path', 'hotend_class']]

        # Check the class distribution in the filtered dataset
        class_distribution = balanced_dataset_filtered['hotend_class'].value_counts()
        
        # Print the class distribution for the current experience
        print(f"\nClass distribution for Experience {exp_id}:")
        print(class_distribution)


Class distribution for Experience 1:
hotend_class
1    33
2    33
0    33
Name: count, dtype: int64

Class distribution for Experience 2:
hotend_class
1    46
2    46
0    46
Name: count, dtype: int64

Class distribution for Experience 3:
hotend_class
0    249
1    249
2    249
Name: count, dtype: int64


## Printing the indices, the classes, and the number of images in each class

In [10]:
# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Check if the experience dataset exists (in case an experience was skipped)
    if exp_id in experience_datasets:
        # Select only the 'img_path' and 'hotend_class' columns for the current experience dataset
        balanced_dataset_filtered = experience_datasets[exp_id][['img_path', 'hotend_class']]

        # Get the class distribution for the current experience dataset
        class_distribution = balanced_dataset_filtered['hotend_class'].value_counts()
        
        # Step 1: Print the indices, the classes, and the number of images in each class
        print(f"\n--- Experience {exp_id} ---")
        for class_label in class_distribution.index:
            # Get all indices for the current class
            class_indices = balanced_dataset_filtered[balanced_dataset_filtered['hotend_class'] == class_label].index.tolist()

            # Count the number of images for the current class
            num_images_in_class = len(class_indices)

            # Print the details for this class
            print(f"\nClass: {class_label} (Total images: {num_images_in_class})")
            print("Indices: ", class_indices)
            print(f"Number of images in class {class_label}: {num_images_in_class}")

        # Step 2: Get the number of unique classes
        num_classes = len(class_distribution)

        # Step 3: Set a small batch size
        small_batch_size = 15  # You can change this to a value like 32, 64, etc.

        # Step 4: Calculate the number of samples per class per batch
        samples_per_class = small_batch_size // num_classes  # Ensure it's divisible

        # Make sure we don't ask for more samples than available in the smallest class
        samples_per_class = min(samples_per_class, class_distribution.min())

        # Step 5: Calculate the total batch size
        batch_size = samples_per_class * num_classes

        print(f"\nRecommended Small Batch Size for Experience {exp_id}: {batch_size}")
        print(f"Samples per class in Experience {exp_id}: {samples_per_class}")
        print("-" * 50)  # To separate each experience's results


--- Experience 1 ---

Class: 1 (Total images: 33)
Indices:  [0, 1, 6, 7, 12, 14, 20, 24, 25, 29, 35, 37, 40, 42, 47, 52, 57, 59, 61, 62, 63, 65, 67, 70, 71, 72, 74, 78, 79, 81, 84, 94, 98]
Number of images in class 1: 33

Class: 2 (Total images: 33)
Indices:  [2, 4, 5, 11, 18, 19, 21, 22, 27, 30, 33, 36, 39, 45, 54, 56, 58, 60, 64, 66, 68, 69, 73, 75, 77, 80, 88, 89, 90, 91, 92, 95, 97]
Number of images in class 2: 33

Class: 0 (Total images: 33)
Indices:  [3, 8, 9, 10, 13, 15, 16, 17, 23, 26, 28, 31, 32, 34, 38, 41, 43, 44, 46, 48, 49, 50, 51, 53, 55, 76, 82, 83, 85, 86, 87, 93, 96]
Number of images in class 0: 33

Recommended Small Batch Size for Experience 1: 15
Samples per class in Experience 1: 5
--------------------------------------------------

--- Experience 2 ---

Class: 1 (Total images: 46)
Indices:  [0, 4, 5, 6, 9, 10, 12, 13, 18, 19, 30, 33, 34, 42, 50, 53, 55, 61, 62, 64, 65, 66, 68, 69, 70, 76, 77, 86, 92, 93, 95, 97, 99, 101, 103, 104, 105, 106, 109, 110, 113, 114, 121

## At this point a balanced dataset for each experience has been created

## Create training, validation, and testing datasets

In [11]:
# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Check if the experience dataset exists (in case an experience was skipped)
    if exp_id in experience_datasets:
        # Select only the 'img_path' and 'hotend_class' columns for the current experience dataset
        balanced_dataset_filtered = experience_datasets[exp_id][['img_path', 'hotend_class']]

        # Number of images per class (this will be the same after balancing)
        num_images_per_class = len(balanced_dataset_filtered) // 3  # Assuming there are 3 classes (0, 1, 2)

        # Calculate the number of samples per class for train, validation, and test sets
        train_size = int(0.7 * num_images_per_class)
        valid_size = int(0.15 * num_images_per_class)
        test_size = num_images_per_class - train_size - valid_size

        # Lists to hold indices for each class's dataset (train, validation, test)
        train_indices, valid_indices, test_indices = [], [], []

        # Split the data by class (assuming classes are 0, 1, 2)
        for class_label in [0, 1, 2]:
            class_data = balanced_dataset_filtered[balanced_dataset_filtered['hotend_class'] == class_label].index.tolist()

            # Shuffle the indices of the current class
            random.shuffle(class_data)

            # Split the indices for each class into train, validation, and test
            train_indices.extend(class_data[:train_size])
            valid_indices.extend(class_data[train_size:train_size + valid_size])
            test_indices.extend(class_data[train_size + valid_size:])

        # Sort the indices to ensure consistent processing
        train_indices, valid_indices, test_indices = sorted(train_indices), sorted(valid_indices), sorted(test_indices)

        # Create DataFrames for train, validation, and test sets based on the indices
        globals()[f'train_{exp_id}'] = balanced_dataset_filtered.loc[train_indices].reset_index(drop=True)
        globals()[f'valid_{exp_id}'] = balanced_dataset_filtered.loc[valid_indices].reset_index(drop=True)
        globals()[f'test_{exp_id}'] = balanced_dataset_filtered.loc[test_indices].reset_index(drop=True)

        # Count class distribution for each of the datasets
        def count_class_distribution(indices):
            class_counts = [0, 0, 0]  # Assuming 3 classes (0, 1, 2)
            for index in indices:
                class_label = balanced_dataset_filtered.loc[index, 'hotend_class']
                class_counts[class_label] += 1
            return class_counts

        # Count class distribution for each of the datasets
        train_class_distribution = count_class_distribution(train_indices)
        valid_class_distribution = count_class_distribution(valid_indices)
        test_class_distribution = count_class_distribution(test_indices)

        # Print the class distribution and dataset sizes
        print(f"\n--- Experience {exp_id} ---")
        print(f"Train set size: {len(train_indices)} | Class distribution: {train_class_distribution}")
        print(f"Validation set size: {len(valid_indices)} | Class distribution: {valid_class_distribution}")
        print(f"Test set size: {len(test_indices)} | Class distribution: {test_class_distribution}")

        print(f"Experience {exp_id} datasets created successfully!\n")

# Now, the datasets are directly available as:
# train_1, valid_1, test_1, train_2, valid_2, test_2, train_3, valid_3, test_3


--- Experience 1 ---
Train set size: 69 | Class distribution: [23, 23, 23]
Validation set size: 12 | Class distribution: [4, 4, 4]
Test set size: 18 | Class distribution: [6, 6, 6]
Experience 1 datasets created successfully!


--- Experience 2 ---
Train set size: 96 | Class distribution: [32, 32, 32]
Validation set size: 18 | Class distribution: [6, 6, 6]
Test set size: 24 | Class distribution: [8, 8, 8]
Experience 2 datasets created successfully!


--- Experience 3 ---
Train set size: 522 | Class distribution: [174, 174, 174]
Validation set size: 111 | Class distribution: [37, 37, 37]
Test set size: 114 | Class distribution: [38, 38, 38]
Experience 3 datasets created successfully!



## Check for Missing or Invalid Labels in Training, Validation, and Test Data

In [12]:
# Check for any missing labels or invalid labels
print(train_1['hotend_class'].isnull().sum())  # Count missing labels
print(train_1['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(train_2['hotend_class'].isnull().sum())  # Count missing labels
print(train_2['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(train_3['hotend_class'].isnull().sum())  # Count missing labels
print(train_3['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(valid_1['hotend_class'].isnull().sum())  # Count missing labels
print(valid_1['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(valid_2['hotend_class'].isnull().sum())  # Count missing labels
print(valid_2['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(valid_3['hotend_class'].isnull().sum())  # Count missing labels
print(valid_3['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(test_1['hotend_class'].isnull().sum())  # Count missing labels
print(test_1['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(test_2['hotend_class'].isnull().sum())  # Count missing labels
print(test_2['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

print(test_3['hotend_class'].isnull().sum())  # Count missing labels
print(test_3['hotend_class'].unique())  # Check unique labels to ensure there are no unexpected values

0
[1 0 2]
0
[1 0 2]
0
[0 1 2]
0
[1 0 2]
0
[1 0 2]
0
[0 1 2]
0
[2 0 1]
0
[2 1 0]
0
[1 2 0]


## Balanced Dataset class

In [13]:
# Define the dataset class
class BalancedDataset(Dataset):
    def __init__(self, data_frame, root_dir, transform=None):
        self.data = data_frame
        self.root_dir = root_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Validate that the images exist in the directory
        self.valid_indices = self.get_valid_indices()

    def get_valid_indices(self):
        valid_indices = []
        for idx in tqdm(range(len(self.data)), desc="Validating images"):
            img_name = self.data.iloc[idx, 0].strip()
            img_name = img_name.split('/')[-1]  # Extract file name
            
            if img_name.startswith("image-"):
                try:
                    # Ensure we only include images in the valid range
                    image_number = int(img_name.split('-')[1].split('.')[0])
                    if 4 <= image_number <= 26637:
                        full_img_path = os.path.join(self.root_dir, img_name)
                        if os.path.exists(full_img_path):
                            valid_indices.append(idx)
                        else:
                            print(f"Image does not exist: {full_img_path}")
                except ValueError:
                    print(f"Invalid filename format for {img_name}. Skipping...")
        
        print(f"Total valid indices found: {len(valid_indices)}")  # Debugging output
        return valid_indices

    def __len__(self):
            return len(self.valid_indices)
    
    def __getitem__(self, idx):
        # Wrap around the index if it exceeds the length of valid indices
        idx = idx % len(self.valid_indices)
        
        # Get the actual index from valid indices
        actual_idx = self.valid_indices[idx]
        img_name = self.data.iloc[actual_idx, 0].strip()
        full_img_path = os.path.join(self.root_dir, img_name)
        label = self.targets[actual_idx]  # Get the label from the targets tensor
    
        try:
            # Attempt to open the image and convert to RGB
            image = Image.open(full_img_path).convert('RGB')
    
            # Apply transformations if defined
            if self._transform_groups.get('train'):
                image = self._transform_groups['train'](image)
    
            return image, label, task_label  # Return image, label, and task label
        except (OSError, IOError, ValueError) as e:
            # Print error message for debugging
            print(f"Error loading image {full_img_path}: {e}")
    
            # Handle gracefully by skipping the corrupted/missing file
            return self.__getitem__((idx + 1) % len(self.valid_indices))  # Try next valid index


## Balanced Batch Sampler class

In [14]:
class BalancedBatchSampler(Sampler):
    def __init__(self, data_frame, batch_size=15, samples_per_class=5):
        """
        data_frame: Pandas DataFrame with image paths and their respective class labels.
        batch_size: Total batch size.
        samples_per_class: Number of samples to draw from each class per batch.
        """
        self.data_frame = data_frame
        self.batch_size = batch_size
        self.samples_per_class = samples_per_class
        self.num_classes = len(data_frame['hotend_class'].unique())
        
        if self.batch_size % self.num_classes != 0:
            raise ValueError("Batch size must be divisible by the number of classes.")

        self.class_indices = {
            class_id: self.data_frame[self.data_frame['hotend_class'] == class_id].index.tolist()
            for class_id in self.data_frame['hotend_class'].unique()
        }
        
        # Shuffle class indices initially
        for class_id in self.class_indices:
            random.shuffle(self.class_indices[class_id])

        self.num_samples_per_epoch = sum(len(indices) for indices in self.class_indices.values())
        self.indices_used = {class_id: [] for class_id in self.class_indices}

    def __iter__(self):
        batches = []

        # Replenish indices for each class
        for class_id in self.class_indices:
            if not self.class_indices[class_id]:
                raise ValueError(f"Class {class_id} has no samples. Cannot form balanced batches.")

            # Shuffle and use all indices from this class
            self.indices_used[class_id] = self.class_indices[class_id].copy()
            random.shuffle(self.indices_used[class_id])

        # Generate balanced batches
        while len(batches) * self.batch_size < self.num_samples_per_epoch:
            batch = []
            for class_id in self.indices_used:
                if len(self.indices_used[class_id]) < self.samples_per_class:
                    # If a class runs out of samples, reshuffle and replenish
                    self.indices_used[class_id] = self.class_indices[class_id].copy()
                    random.shuffle(self.indices_used[class_id])

                # Take `samples_per_class` indices from the current class
                batch.extend(self.indices_used[class_id][:self.samples_per_class])
                self.indices_used[class_id] = self.indices_used[class_id][self.samples_per_class:]

            # Shuffle the batch and append
            random.shuffle(batch)
            batches.append(batch)

        return iter(batches)

    def __len__(self):
        # Total number of batches per epoch
        return self.num_samples_per_epoch // self.batch_size

In [15]:
# Define a dictionary to store datasets and DataLoaders
datasets = {}
dataloaders = {}

# Iterate over all experience datasets (1, 2, 3)
for exp_id in [1, 2, 3]:
    # Ensure the dataset exists
    if f"train_{exp_id}" in globals():
        train_data = globals()[f"train_{exp_id}"]
        val_data = globals()[f"valid_{exp_id}"]
        test_data = globals()[f"test_{exp_id}"]

        # Create dataset instances
        datasets[f"train_{exp_id}"] = BalancedDataset(data_frame=train_data, root_dir=root_dir)
        datasets[f"valid_{exp_id}"] = BalancedDataset(data_frame=val_data, root_dir=root_dir)
        datasets[f"test_{exp_id}"] = BalancedDataset(data_frame=test_data, root_dir=root_dir)

        # Create batch samplers for balanced training
        train_sampler = BalancedBatchSampler(data_frame=train_data, batch_size=15, samples_per_class=5)
        val_sampler = BalancedBatchSampler(data_frame=val_data, batch_size=15, samples_per_class=5)
        test_sampler = BalancedBatchSampler(data_frame=test_data, batch_size=15, samples_per_class=5)

        # Create DataLoaders
        dataloaders[f"train_{exp_id}"] = DataLoader(datasets[f"train_{exp_id}"], batch_sampler=train_sampler, shuffle=False)
        dataloaders[f"valid_{exp_id}"] = DataLoader(datasets[f"valid_{exp_id}"], batch_sampler=val_sampler, shuffle=False)
        dataloaders[f"test_{exp_id}"] = DataLoader(datasets[f"test_{exp_id}"], batch_sampler=test_sampler)

        # Print dataset lengths
        print(f"   Experience {exp_id} datasets and DataLoaders created successfully!")
        print(f"   Train dataset length: {len(datasets[f'train_{exp_id}'])}")
        print(f"   Validation dataset length: {len(datasets[f'valid_{exp_id}'])}")
        print(f"   Test dataset length: {len(datasets[f'test_{exp_id}'])}")


Validating images: 100%|██████████| 69/69 [00:00<00:00, 975.79it/s]


Total valid indices found: 69


Validating images: 100%|██████████| 12/12 [00:00<00:00, 928.58it/s]


Total valid indices found: 12


Validating images: 100%|██████████| 18/18 [00:00<00:00, 1182.90it/s]


Total valid indices found: 18
   Experience 1 datasets and DataLoaders created successfully!
   Train dataset length: 69
   Validation dataset length: 12
   Test dataset length: 18


Validating images: 100%|██████████| 96/96 [00:00<00:00, 984.85it/s]


Total valid indices found: 96


Validating images: 100%|██████████| 18/18 [00:00<00:00, 903.65it/s]


Total valid indices found: 18


Validating images: 100%|██████████| 24/24 [00:00<00:00, 933.33it/s]


Total valid indices found: 24
   Experience 2 datasets and DataLoaders created successfully!
   Train dataset length: 96
   Validation dataset length: 18
   Test dataset length: 24


Validating images: 100%|██████████| 522/522 [00:00<00:00, 895.07it/s]


Total valid indices found: 522


Validating images: 100%|██████████| 111/111 [00:00<00:00, 1153.15it/s]


Total valid indices found: 111


Validating images: 100%|██████████| 114/114 [00:00<00:00, 1215.80it/s]

Total valid indices found: 114
   Experience 3 datasets and DataLoaders created successfully!
   Train dataset length: 522
   Validation dataset length: 111
   Test dataset length: 114





## Setting up a new folder for each experiment

In [16]:
# Set base directory
base_dir = "experiments"
os.makedirs(base_dir, exist_ok=True)

# Function to get the next experiment folder
def get_experiment_folder(exp_num):
    return os.path.join(base_dir, f"Experiment_{exp_num:02d}")  # Keeps two-digit format (01, 02, ..., 10)

# Set initial experiment number
experiment_num = 1
experiment_folder = get_experiment_folder(experiment_num)

# Create the main experiment directory if it doesn't exist
os.makedirs(experiment_folder, exist_ok=True)

# Set model path inside experiment folder
model_path = os.path.join(experiment_folder, "best_model.pth")

# Create subdirectories for training, validation, and test confusion matrices
train_folder = os.path.join(experiment_folder, "training_confusion_matrices")
val_folder = os.path.join(experiment_folder, "validation_confusion_matrices")
test_folder = os.path.join(experiment_folder, "test_confusion_matrices")

# Ensure that the subdirectories exist
os.makedirs(train_folder, exist_ok=True)
os.makedirs(val_folder, exist_ok=True)
os.makedirs(test_folder, exist_ok=True)

# Print the directory where results will be saved
print(f"Saving results to: {experiment_folder}")

Saving results to: experiments\Experiment_01


## Display a Random Image from the Dataset with Its Label

In [17]:
import random
import os
import matplotlib.pyplot as plt

def save_random_image_from_experiment(exp_id, dataset_type):
    """
    Selects a random image from the specified dataset (train, valid, or test) for a given experience ID,
    loads it, displays it, and saves it to the corresponding experiment folder.

    Args:
        exp_id (int): The experience group number (1, 2, or 3).
        dataset_type (str): The dataset type - 'train', 'valid', or 'test'.
    """
    # Ensure the dataset exists
    dataset_key = f"{dataset_type}_{exp_id}"  # Example: 'train_1', 'valid_2', 'test_3'
    if dataset_key not in datasets:
        print(f"Dataset {dataset_key} not found!")
        return

    dataset = datasets[dataset_key]  # Retrieve the dataset
    data_frame = dataset.data  # Get the underlying DataFrame

    # Ensure the dataset is not empty
    if data_frame.empty:
        print(f"Dataset {dataset_key} is empty!")
        return

    # Select a random index
    random_index = random.choice(data_frame.index)
    img_path = os.path.join(root_dir, data_frame.iloc[random_index, 0].strip())
    label = data_frame.loc[random_index, 'hotend_class']

    # Load and display the image
    img = plt.imread(img_path)
    plt.imshow(img)
    plt.title(f"Label: {label}")

    # Define the path to save the image inside the current experiment folder
    experiment_folder = os.path.join("experiments", f"experiment_{exp_id}")
    os.makedirs(experiment_folder, exist_ok=True)  # Ensure folder exists

    output_path = os.path.join(experiment_folder, f"random_{dataset_type}.png")

    # Save the figure
    plt.savefig(output_path)
    plt.clf()  # Clear the plot to avoid overlaps

    print(f"Image saved to: {output_path}")

# Example Usage:
save_random_image_from_experiment(exp_id=1, dataset_type='train')  # Random training image from Experience 1
save_random_image_from_experiment(exp_id=2, dataset_type='valid')  # Random validation image from Experience 2
save_random_image_from_experiment(exp_id=3, dataset_type='test')   # Random test image from Experience 3

Image saved to: experiments\experiment_1\random_train.png
Image saved to: experiments\experiment_2\random_valid.png
Image saved to: experiments\experiment_3\random_test.png


<Figure size 640x480 with 0 Axes>

In [18]:
# Iterate over all experience groups
for exp_id in [1, 2, 3]:  
    dataset_key = f"train_{exp_id}"  # e.g., 'train_1', 'train_2', 'train_3'
    
    # Ensure the dataset exists
    if dataset_key in datasets:
        data_frame = datasets[dataset_key].data  # Access the DataFrame from BalancedDataset

        # Ensure the dataset is not empty
        if not data_frame.empty:
            # First image
            first_index = data_frame.index[0]
            first_image = data_frame.loc[first_index, 'img_path']
            first_label = data_frame.loc[first_index, 'hotend_class']
            print(f"Experience {exp_id} - First Image Path: {first_image}, First Label: {first_label}")

            # Last image
            last_index = data_frame.index[-1]
            last_image = data_frame.loc[last_index, 'img_path']
            last_label = data_frame.loc[last_index, 'hotend_class']
            print(f"Experience {exp_id} - Last Image Path: {last_image}, Last Label: {last_label}\n")
        else:
            print(f"Experience {exp_id} - Training dataset is empty!\n")
    else:
        print(f"Experience {exp_id} - Training dataset not found!\n")

Experience 1 - First Image Path: image-12253.jpg, First Label: 1
Experience 1 - Last Image Path: image-5677.jpg, Last Label: 2

Experience 2 - First Image Path: image-559.jpg, First Label: 1
Experience 2 - Last Image Path: image-522.jpg, Last Label: 2

Experience 3 - First Image Path: image-23305.jpg, First Label: 0
Experience 3 - Last Image Path: image-15003.jpg, Last Label: 0



## Creating an EWC Class which inherits from AvalancheDataset and contains all the expected functions

In [19]:
import os
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from avalanche.benchmarks.utils import AvalancheDataset, DataAttribute
from avalanche.benchmarks.utils.transforms import TupleTransform

class EWCCompatibleBalancedDataset(AvalancheDataset):
    def __init__(self, data_frame, root_dir=None, transform=None, task_label=0, indices=None):
        """
        Custom dataset compatible with EWC that inherits from AvalancheDataset.
        It loads images from disk, applies transforms, and provides sample-wise
        attributes for targets and task labels.
        
        Args:
            data_frame (pd.DataFrame or list): If a DataFrame, it must contain columns
                'image_path' and 'hotend_class'. If a list, it is assumed to be a pre-built
                list of datasets (used in subset calls).
            root_dir (str, optional): Directory where images are stored. Must be provided if data_frame is a DataFrame.
            transform (callable, optional): Transformations to apply.
            task_label (int, optional): Task label for continual learning.
            indices (Sequence[int], optional): Optional indices for subsetting.
        """
        # If data_frame is a list, assume this is a call from subset() and forward the call.
        if isinstance(data_frame, list):
            super().__init__(data_frame, indices=indices)
            return

        # Otherwise, data_frame is a DataFrame. Ensure root_dir is provided.
        if root_dir is None:
            raise ValueError("root_dir must be provided when data_frame is a DataFrame")
        
        # Reset DataFrame index for consistency.
        self.data = data_frame.reset_index(drop=True)
        self.root_dir = root_dir
        self.task_label = task_label

        # Define a default transform if none provided.
        default_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        # Wrap the transform in TupleTransform so that it applies only to the image element.
        self._transform_groups = {
            "train": TupleTransform([transform or default_transform]),
            "eval": TupleTransform([transform or default_transform])
        }
        
        # Ensure required columns exist.
        if 'hotend_class' not in self.data.columns:
            raise ValueError("DataFrame must contain 'hotend_class' for labels.")
        if 'image_path' not in self.data.columns:
            raise ValueError("DataFrame must contain 'image_path' for image paths.")
        
        # Validate image paths and obtain valid indices.
        valid_indices = self.get_valid_indices()
        if len(valid_indices) == 0:
            raise ValueError("No valid image paths found.")
        
        # Compute targets and task labels for valid samples.
        targets_data = torch.tensor(self.data.loc[valid_indices, 'hotend_class'].values)
        targets_task_labels_data = torch.full_like(targets_data, self.task_label)
        
        # Prepare sample entries (one per valid image).
        samples = []
        for idx in valid_indices:
            img_name = self.data.loc[idx, 'image_path'].strip()
            full_img_path = os.path.join(self.root_dir, img_name)
            label = int(self.data.loc[idx, 'hotend_class'])
            samples.append({
                "img_path": full_img_path,
                "label": label,
                "task_label": self.task_label
            })
        
        # Define an internal basic dataset that loads images.
        class BasicDataset(Dataset):
            def __init__(self, samples):
                self.samples = samples

            def __len__(self):
                return len(self.samples)

            def __getitem__(self, idx):
                sample = self.samples[idx]
                img_path = sample["img_path"]
                try:
                    # Load the image (ensure it is a PIL image).
                    image = Image.open(img_path).convert('RGB')
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")
                    # If an error occurs, try the next sample.
                    return self.__getitem__((idx + 1) % len(self.samples))
                return image, sample["label"], sample["task_label"]
        
        basic_dataset = BasicDataset(samples)
        
        # Create data attributes.
        data_attributes = [
            DataAttribute(targets_data, name="targets", use_in_getitem=True),
            DataAttribute(targets_task_labels_data, name="targets_task_labels", use_in_getitem=True)
        ]
        
        # IMPORTANT: Pass the basic_dataset inside a list so that AvalancheDataset
        # correctly sets up its internal flat data, and forward the indices parameter.
        super().__init__(
            [basic_dataset],
            data_attributes=data_attributes,
            transform_groups=self._transform_groups,
            indices=indices
        )
    
    def get_valid_indices(self):
        """Return indices for which the image file exists."""
        valid_indices = []
        for idx in tqdm(range(len(self.data)), desc="Validating images"):
            img_name = self.data.loc[idx, 'image_path'].strip()
            full_img_path = os.path.join(self.root_dir, img_name)
            if os.path.exists(full_img_path):
                valid_indices.append(idx)
            else:
                print(f"Image does not exist: {full_img_path}")
        print(f"Total valid images: {len(valid_indices)}")
        return valid_indices

## Creating training, validation and testing datasets to implement EWC

In [20]:
from torchvision import transforms

# Define the transformation (e.g., normalization)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Experience 1
train_dataset_exp1 = EWCCompatibleBalancedDataset(
    data_frame=train_1.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

val_dataset_exp1 = EWCCompatibleBalancedDataset(
    data_frame=valid_1.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

test_dataset_exp1 = EWCCompatibleBalancedDataset(
    data_frame=test_1.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

# Experience 2
train_dataset_exp2 = EWCCompatibleBalancedDataset(
    data_frame=train_2.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

val_dataset_exp2 = EWCCompatibleBalancedDataset(
    data_frame=valid_2.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

test_dataset_exp2 = EWCCompatibleBalancedDataset(
    data_frame=test_2.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

# Experience 3
train_dataset_exp3 = EWCCompatibleBalancedDataset(
    data_frame=train_3.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

val_dataset_exp3 = EWCCompatibleBalancedDataset(
    data_frame=valid_3.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

test_dataset_exp3 = EWCCompatibleBalancedDataset(
    data_frame=test_3.rename(columns={'img_path': 'image_path', 'class': 'hotend_class'}),
    root_dir=root_dir,
    transform=transform,
    task_label=0
)

Validating images: 100%|██████████| 69/69 [00:00<00:00, 3866.54it/s]


Total valid images: 69


Validating images: 100%|██████████| 12/12 [00:00<00:00, 2931.03it/s]


Total valid images: 12


Validating images: 100%|██████████| 18/18 [00:00<00:00, 4445.74it/s]


Total valid images: 18


Validating images: 100%|██████████| 96/96 [00:00<00:00, 6586.73it/s]


Total valid images: 96


Validating images: 100%|██████████| 18/18 [00:00<00:00, 3032.39it/s]


Total valid images: 18


Validating images: 100%|██████████| 24/24 [00:00<00:00, 2438.02it/s]


Total valid images: 24


Validating images: 100%|██████████| 522/522 [00:00<00:00, 5030.48it/s]


Total valid images: 522


Validating images: 100%|██████████| 111/111 [00:00<00:00, 5424.62it/s]


Total valid images: 111


Validating images: 100%|██████████| 114/114 [00:00<00:00, 4987.80it/s]

Total valid images: 114





## Creating Dataloaders for more efficient data processing

In [21]:
from torch.utils.data.dataloader import DataLoader

# Experience 1
train_sampler_exp1 = BalancedBatchSampler(data_frame=train_1.rename(columns={'img_path': 'image_path'}), 
                                          batch_size=15, samples_per_class=5)
val_sampler_exp1 = BalancedBatchSampler(data_frame=valid_1.rename(columns={'img_path': 'image_path'}), 
                                        batch_size=15, samples_per_class=5)
test_sampler_exp1 = BalancedBatchSampler(data_frame=test_1.rename(columns={'img_path': 'image_path'}), 
                                         batch_size=15, samples_per_class=5)

train_loader_exp1 = DataLoader(train_dataset_exp1, batch_sampler=train_sampler_exp1, shuffle=False)
val_loader_exp1 = DataLoader(val_dataset_exp1, batch_sampler=val_sampler_exp1, shuffle=False)
test_loader_exp1 = DataLoader(test_dataset_exp1, batch_sampler=test_sampler_exp1, shuffle=False)

# Experience 2
train_sampler_exp2 = BalancedBatchSampler(data_frame=train_2.rename(columns={'img_path': 'image_path'}), 
                                          batch_size=15, samples_per_class=5)
val_sampler_exp2 = BalancedBatchSampler(data_frame=valid_2.rename(columns={'img_path': 'image_path'}), 
                                        batch_size=15, samples_per_class=5)
test_sampler_exp2 = BalancedBatchSampler(data_frame=test_2.rename(columns={'img_path': 'image_path'}), 
                                         batch_size=15, samples_per_class=5)

train_loader_exp2 = DataLoader(train_dataset_exp2, batch_sampler=train_sampler_exp2, shuffle=False)
val_loader_exp2 = DataLoader(val_dataset_exp2, batch_sampler=val_sampler_exp2, shuffle=False)
test_loader_exp2 = DataLoader(test_dataset_exp2, batch_sampler=test_sampler_exp2, shuffle=False)

# Experience 3
train_sampler_exp3 = BalancedBatchSampler(data_frame=train_3.rename(columns={'img_path': 'image_path'}), 
                                          batch_size=15, samples_per_class=5)
val_sampler_exp3 = BalancedBatchSampler(data_frame=valid_3.rename(columns={'img_path': 'image_path'}), 
                                        batch_size=15, samples_per_class=5)
test_sampler_exp3 = BalancedBatchSampler(data_frame=test_3.rename(columns={'img_path': 'image_path'}), 
                                         batch_size=15, samples_per_class=5)

train_loader_exp3 = DataLoader(train_dataset_exp3, batch_sampler=train_sampler_exp3, shuffle=False)
val_loader_exp3 = DataLoader(val_dataset_exp3, batch_sampler=val_sampler_exp3, shuffle=False)
test_loader_exp3 = DataLoader(test_dataset_exp3, batch_sampler=test_sampler_exp3, shuffle=False)

# Print to check if the DataLoaders are created successfully
print("DataLoaders for all experiences created successfully!")

DataLoaders for all experiences created successfully!


## Checking if the datasets are AvalancheDatasets and whether they contain the correct Attributes

In [22]:
# Function to check if a dataset is an instance of AvalancheDataset
def check_avalanche_dataset(dataset):
    # Check if dataset is an instance of AvalancheDataset
    if isinstance(dataset, AvalancheDataset):
        print(f"Dataset is an instance of AvalancheDataset.")
    else:
        print(f"Dataset is NOT an instance of AvalancheDataset.")
        
    # Inspect the internal structure to understand where the data attributes are stored
    print(f"Dataset internal structure: {dir(dataset)}")

    # Check if dataset has the core attributes: 'data', 'targets', 'task_label'
    if hasattr(dataset, 'data') and hasattr(dataset, 'targets') and hasattr(dataset, 'task_label'):
        print("Dataset contains 'data', 'targets', and 'task_label' attributes.")
    else:
        print("Dataset is missing one or more of the required attributes: 'data', 'targets', 'task_label'.")
        
    # Verify the length and sample data
    try:
        # Let's print the first sample to see how data is structured
        sample = dataset[0]
        print(f"First sample structure: {sample}")
    except Exception as e:
        print(f"Error fetching first sample: {e}")
    
    # If there's data, check for its expected shape and content
    if hasattr(dataset, 'data'):
        print(f"Dataset contains data with shape: {len(dataset.data)} samples.")
    
    if hasattr(dataset, 'targets'):
        print(f"Dataset contains targets with length: {len(dataset.targets)}.")

# Experience 1
check_avalanche_dataset(train_dataset_exp1)
check_avalanche_dataset(val_dataset_exp1)
check_avalanche_dataset(test_dataset_exp1)

# Experience 2
check_avalanche_dataset(train_dataset_exp2)
check_avalanche_dataset(val_dataset_exp2)
check_avalanche_dataset(test_dataset_exp2)

# Experience 3
check_avalanche_dataset(train_dataset_exp3)
check_avalanche_dataset(val_dataset_exp3)
check_avalanche_dataset(test_dataset_exp3)

Dataset is an instance of AvalancheDataset.
Dataset internal structure: ['__abstractmethods__', '__add__', '__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__protocol_attrs__', '__radd__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_data_attributes', '_datasets', '_flat_data', '_init_collate_fn', '_is_protocol', '_is_runtime_protocol', '_shallow_clone_dataset', '_transform_groups', '_tree_depth', 'collate_fn', 'concat', 'data', 'eval', 'freeze_transforms', 'get_valid_indices', 'remove_current_transform_group', 'replace_current_transform_group', 'root_dir', 'subset', 'targets', 'targets_task_labe

In [23]:
# Function to print all attributes of the dataset
def print_all_attributes(dataset):
    print(f"Attributes of the dataset:")
    for attr in dir(dataset):
        # Skip private attributes (those starting with '_')
        if not attr.startswith('_'):
            print(f"  {attr}")

# Check all datasets in the "train" and "test" streams
dataset_streams = {
    "train": [train_dataset_exp1, train_dataset_exp2, train_dataset_exp3],
    "test": [test_dataset_exp1, test_dataset_exp2, test_dataset_exp3]
}

# Iterate over the streams and check each dataset
for stream_name, datasets in dataset_streams.items():
    print(f"\nChecking {stream_name} datasets:")
    for i, dataset in enumerate(datasets):
        print(f"\n  Checking dataset {stream_name}_{i + 1}:")
        print_all_attributes(dataset)


Checking train datasets:

  Checking dataset train_1:
Attributes of the dataset:
  collate_fn
  concat
  data
  eval
  freeze_transforms
  get_valid_indices
  remove_current_transform_group
  replace_current_transform_group
  root_dir
  subset
  targets
  targets_task_labels
  task_label
  train
  transform
  update_data_attribute
  with_transforms

  Checking dataset train_2:
Attributes of the dataset:
  collate_fn
  concat
  data
  eval
  freeze_transforms
  get_valid_indices
  remove_current_transform_group
  replace_current_transform_group
  root_dir
  subset
  targets
  targets_task_labels
  task_label
  train
  transform
  update_data_attribute
  with_transforms

  Checking dataset train_3:
Attributes of the dataset:
  collate_fn
  concat
  data
  eval
  freeze_transforms
  get_valid_indices
  remove_current_transform_group
  replace_current_transform_group
  root_dir
  subset
  targets
  targets_task_labels
  task_label
  train
  transform
  update_data_attribute
  with_transfo

In [24]:
# Check if they are instances of AvalancheDataset
print("Is train_dataset_exp1 an AvalancheDataset? ", isinstance(train_dataset_exp1, AvalancheDataset))
print("Is train_dataset_exp2 an AvalancheDataset? ", isinstance(train_dataset_exp2, AvalancheDataset))
print("Is train_dataset_exp3 an AvalancheDataset? ", isinstance(train_dataset_exp3, AvalancheDataset))

Is train_dataset_exp1 an AvalancheDataset?  True
Is train_dataset_exp2 an AvalancheDataset?  True
Is train_dataset_exp3 an AvalancheDataset?  True


In [25]:
def check_dataset(avalanche_dataset):
    try:
        # Access the first sample using __getitem__
        first_sample = avalanche_dataset[0]  # This might be a tuple (image, label, task_label)
        
        # Print the entire first sample to check the structure
        print(f"First sample: {first_sample}")
        
        # Print the type of each element in the sample (image, target, task_label)
        if isinstance(first_sample, tuple):
            print(f"First element type (image): {type(first_sample[0])}")
            print(f"Second element type (target): {type(first_sample[1])}")
            if len(first_sample) >= 3:
                print(f"Third element type (task_label): {type(first_sample[2])}")
            else:
                print("No task label in the sample.")
        else:
            print("The first sample is not a tuple as expected.")

        # Check if the first element (image) is a string (file path) or a tensor
        if isinstance(first_sample[0], str):
            print("The first element is a string, which might be a file path.")
        elif hasattr(first_sample[0], 'shape'):
            print(f"The first element is an image tensor with shape: {first_sample[0].shape}")
        else:
            print("The first element is neither a string nor a tensor.")

        # Check if the dataset has 3 elements (image, target, task_label)
        if len(first_sample) >= 3:
            print(f"Target (label): {first_sample[1]}")
            print(f"Task label: {first_sample[2]}")
        else:
            print("Warning: The dataset does not contain all expected elements (image, target, task_label).")

    except AttributeError as e:
        print(f"Error accessing dataset attributes: {e}")
    except IndexError as e:
        print(f"Error accessing dataset elements: {e}")

# Running the function on the first dataset
check_dataset(train_dataset_exp1)

First sample: [tensor([[[-1.8268, -1.8268, -1.8268,  ..., -0.4054, -0.3883, -0.3541],
         [-1.8268, -1.8268, -1.8268,  ..., -0.3883, -0.3541, -0.3198],
         [-1.8268, -1.8268, -1.8610,  ..., -0.3541, -0.3198, -0.2684],
         ...,
         [ 1.8722,  1.8722,  1.8208,  ...,  1.1015,  1.1015,  1.1187],
         [ 1.8208,  1.8037,  1.7694,  ...,  1.0844,  1.0844,  1.1187],
         [ 1.8208,  1.7865,  1.7523,  ...,  1.0502,  1.0673,  1.1187]],

        [[-1.7731, -1.7731, -1.7906,  ..., -0.6702, -0.6352, -0.5651],
         [-1.7731, -1.7731, -1.7906,  ..., -0.6527, -0.6001, -0.5301],
         [-1.7731, -1.7731, -1.8081,  ..., -0.6352, -0.5826, -0.5301],
         ...,
         [ 2.4286,  2.4286,  2.4111,  ...,  0.5203,  0.5028,  0.5028],
         [ 2.4286,  2.4286,  2.3936,  ...,  0.5203,  0.5203,  0.5203],
         [ 2.4286,  2.4111,  2.3761,  ...,  0.5028,  0.5203,  0.5378]],

        [[-1.7522, -1.7522, -1.7347,  ..., -1.8044, -1.8044, -1.8044],
         [-1.7522, -1.7522, -1

In [26]:
sample = train_dataset_exp1[0]
image, label, task_label = sample[:3]
print(f"Image shape: {image.shape}, Label: {label}, Task Label: {task_label}")

Image shape: torch.Size([3, 224, 224]), Label: 1, Task Label: 0


## Checking class distribution in each dataset

In [27]:
import torch
from collections import Counter

def count_classes(dataset):
    # Convert the FlatData into a list of values via list comprehension.
    values = [x for x in dataset.targets]
    # Convert the list of values to a tensor.
    t = torch.tensor(values)
    # Now, convert the tensor to a NumPy array and count the classes.
    return Counter(t.numpy())

print("Class distribution in Train Dataset 1:", count_classes(train_dataset_exp1))
print("Class distribution in Train Dataset 2:", count_classes(train_dataset_exp2))
print("Class distribution in Train Dataset 3:", count_classes(train_dataset_exp3))
print("Class distribution in Validation Dataset 1:", count_classes(val_dataset_exp1))
print("Class distribution in Validation Dataset 2:", count_classes(val_dataset_exp2))
print("Class distribution in Validation Dataset 3:", count_classes(val_dataset_exp3))
print("Class distribution in Test Dataset 1:", count_classes(test_dataset_exp1))
print("Class distribution in Test Dataset 2:", count_classes(test_dataset_exp2))
print("Class distribution in Test Dataset 3:", count_classes(test_dataset_exp3))

Class distribution in Train Dataset 1: Counter({1: 23, 0: 23, 2: 23})
Class distribution in Train Dataset 2: Counter({1: 32, 0: 32, 2: 32})
Class distribution in Train Dataset 3: Counter({0: 174, 1: 174, 2: 174})
Class distribution in Validation Dataset 1: Counter({1: 4, 0: 4, 2: 4})
Class distribution in Validation Dataset 2: Counter({1: 6, 0: 6, 2: 6})
Class distribution in Validation Dataset 3: Counter({0: 37, 1: 37, 2: 37})
Class distribution in Test Dataset 1: Counter({2: 6, 0: 6, 1: 6})
Class distribution in Test Dataset 2: Counter({2: 8, 1: 8, 0: 8})
Class distribution in Test Dataset 3: Counter({1: 38, 2: 38, 0: 38})


## Checking class distribution in each experience

In [28]:
from avalanche.benchmarks.utils import DataAttribute
from avalanche.benchmarks import benchmark_from_datasets
# Create the benchmark from your datasets
dataset_streams = {
    "train": [train_dataset_exp1, train_dataset_exp2, train_dataset_exp3],
    "test": [test_dataset_exp1, test_dataset_exp2, test_dataset_exp3]
}
# You might want to ensure the benchmark is created here
benchmark = benchmark_from_datasets(**dataset_streams)

for experience in benchmark.train_stream:
    print(f"Start of experience: {experience.current_experience}")
    
    # Try to get the targets via the dynamic property.
    try:
        targets_data = experience.dataset.targets.data
    except AttributeError:
        # Fallback: access the internal _data_attributes dictionary.
        targets_data = experience.dataset._data_attributes["targets"].data

    # If targets_data doesn't have 'tolist', assume it's already iterable.
    if hasattr(targets_data, "tolist"):
        unique_classes = set(targets_data.tolist())
    else:
        unique_classes = set(targets_data)
        
    print(f"Classes in this experience: {unique_classes}")

Start of experience: 0
Classes in this experience: {0, 1, 2}
Start of experience: 1
Classes in this experience: {0, 1, 2}
Start of experience: 2
Classes in this experience: {0, 1, 2}


## Overfitting Test

In [29]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from avalanche.training import EWC
from avalanche.benchmarks import benchmark_from_datasets
from models.cnn_models import SimpleCNN

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model and move it to the device
model = SimpleCNN(num_classes=3).to(device)

# Define your loss function and optimizer with a high learning rate for overfitting test
criterion = CrossEntropyLoss()
#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)  # High LR, no weight decay
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

# Disable dropout in the model to aid overfitting
def disable_dropout(model):
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.p = 0.0

disable_dropout(model)

# Create subsets of your AvalancheDataset (assuming train_dataset_exp1 exists)
# Here, we take the first 30 samples for training and use the same set for evaluation.
small_train_dataset = train_dataset_exp1.subset(list(range(30)))
small_test_dataset = train_dataset_exp1.subset(list(range(30)))

# Create a benchmark from these subsets
small_benchmark = benchmark_from_datasets(train=[small_train_dataset],
                                          test=[small_test_dataset])

# Create an EWC strategy instance.
# We set ewc_lambda to 0.0 to disable the EWC penalty and train_epochs to 50 so that
# the strategy runs for 50 epochs internally on the small subset.
cl_strategy = EWC(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_mb_size=5,
    train_epochs=100,
    eval_mb_size=5,
    ewc_lambda=0.0,
    device=device
)

# Run training on the small benchmark
for experience in small_benchmark.train_stream:
    print(f"=== Overfitting Test on Small Subset, Experience {experience.current_experience} ===")
    cl_strategy.train(experience)
    
    # Evaluate on the training subset to check if the model overfits
    train_eval_res = cl_strategy.eval(small_benchmark.test_stream)
    print("Evaluation on training subset:", train_eval_res)
    
    # Optionally, print predictions on a few samples to verify correctness
    for sample in small_benchmark.test_stream[0].dataset:
        image, label, *rest = sample
        image = image.to(device).unsqueeze(0)  # Add batch dimension
        output = model(image)
        predicted_class = output.argmax(dim=1).item()
        print(f"True label: {label}, Predicted: {predicted_class}")
        break  # Print only one sample per experience

=== Overfitting Test on Small Subset, Experience 0 ===
-- >> Start of training phase << --
0it [00:00, ?it/s]

KeyboardInterrupt: 

## Implementing EWC using Avalanche - the end-to-end continual learning library

In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from avalanche.training.plugins import SupervisedPlugin
import os
from models.cnn_models import SimpleCNN

# Ensure the folder "loss_plots" exists.
if not os.path.exists("loss_plots"):
    os.makedirs("loss_plots")

# Helper function to compute the average loss on a dataset.
def compute_loss(model, dataset, criterion, device, batch_size=15):
    model.eval()
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    losses = []
    with torch.no_grad():
        for batch in data_loader:
            if isinstance(batch, (list, tuple)):
                inputs, targets = batch[0], batch[1]
            else:
                raise ValueError("Expected the data loader to return a tuple or list.")
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
    model.train()  # Set back to train mode
    return sum(losses) / len(losses) if losses else 0.0

# Custom plugin to record training and validation loss after each epoch.
class LossHistoryPlugin(SupervisedPlugin):
    def __init__(self, validation_dataset, criterion, device, batch_size=15):
        super().__init__()
        self.validation_dataset = validation_dataset
        self.criterion = criterion
        self.device = device
        self.batch_size = batch_size
        self.epoch_train_losses = []
        self.epoch_val_losses = []

    def after_training_epoch(self, strategy):
        # Compute training loss on the full training dataset of the current experience.
        train_loss = compute_loss(strategy.model, strategy.experience.dataset, self.criterion, self.device, self.batch_size)
        # Compute validation loss on the provided validation dataset.
        val_loss = compute_loss(strategy.model, self.validation_dataset, self.criterion, self.device, self.batch_size)
        self.epoch_train_losses.append(train_loss)
        self.epoch_val_losses.append(val_loss)
        print(f"Epoch {len(self.epoch_train_losses)}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
# -------------------------------
# Define candidate hyperparameters for grid search.
# -------------------------------
learning_rates = [0.001]
ewc_lambdas = [50, 60, 70, 80, 90, 100]

# -------------------------------
# Setup loggers and device
# -------------------------------
tb_logger = TensorboardLogger()
text_logger = TextLogger(open('log.txt', 'a'))
interactive_logger = InteractiveLogger()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -------------------------------
# Setup benchmark and validation datasets
# -------------------------------
# (Assume train_dataset_exp1, train_dataset_exp2, train_dataset_exp3,
#  test_dataset_exp1, test_dataset_exp2, test_dataset_exp3,
#  val_dataset_exp1, val_dataset_exp2, val_dataset_exp3 are defined.)
dataset_streams = {
    "train": [train_dataset_exp1, train_dataset_exp2, train_dataset_exp3],
    "test": [test_dataset_exp1, test_dataset_exp2, test_dataset_exp3]
}
benchmark = benchmark_from_datasets(**dataset_streams)
validation_datasets = [val_dataset_exp1, val_dataset_exp2, val_dataset_exp3]

# -------------------------------
# Grid search loop over learning rate and ewc_lambda.
# -------------------------------
results_summary = []

for lr, ewc_lambda in itertools.product(learning_rates, ewc_lambdas):
    print(f"\n=== Hyperparameters: lr={lr}, ewc_lv   ambda={ewc_lambda} ===")
    
    # Reinitialize model, criterion, and optimizer.
    model = SimpleCNN(num_classes=3).to(device)
    criterion = CrossEntropyLoss()
    # For example, using SGD:
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)
    
    # Setup a learning rate scheduler and its plugin.
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
    lr_plugin = LRSchedulerPlugin(scheduler)
    
    # Enable the evaluator to log additional metrics.
    evaluator = EvaluationPlugin(
        accuracy_metrics(minibatch=False, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=False, epoch=True, experience=True, stream=True),
        loggers=[interactive_logger, text_logger, tb_logger]
    )
    
    # Instantiate the EWC strategy with eval_every=-1 since we’ll do our own per-epoch validation.
    cl_strategy = EWC(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_mb_size=15,
        train_epochs=50,
        eval_mb_size=15,
        ewc_lambda=ewc_lambda,
        evaluator=evaluator,
        eval_every=-1,  # disable the internal evaluation calls
        device=device,
        plugins=[lr_plugin]  # we will add our loss history plugin per experience below
    )
    
    # Prepare a CSV file for logging summary metrics for this hyperparameter setting.
    csv_file_path = f"ewc_grid_lr{lr}_lambda{ewc_lambda}.csv"
    with open(csv_file_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["Experience", "Epochs", "FinalTrainLoss", "FinalValLoss"])
    
    # Train across experiences.
    for i, experience in enumerate(benchmark.train_stream):
        print(f"\n=== Start of Experience {experience.current_experience} ===")
        
        # Create a validation benchmark for the current experience if needed.
        # Here we use the corresponding validation dataset directly.
        current_val_dataset = validation_datasets[i]
        
        # Instantiate the loss history plugin for this experience.
        loss_history = LossHistoryPlugin(validation_dataset=current_val_dataset,
                                         criterion=criterion, device=device, batch_size=15)
        # Add our plugin to the strategy.
        cl_strategy.plugins.append(loss_history)
        
        print(f"Training Experience {experience.current_experience} for 50 epochs...")
        train_res = cl_strategy.train(experience)
        # After training, remove the loss history plugin so it doesn't accumulate data from previous experiences.
        cl_strategy.plugins.remove(loss_history)
        
        # At this point, loss_history.epoch_train_losses and .epoch_val_losses contain the per-epoch losses.
        final_train_loss = loss_history.epoch_train_losses[-1] if loss_history.epoch_train_losses else None
        final_val_loss = loss_history.epoch_val_losses[-1] if loss_history.epoch_val_losses else None
        
        # Save per-epoch losses to a plot.
        epochs = list(range(1, len(loss_history.epoch_train_losses) + 1))
        plt.figure()
        plt.plot(epochs, loss_history.epoch_train_losses, label='Train Loss')
        plt.plot(epochs, loss_history.epoch_val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title(f"Exp {experience.current_experience} | Optimiser: Adam | lr={lr}, ewc_lambda={ewc_lambda} | {cl_strategy.train_epochs} epochs")
        plt.legend()
        plot_filename = os.path.join("loss_plots", f"loss_plot_exp{experience.current_experience}_lr{lr}_lambda{ewc_lambda}.png")
        plt.savefig(plot_filename)
        plt.close()
        print(f"Saved loss plot to {plot_filename}")
        
        # Log summary metrics for this experience.
        with open(csv_file_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([experience.current_experience, cl_strategy.train_epochs, final_train_loss, final_val_loss])
        
        # Evaluate on the entire test stream.
        print("Testing on the entire test stream...")
        test_res = cl_strategy.eval(benchmark.test_stream)
        print("Test results:", test_res)
    
    results_summary.append({
        "lr": lr,
        "ewc_lambda": ewc_lambda,
        "final_train_loss": final_train_loss,
        "final_val_loss": final_val_loss,
        "test_results": test_res    
    })

print("\n=== Hyperparameter Search Summary ===")
for res in results_summary: 
    print(res)


=== Hyperparameters: lr=0.001, ewc_lambda=50 ===

=== Start of Experience 0 ===
Training Experience 0 for 50 epochs...
-- >> Start of training phase << --
100%|██████████| 5/5 [01:10<00:00, 14.04s/it]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream = 1.1020
	Top1_Acc_Epoch/train_phase/train_stream = 0.3333
Epoch 1: Train Loss = 1.0935, Val Loss = 1.0940
100%|██████████| 5/5 [00:43<00:00,  8.66s/it]
Epoch 1 ended.
	Loss_Epoch/train_phase/train_stream = 1.0886
	Top1_Acc_Epoch/train_phase/train_stream = 0.4058
Epoch 2: Train Loss = 1.0703, Val Loss = 1.0741
100%|██████████| 5/5 [00:37<00:00,  7.58s/it]
Epoch 2 ended.
	Loss_Epoch/train_phase/train_stream = 1.0621
	Top1_Acc_Epoch/train_phase/train_stream = 0.5217
Epoch 3: Train Loss = 1.0161, Val Loss = 1.0301
100%|██████████| 5/5 [00:36<00:00,  7.21s/it]
Epoch 3 ended.
	Loss_Epoch/train_phase/train_stream = 1.0076
	Top1_Acc_Epoch/train_phase/train_stream = 0.5652
Epoch 4: Train Loss = 0.9563, Val Loss = 0.9875
100%|██████████| 5/5 [00