## Import libraries 

In [2]:
import os
import labels
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Sampler, SubsetRandomSampler
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from tqdm import tqdm  # Import tqdm for progress visualization
from models.cnn_models import SimpleCNN
import random
import numpy as np
import matplotlib.pyplot as plt

# Set the random seed for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)  # If using a GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

## Define filepaths as constant

In [3]:
# Define file paths as constants
CSV_FILE_PATH = r'C:\Users\Sandhra George\avalanche\data\dataset.csv'
ROOT_DIR_PATH = r'C:\Users\Sandhra George\avalanche\caxton_dataset\print24'

csv_file = r'C:\Users\Sandhra George\avalanche\data\dataset.csv'  # Path to the CSV file
root_dir = r'C:\Users\Sandhra George\avalanche\caxton_dataset\print24'  # Path to the image directory

## Load data into DataFrame

In [4]:
# Load data into a DataFrame for easier processing
data = pd.read_csv(CSV_FILE_PATH)

# Limit dataset to the images between row indices 454 and 7058 (inclusive)
#data_limited = data.iloc[454:7059].reset_index(drop=True)

# Filter the dataset to only include images containing "print24"
data_filtered = data[data.iloc[:, 0].str.contains('print24', na=False)]

# Update the first column to contain only the image filenames
data_filtered.iloc[:, 0] = data_filtered.iloc[:, 0].str.replace(r'.*?/(image-\d+\.jpg)', r'\1', regex=True)

# Display the updated DataFrame
print("First rows of filtered DataFrame:")
print(data_filtered.head())

# Display the last few rows of the updated DataFrame
print("\nLast rows of filtered DataFrame:")
print(data_filtered.tail())

First rows of filtered DataFrame:
          img_path               timestamp  flow_rate  feed_rate  z_offset  \
99496  image-4.jpg  2020-10-07T11:45:35-86        100        100       0.0   
99497  image-5.jpg  2020-10-07T11:45:36-32        100        100       0.0   
99498  image-6.jpg  2020-10-07T11:45:36-79        100        100       0.0   
99499  image-7.jpg  2020-10-07T11:45:37-26        100        100       0.0   
99500  image-8.jpg  2020-10-07T11:45:37-72        100        100       0.0   

       target_hotend  hotend    bed  nozzle_tip_x  nozzle_tip_y  img_num  \
99496          205.0  204.86  64.83           654           560        3   
99497          205.0  204.62  65.08           654           560        4   
99498          205.0  204.62  65.08           654           560        5   
99499          205.0  204.62  65.08           654           560        6   
99500          205.0  204.62  65.08           654           560        7   

       print_id  flow_rate_class  feed_r

### Analysing the hotend temperature column

In [5]:
# Extract unique temperatures in the 'target_hotend' column and sort them
unique_temperatures = sorted(data_filtered['target_hotend'].unique())  # Sort temperatures in ascending order

# Calculate the full range of temperatures (min and max)
temperature_min = data_filtered['target_hotend'].min()
temperature_max = data_filtered['target_hotend'].max()

# Print the unique temperatures (sorted), count, and full range
print("\nUnique target hotend temperatures in the dataset (sorted):")
print(unique_temperatures)
print(f"\nNumber of unique target hotend temperatures: {len(unique_temperatures)}")
print(f"Temperature range: {temperature_min}° to {temperature_max}°")


Unique target hotend temperatures in the dataset (sorted):
[180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0]

Number of unique target hotend temperatures: 51
Temperature range: 180.0° to 230.0°


## Creating a "random" temperature sub list

In [6]:
# Check if we have enough unique temperatures to select from
if len(unique_temperatures) >= 50:
    # Select the lowest and highest temperatures
    temperature_sublist = [temperature_min, temperature_max]

    # Remove the lowest and highest temperatures from the unique temperatures list
    remaining_temperatures = [temp for temp in unique_temperatures if temp != temperature_min and temp != temperature_max]

    # Randomly select 11 other temperatures from the remaining ones
    random_temperatures = random.sample(remaining_temperatures, 40)

    # Add the random temperatures to the temperature_sublist
    temperature_sublist.extend(random_temperatures)
    
    # Sort from lowest to highest hotend temperature
    temperature_sublist = sorted(temperature_sublist)

    # Print the temperature sublist
    print("\nTemperature sublist:")
    print(temperature_sublist)
else:
    print("Not enough unique temperatures to select from. At least 13 unique temperatures are required.")


Temperature sublist:
[180.0, 181.0, 182.0, 183.0, 184.0, 186.0, 187.0, 188.0, 189.0, 191.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 201.0, 202.0, 203.0, 205.0, 208.0, 209.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 226.0, 227.0, 228.0, 229.0, 230.0]


## Split the dataset into separate DataFrames for each class

In [8]:
# Initialise a dictionary to store DataFrames for each class and temperature combination
class_temperature_datasets = {}

# Iterate over all temperatures in the temperature_sublist
for temp in temperature_sublist:
    print(f"Processing temperature: {temp}°")
    
    # Filter the dataset for the current temperature
    temp_filtered = data_filtered[data_filtered['target_hotend'] == temp]
    
    # Now, iterate over all classes (0, 1, 2)
    for class_id in [0, 1, 2]:  # Ensure we process all classes: 0, 1, 2
        # Filter the data for the current class
        class_temp_data = temp_filtered[temp_filtered['hotend_class'] == class_id]
        
        if class_temp_data.empty:
            # If there are no images for this class at the current temperature, print a message
            print(f"Class {class_id} at {temp}° dataset size: 0")
        else:
            # Shuffle the data (if needed) and store it in the dictionary
            class_temperature_datasets[(class_id, temp)] = class_temp_data.sample(frac=1, random_state=42)
            print(f"Class {class_id} at {temp}° dataset size: {len(class_temp_data)}")
    
# Print the size of each class-temperature dataset (even if the size is 0)
for temp in temperature_sublist:
    print(f"\nSummary for Temperature: {temp}°")
    for class_id in [0, 1, 2]:
        # Retrieve the data for the current temperature and class from the dictionary
        if (class_id, temp) in class_temperature_datasets:
            print(f"Class {class_id} at {temp}° dataset size: {len(class_temperature_datasets[(class_id, temp)])}")
        else:
            # If no data for this class-temperature combination, print 0
            print(f"Class {class_id} at {temp}° dataset size: 0")

# OPTIONAL: Process the minimum class sizes where no class has zero data, and minimum size is 10 or more
min_class_size = float('inf')  # Start with infinity as a comparison baseline
min_combinations = []  # List to store class-temperature combinations with the minimum size

# Iterate over each temperature in the temperature sublist
for temp in temperature_sublist:
    
    # Check if all classes (0, 1, 2) have non-zero data for the current temperature
    class_sizes = []  # List to hold class sizes for the current temperature
    valid_temperature = True  # Flag to check if all classes have non-zero data for this temperature
    
    # Iterate over all classes (0, 1, 2)
    for class_id in [0, 1, 2]:
        if (class_id, temp) in class_temperature_datasets:
            class_size = len(class_temperature_datasets[(class_id, temp)])
            class_sizes.append(class_size)
            if class_size == 0:
                valid_temperature = False
                break  # No need to check further, this temperature is invalid
        else:
            valid_temperature = False
            break  # No data for this temperature-class combination, so it's invalid

    # If all classes have non-zero data for this temperature, calculate the minimum class size
    if valid_temperature:
        min_temp_class_size = min(class_sizes)  # Get the minimum class size for this temperature
        
        # Ensure the minimum class size is at least 10 before considering it
        if min_temp_class_size >= 15:
            if min_temp_class_size < min_class_size:
                # If we find a new minimum, reset the combinations list
                min_class_size = min_temp_class_size
                min_combinations = [(class_sizes.index(min_class_size), temp)]  # Store the combination
            elif min_temp_class_size == min_class_size:
                # If it's the same as the current minimum, append to the list
                min_combinations.append((class_sizes.index(min_class_size), temp))

# Print the minimum class size and the corresponding class ID and temperature
if min_class_size != float('inf'):
    print(f"\nMinimum class size: {min_class_size} (with size >= 10) occurs for the following class-temperature combinations:")
    for class_id, temp in min_combinations:
        print(f"Class {class_id} at {temp}°")
else:
    print("\nNo valid temperature with all classes having non-zero data and class size >= 10.")

Processing temperature: 180.0°
Class 0 at 180.0° dataset size: 618
Class 1 at 180.0° dataset size: 53
Class 2 at 180.0° dataset size: 5
Processing temperature: 181.0°
Class 0 at 181.0° dataset size: 393
Class 1 at 181.0° dataset size: 12
Class 2 at 181.0° dataset size: 0
Processing temperature: 182.0°
Class 0 at 182.0° dataset size: 243
Class 1 at 182.0° dataset size: 27
Class 2 at 182.0° dataset size: 0
Processing temperature: 183.0°
Class 0 at 183.0° dataset size: 262
Class 1 at 183.0° dataset size: 8
Class 2 at 183.0° dataset size: 0
Processing temperature: 184.0°
Class 0 at 184.0° dataset size: 272
Class 1 at 184.0° dataset size: 0
Class 2 at 184.0° dataset size: 0
Processing temperature: 186.0°
Class 0 at 186.0° dataset size: 385
Class 1 at 186.0° dataset size: 21
Class 2 at 186.0° dataset size: 0
Processing temperature: 187.0°
Class 0 at 187.0° dataset size: 110
Class 1 at 187.0° dataset size: 26
Class 2 at 187.0° dataset size: 0
Processing temperature: 188.0°
Class 0 at 188.0° d

## Create a balanced dataset

In [9]:
# Initialise a list to store valid datasets for each temperature
valid_class_temperature_datasets = []

# Process each temperature in the temperature sublist
for temp in temperature_sublist:
    print(f"\nProcessing temperature: {temp}°")
    
    # Filter the dataset for the current temperature
    temp_filtered = data_filtered[data_filtered['target_hotend'] == temp]
    
    # Dictionary to store class-specific data for the current temperature
    temp_class_data = {}
    meets_criteria = True  # Assume the temperature meets criteria until proven otherwise

    # Iterate through each class (0, 1, 2)
    for class_id in [0, 1, 2]:
        # Filter by both class and temperature
        class_temp_data = temp_filtered[temp_filtered['hotend_class'] == class_id]
        
        # Check and print actual dataset size for verification
        actual_class_size = len(class_temp_data)
        print(f"Class {class_id} at {temp}° actual dataset size: {actual_class_size}")

        # Only add if the dataset size for this class meets the minimum requirement
        if actual_class_size >= min_class_size:
            # Sample exactly min_class_size images
            temp_class_data[class_id] = class_temp_data.sample(n=min_class_size, random_state=42)
        else:
            print(f"Class {class_id} at {temp}° does not have enough images ({actual_class_size}). Skipping this temperature.")
            meets_criteria = False
            break  # Stop processing this temperature if any class fails to meet min_class_size

    # If all classes at this temperature meet the criteria, add to valid datasets
    if meets_criteria:
        combined_data_for_temp = pd.concat(temp_class_data.values(), ignore_index=True)
        valid_class_temperature_datasets.append(combined_data_for_temp)
        print(f"Temperature {temp}° included with {min_class_size} images per class.")

# Combine all valid datasets for all temperatures into one DataFrame
balanced_data = pd.concat(valid_class_temperature_datasets, ignore_index=True) if valid_class_temperature_datasets else pd.DataFrame()

# Shuffle the balanced dataset if it’s not empty
if not balanced_data.empty:
    balanced_data = balanced_data.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"\nTotal number of images in the balanced dataset: {len(balanced_data)}")
else:
    print("No valid data left after filtering temperatures with insufficient class sizes.")

# Print the final class and temperature counts in the balanced dataset
if not balanced_data.empty:
    print("\nClass and Temperature counts in the balanced dataset:")
    for temp in balanced_data['target_hotend'].unique():
        print(f"\nTemperature: {temp}°")
        for class_id in [0, 1, 2]:
            count = len(balanced_data[(balanced_data['hotend_class'] == class_id) & (balanced_data['target_hotend'] == temp)])
            print(f"Class {class_id}: {count} images")
else:
    print("Balanced dataset is empty.")


Processing temperature: 180.0°
Class 0 at 180.0° actual dataset size: 618
Class 1 at 180.0° actual dataset size: 53
Class 2 at 180.0° actual dataset size: 5
Class 2 at 180.0° does not have enough images (5). Skipping this temperature.

Processing temperature: 181.0°
Class 0 at 181.0° actual dataset size: 393
Class 1 at 181.0° actual dataset size: 12
Class 1 at 181.0° does not have enough images (12). Skipping this temperature.

Processing temperature: 182.0°
Class 0 at 182.0° actual dataset size: 243
Class 1 at 182.0° actual dataset size: 27
Class 2 at 182.0° actual dataset size: 0
Class 2 at 182.0° does not have enough images (0). Skipping this temperature.

Processing temperature: 183.0°
Class 0 at 183.0° actual dataset size: 262
Class 1 at 183.0° actual dataset size: 8
Class 1 at 183.0° does not have enough images (8). Skipping this temperature.

Processing temperature: 184.0°
Class 0 at 184.0° actual dataset size: 272
Class 1 at 184.0° actual dataset size: 0
Class 1 at 184.0° does

## Converting balanced_dataset into a dataframe that contains only the img_path and hotend_class

In [None]:
# Now select only the columns you want
balanced_data = balanced_data[['img_path', 'hotend_class']]

# Check the modified dataset
print(balanced_data)

# AT THIS POINT THE BALANCED DATASET IS CREATED

## Create training, validation, and testing datasets

In [None]:
# class BalancedDataset(Dataset):
#     def __init__(self, data_frame, root_dir, transform=None):
#         self.data = data_frame
#         self.root_dir = root_dir
#         self.transform = transform or transforms.Compose([transforms.Resize((224, 224)),
#                                                           transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
#                 
#         # Validate that the images exist in the directory
#         self.valid_indices = self.get_valid_indices()
# 
#     def get_valid_indices(self):
#         valid_indices = []
#         for idx in tqdm(range(len(self.data)), desc="Validating images"):
#             img_name = self.data.iloc[idx, 0].strip()
#             img_name = img_name.split('/')[-1]  # Extract file name
#             
#             if img_name.startswith("image-"):
#                 try:
#                     # Ensure we only include images in print24
#                     image_number = int(img_name.split('-')[1].split('.')[0])
#                     if 4 <= image_number <= 26637:
#                         full_img_path = os.path.join(self.root_dir, img_name)
#                         if os.path.exists(full_img_path):
#                             valid_indices.append(idx)
#                         else:
#                             print(f"Image does not exist: {full_img_path}")
#                 except ValueError:
#                     print(f"Invalid filename format for {img_name}. Skipping...")
#         
#         print(f"Total valid indices found: {len(valid_indices)}")  # Debugging output
#         return valid_indices
# 
#     def __len__(self):
#         return len(self.valid_indices)
# 
#     def __getitem__(self, idx):
#         # Get the actual index from valid indices
#         actual_idx = self.valid_indices[idx]
#         img_name = self.data.iloc[actual_idx, 0].strip()
#         full_img_path = os.path.join(self.root_dir, img_name)
#         
#         try:
#             image = Image.open(full_img_path).convert('RGB')  # Ensure image is RGB
#             label_str = self.data.iloc[actual_idx]['hotend_class']  # Use column name 'hotend_class'
#             label = int(label_str)  # Convert label to integer (ensure it's valid)
#             
#             # Apply transformation if any
#             image = self.transform(image)  # Apply transformation
#     
#             return image, label, actual_idx
#         except Exception as e:
#             print(f"Error loading image {full_img_path}: {e}")
#             return None  # Handle error gracefully
# 
# # Assuming the data and labels are already prepared in the dataframe
# # balanced_data should be a pandas DataFrame that contains image paths and labels
# dataset = BalancedDataset(balanced_data, ROOT_DIR_PATH)
# 
# # Step 3: Stratified Split (Ensure each class is represented proportionally in the splits)
# labels = dataset.data['hotend_class'].values
# 
# # Get the number of samples for each class
# class_counts = np.bincount(labels)
# 
# # Calculate the number of samples to allocate for each class in each split (train, val, test)
# def calculate_split_indices(class_counts, split_ratios):
#     train_indices = []
#     val_indices = []
#     test_indices = []
# 
#     # For each class, calculate how many samples go into each split
#     for class_label, count in enumerate(class_counts):
#         # Calculate how many samples per split for this class
#         num_train = int(count * split_ratios[0])
#         num_val = int(count * split_ratios[1])
#         num_test = count - num_train - num_val  # The remaining samples go to test
#         
#         # Get the indices for the class
#         class_indices = np.where(labels == class_label)[0]
#         
#         # Shuffle indices for randomness
#         np.random.shuffle(class_indices)
#         
#         # Split the indices based on the calculated numbers
#         train_indices.extend(class_indices[:num_train])
#         val_indices.extend(class_indices[num_train:num_train+num_val])
#         test_indices.extend(class_indices[num_train+num_val:])
#     
#     return train_indices, val_indices, test_indices
# 
# # Define the split ratios (train, val, test)
# split_ratios = (0.8, 0.1, 0.1)  # 70% training, 20% validation, 10% testing
# train_indices, val_indices, test_indices = calculate_split_indices(class_counts, split_ratios)
# 
# # Print sizes to confirm the split
# print(f"Training set size: {len(train_indices)}")
# print(f"Validation set size: {len(val_indices)}")
# print(f"Test set size: {len(test_indices)}")
# 
# # Step 4: Create DataLoaders with SubsetRandomSampler
# batch_size = 33  # Adjust batch size as needed
# 
# train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
# val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(val_indices))
# test_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(test_indices))
# 
# # Optionally: You can also print the class distribution in each set
# def print_class_distribution(loader, name):
#     class_counts = np.zeros(3)  # Assuming 3 classes, adjust for your number of classes
#     for _, labels, _ in loader:
#         for label in labels:
#             class_counts[label] += 1
#     print(f"{name} Class Distribution: {class_counts}")
# 
# print_class_distribution(train_loader, "Training")
# print_class_distribution(val_loader, "Validation")
# print_class_distribution(test_loader, "Test")
# 
# # Optionally: You can print details of the indices in the DataLoader as follows:
# def print_loader_info(loader, name):
#     print(f"\n{name} set:")
#     for images, labels, indices in loader:
#         for img, label, idx in zip(images, labels, indices):
#             print(f"Index: {idx}, Label: {label}")
# 
# print_loader_info(train_loader, "Train")
# print_loader_info(val_loader, "Validation")
# print_loader_info(test_loader, "Test")

In [None]:
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset

# Define a worker function for deterministic DataLoader behavior
def seed_worker(worker_id):
    np.random.seed(seed + worker_id)
    random.seed(seed + worker_id)

# Define the BalancedDataset class
class BalancedDataset(Dataset):
    def __init__(self, data_frame, root_dir, transform=None):
        self.data = data_frame
        self.root_dir = root_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Validate that the images exist in the directory
        self.valid_indices = self.get_valid_indices()

    def get_valid_indices(self):
        valid_indices = []
        for idx in tqdm(range(len(self.data)), desc="Validating images"):
            img_name = self.data.iloc[idx, 0].strip()
            img_name = img_name.split('/')[-1]  # Extract file name
            
            if img_name.startswith("image-"):
                try:
                    # Ensure we only include images in the valid range
                    image_number = int(img_name.split('-')[1].split('.')[0])
                    if 4 <= image_number <= 26637:
                        full_img_path = os.path.join(self.root_dir, img_name)
                        if os.path.exists(full_img_path):
                            valid_indices.append(idx)
                        else:
                            print(f"Image does not exist: {full_img_path}")
                except ValueError:
                    print(f"Invalid filename format for {img_name}. Skipping...")
        
        print(f"Total valid indices found: {len(valid_indices)}")  # Debugging output
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Get the actual index from valid indices
        actual_idx = self.valid_indices[idx]
        img_name = self.data.iloc[actual_idx, 0].strip()
        full_img_path = os.path.join(self.root_dir, img_name)
        
        try:
            image = Image.open(full_img_path).convert('RGB')  # Ensure image is RGB
            label_str = self.data.iloc[actual_idx]['hotend_class']  # Use column name 'hotend_class'
            label = int(label_str)  # Convert label to integer (ensure it's valid)
            
            # Apply transformation if any
            image = self.transform(image)  # Apply transformation
    
            return image, label, actual_idx
        except Exception as e:
            print(f"Error loading image {full_img_path}: {e}")
            return None  # Handle error gracefully

# Assuming the data and labels are already prepared in the dataframe
# balanced_data should be a pandas DataFrame that contains image paths and labels
dataset = BalancedDataset(balanced_data, ROOT_DIR_PATH)

# Stratified Split to ensure proportional representation in splits
labels = dataset.data['hotend_class'].values

# Calculate the number of samples for each class in each split (train, val, test)
def calculate_split_indices(class_counts, split_ratios, seed=42):
    train_indices = []
    val_indices = []
    test_indices = []

    # Set the seed for consistent behavior
    np.random.seed(seed)

    # For each class, calculate how many samples go into each split
    for class_label, count in enumerate(class_counts):
        num_train = int(count * split_ratios[0])
        num_val = int(count * split_ratios[1])
        num_test = count - num_train - num_val  # The remaining samples go to test
        
        # Get the indices for the class
        class_indices = np.where(labels == class_label)[0]
        
        # Shuffle indices for randomness
        np.random.shuffle(class_indices)
        
        # Split the indices based on the calculated numbers
        train_indices.extend(class_indices[:num_train])
        val_indices.extend(class_indices[num_train:num_train+num_val])
        test_indices.extend(class_indices[num_train+num_val:])

    return train_indices, val_indices, test_indices

# Define the split ratios (train, val, test)
split_ratios = (0.8, 0.1, 0.1)  # 80% training, 10% validation, 10% testing
class_counts = np.bincount(labels)  # Get the count of each class in the labels
train_indices, val_indices, test_indices = calculate_split_indices(class_counts, split_ratios, seed=seed)

# Print sizes to confirm the split
print(f"Training set size: {len(train_indices)}")
print(f"Validation set size: {len(val_indices)}")
print(f"Test set size: {len(test_indices)}")

# Step 4: Create DataLoaders with SubsetRandomSampler, applying worker_init_fn=seed_worker for reproducibility
batch_size = 15  # Adjust batch size as needed

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices), worker_init_fn=seed_worker)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(val_indices), worker_init_fn=seed_worker)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(test_indices), worker_init_fn=seed_worker)

# Modify the function to accept num_classes as an argument
def print_class_distribution(loader, name, num_classes):
    class_counts = np.zeros(num_classes)
    for _, labels, _ in loader:
        for label in labels:
            class_counts[label] += 1
    print(f"{name} Class Distribution: {class_counts}")

# Get the number of unique classes in the dataset
num_classes = len(np.unique(labels))  # Calculate it once for the dataset

# Call the function with the number of classes as an additional argument
print_class_distribution(train_loader, "Training", num_classes)
print_class_distribution(val_loader, "Validation", num_classes)
print_class_distribution(test_loader, "Test", num_classes)


# Optionally: Print details of the indices in each DataLoader
def print_loader_info(loader, name):
    print(f"\n{name} set:")
    for images, labels, indices in loader:
        for img, label, idx in zip(images, labels, indices):
            print(f"Index: {idx}, Label: {label}")

print_loader_info(train_loader, "Train")
print_loader_info(val_loader, "Validation")
print_loader_info(test_loader, "Test")

# AT THIS POINT I AM READY TO BEGIN TRAINING MY MODEL

In [None]:
# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move model to device
model = SimpleCNN(num_classes=3).to(device)

# Training parameters
num_epochs = 100 # Adjust as needed
class_weights = torch.tensor([1.0, 1.0, 1.0]).to(device)  # Update these based on your class distribution
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)  # Adjust learning rate if needed
# **Add the learning rate scheduler here**
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Decrease LR every 10 epochs by a factor of 0.1

best_val_accuracy = 0.0  # Track the best validation accuracy to save the best model

# Store losses for plotting
train_losses = []
val_losses = []

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    model.train()  # Set model to training mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    class_counts = [0] * 3  # Assuming 3 classes, update if needed

    # Training phase with tqdm progress bar
    for images, labels, _ in tqdm(train_loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track training loss and accuracy
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        # Update class counts
        for label in labels:
            class_counts[label.item()] += 1
        
        # Print predicted vs actual labels for each batch
        for i in range(len(labels)):
            print(f"Predicted: {predicted[i].item()}, Actual: {labels[i].item()}")

    epoch_loss = running_loss / total_samples
    epoch_accuracy = correct_predictions / total_samples
    print(f"Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_accuracy:.4f}")
    
    # Print class distribution during training
    print(f"Training Class Distribution: {class_counts}")
    
    # **Call the scheduler here at the end of each epoch to update the learning rate**
    scheduler.step()

    # Store training loss for plotting
    train_losses.append(epoch_loss)

    # Validation phase with tqdm progress bar
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    val_class_counts = [0] * 3  # Assuming 3 classes, update if needed

    with torch.no_grad():  # Disable gradient computation for validation
        for images, labels, _ in tqdm(val_loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Track validation loss and accuracy
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            val_correct_predictions += (predicted == labels).sum().item()
            val_total_samples += labels.size(0)

            # Update class counts for validation
            for label in labels:
                val_class_counts[label.item()] += 1

            # Print predicted vs actual labels for each batch
            for i in range(len(labels)):
                print(f"Predicted: {predicted[i].item()}, Actual: {labels[i].item()}")

    val_epoch_loss = val_loss / val_total_samples
    val_epoch_accuracy = val_correct_predictions / val_total_samples
    print(f"Validation Loss: {val_epoch_loss:.4f}, Validation Accuracy: {val_epoch_accuracy:.4f}")
    
    # Print class distribution during validation
    print(f"Validation Class Distribution: {val_class_counts}")

    # Store validation loss for plotting
    val_losses.append(val_epoch_loss)

    # Save the model if it achieves better validation accuracy
    if val_epoch_accuracy > best_val_accuracy:
        best_val_accuracy = val_epoch_accuracy
        torch.save(model.state_dict(), 'best_model.pth')  # Save the best model
        print("Saved the model with improved validation accuracy.")

# End of training loop
print("Training complete.")

# Plotting the training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss', color='blue')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', color='red')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Losses Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

# Test model function with tqdm progress bar
def test_model(model, test_loader):
    model.eval()  # Set model to evaluation mode
    correct_predictions = 0
    total_samples = 0
    test_class_counts = [0] * 3  # Assuming 3 classes, update if needed
    with torch.no_grad():  # Disable gradients for testing
        for images, labels, _ in tqdm(test_loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_samples += labels.size(0)

            # Update class counts for testing
            for label in labels:
                test_class_counts[label.item()] += 1

            # Print predicted vs actual labels for each batch
            for i in range(len(labels)):
                print(f"Predicted: {predicted[i].item()}, Actual: {labels[i].item()}")

    avg_accuracy = correct_predictions / total_samples
    print(f"Test Accuracy: {avg_accuracy:.4f}")
    
    # Print class distribution during testing
    print(f"Test Class Distribution: {test_class_counts}")

# Run the test phase after training
test_model(model, test_loader)

In [None]:
# class CustomDataset(Dataset):
#     def __init__(self, csv_file=None, root_dir=None, transform=None, data_frame=None):
#         if data_frame is not None:
#             self.data = data_frame
#         elif csv_file is not None:
#             self.data = pd.read_csv(csv_file, header=0, dtype=str)
#         else:
#             raise ValueError("Either csv_file or data_frame must be provided.")
# 
#         self.root_dir = root_dir
#         self.transform = transform or self.default_transform()
#         self.valid_indices = self.get_valid_indices()
# 
#     def default_transform(self):
#         return transforms.Compose([
#             transforms.Resize((224, 224)),
#             transforms.ToTensor(),
#         ])
# 
#     def get_valid_indices(self):
#         valid_indices = []
#         for idx in tqdm(range(len(self.data)), desc="Validating images"):
#             img_name = self.data.iloc[idx, 0].strip()
#             img_name = img_name.split('/')[-1]
#         
#             if img_name.startswith("image-"):
#                 try:
#                     image_number = int(img_name.split('-')[1].split('.')[0])
#                     if image_number <= 3084:
#                         full_img_path = os.path.join(self.root_dir, img_name)
#                         if os.path.exists(full_img_path):
#                             valid_indices.append(idx)
#                             label = self.data.iloc[idx, 15]  # Assuming label is in the 15th column
#                             print(f"Valid image: {img_name}, Label: {label}")
#                         else:
#                             print(f"Image does not exist: {full_img_path}")
#                 except ValueError:
#                     print(f"Invalid filename format for {img_name}. Skipping...")
#         
#         print(f"Total valid indices found: {len(valid_indices)}")  # Debugging output
#         return valid_indices
# 
# 
# 
#     def __len__(self):
#         return len(self.valid_indices)
# 
#     def __getitem__(self, idx):
#         if isinstance(idx, list):
#             items = [self._load_sample(i) for i in idx if self._load_sample(i) is not None]
#             if not items:
#                 raise RuntimeError("No valid items found in the batch.")
#             images, labels = zip(*items)
#             return torch.stack(images), torch.tensor(labels)
#         else:
#             return self._load_sample(idx)
# 
# 
#     def _load_sample(self, idx):
#         # Get the actual index from valid indices
#         actual_idx = self.valid_indices[idx]
#         img_name = self.data.iloc[actual_idx, 0].strip()
#         full_img_path = os.path.join(self.root_dir, img_name)
#     
#         try:
#             image = Image.open(full_img_path).convert('RGB')  # Ensure image is RGB
#             label_str = self.data.iloc[actual_idx, 15]  # Assuming label is in the 15th column
#             
#             # Attempt to convert label to integer; handle exceptions
#             try:
#                 label = int(label_str)  # Try converting to int
#             except ValueError:
#                 print(f"Warning: Non-integer label found for image {img_name}: {label_str}")
#                 print()
#                 return None  # Skip this sample if label conversion fails
#     
#             # Print image and label info when loading a sample
#             print(f"Loading sample: {img_name}, Label: {label}")
#     
#             image = self.transform(image)  # Apply transformation
#     
#             return image, label
#         except Exception as e:
#             print(f"Error loading image {full_img_path}: {e}")
#             return None  # Handle error gracefully