In [1]:
# Make sure to download the necessary libraries: torch, spuco, pandas, tqdm
# We need to import os, a built-in Python module that allows us to 
# interact with the operating system, to handle files and directories
import os

# import the Pandas library for data manipulation and analysis
import pandas as pd
# import the core library of PyTorch the machine learning framework
# for tensor operations and neural networks
import torch
# Here we are importing api references and classes from the spuco package 
# The datasets api reference contains the classes we need to initialize our dataset
# SpuCoMNIST initializes the dataset of images from the MNIST dataset
# where each image has spurious features added 
# The SpuriousFeatureDifficulty is the level of the spurious feature in each image
from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
# The models api reference contains the function to create a machine learning SpuCoModel
from spuco.models import model_factory
# The robust_train api reference contains the Empirical Risk Minimization algorithim
from spuco.robust_train.erm import ERM
# The optim module in PyTorch used to adjust the weights and biases of the NN to minimize loss
import torch.optim as optim

# We need tgdm for animations
from tqdm import tqdm

print("Starting the process...")
print("Loading dataset...")
# Initialize the dataset 
# Other missing parameters will be set as default automatically
dataset = SpuCoMNIST(
    # Initialize the dataset in a folder called data
    root="./data", 
    # Set the spurious feature to a large magnitude
    spurious_feature_difficulty=SpuriousFeatureDifficulty.MAGNITUDE_LARGE,
    # We map the digits we want to classify into binary groups simplifying the problem
    classes=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
    # Here we are making the spurious correlation strength even more difficult
    spurious_correlation_strength=0.9,  
    # Here we are setting the dataset to load the training set
    # This lets the model adjust its weights based on the errors made
    split="train",  
    # Here we let the dataset be downloaded automatically if not available locally
    download=True
)

# Here we call the initialize method to initialize the dataset we defined above
dataset.initialize() 
print("Dataset initialized!")


# We are building a leNet CNN which is best for the MNIST dataset
print("Initializing the model...")
model = model_factory(arch= "lenet", 
                      # We need the dataset to be returned in a tuple 
                      # first index for the image and second for the label
                      input_shape= dataset[0][0].shape, 
                      # We have five classes because we set the classes 
                      # of 10 digits to be split in two
                      num_classes= 10)
print("Model initialized!")

# Here we are creating the optimizer
# Adam is for Adaptive Moment Estimation
# We speed up training with momentum 
# And we adjust the learning rate for each parameter
# We use a learning rate of 0.001 as it is a common choice for deep learning
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Here we are establishing the model training set up using the model we defined
# Our dataset, a batch size of 32 which is good for how many samples will be passed at once per iteration
# And we use our optimizer from before
# We use 10 epochs for how many times the model will go through training
trained_model = ERM(model= model, trainset= dataset, batch_size= 32, optimizer= optimizer, num_epochs= 5)  

# Here we start training the model
print("Starting ERM training...")

trained_model.train()

print("ERM training complete!")

# Now that the model is trained we set it to evaluation mode
# This disables dropout behavior which is only for training
# Now we can pass data through it to test the model
model.eval()

# Now create a dataloader to load the dataset into batches so you can
# pass the data efficiently

from torch.utils.data import DataLoader

# We set shuffle to false
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# For each batch of data perform a forward pass through the model 
# To collect outputs when passing data through the NN we trained
# Disable gradients since you don't need it for inference mode since
# we are just making predictions

# Here we are initializing an array of outputs 
outputs = []
print("Generating outputs...")

# The outptus that will be produced are called logits which are raw scores
# so the model will output a tensor of size 5

# We turn off gradients to reduce memory-use since we are not training 
with torch.no_grad():
    # We need to iterate over each batch in the data loader
    for batch in tqdm(dataloader, desc="Processing batches"):
        # We don't want the labels in the batches only the images of the dataset
        inputs, _ = batch
        # We then pass the input into the model and set the logits to the batch_outputs
        batch_outputs = model(inputs)
        # We then append the logits to the output array
        outputs.append(batch_outputs)

# Here we concatenate the outputs into a single tensor using the dim=0 so we can input
# them into the cluster function from spuco
outputs = torch.cat(outputs, dim=0)
print("Outputs generated!")

# Import the cluster class from the group_inference api reference
from spuco.group_inference import Cluster

# We need to organize the outputs into clusters so we can detect biases and use group-balanced training
print("Performing clustering...")

cluster = Cluster(outputs, num_clusters=5)

# We need the cluster to be in a group partition for it to work in the group balance batch ERM function
group_partition = cluster.infer_groups()

print("Clustering complete!")


# We use another optimizer for retraining
optimizer_new = optim.Adam(model.parameters(), lr=0.001)

# Here we import the GroupBalanceBatchERM from the robust_train API Reference
from spuco.robust_train import GroupBalanceBatchERM

# This training will ensure each group is represented equally
# This is useful for the spurious correlations problem since there are groups
# in the data that the model relies on for prediction
balance = GroupBalanceBatchERM(model= model, trainset= dataset, 
                                           group_partition= group_partition,
                                           batch_size= 32,
                                           optimizer= optimizer_new,
                                           num_epochs= 10)

print("Start group-balanced training...")

balance.train()

print("Group-balanced training complete!")


def evaluate_model(model, dataset):
    model.eval()  # Set the model to evaluation mode
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)  # Create a DataLoader to provide the data
    # here we have a counter to count how many correct predictions made
    correct = 0
    # here is a counter for the total amount of samples processed
    total = 0

    # Here we do another forward pass
    print("Evaluating the model...")
    with torch.no_grad():  # Disable gradients for evaluation
        # We want to track how many truth/predictions we printed
        printed_count = 0
        for batch in tqdm(dataloader, desc="Evaluating batches"):
            inputs, labels = batch  # Get inputs and true labels
            outputs = model(inputs)  # Get model predictions (logits)
            _, predicted = torch.max(outputs.data, 1)  # Convert logits to predicted classes
            
            # Print truth and predictions for the first 5 samples
            for truth, pred in zip(labels, predicted):
                if printed_count < 5:
                    print(f"Real digit: {truth.item()}, Predicted digit: {pred.item()}")
                    printed_count += 1
        
            total += labels.size(0)  # Increment the total number of samples
            correct += (predicted == labels).sum().item()  # Count correct predictions
    
    accuracy = correct / total  # calculate accuracy
    return accuracy

# Now call the function to eval the accuracy
accuracy = evaluate_model(model, dataset)

print(f"Accuracy: {accuracy * 100:.2f}%")
print(accuracy)

Starting the process...
Loading dataset...


100%|████████████████████████████████████████████████████████████████| 48004/48004 [00:04<00:00, 11385.67it/s]


Dataset initialized!
Initializing the model...
Model initialized!
Starting ERM training...
ERM training complete!
Generating outputs...


Processing batches: 100%|████████████████████████████████████████████████| 1501/1501 [00:01<00:00, 816.17it/s]


Outputs generated!
Performing clustering...
Clustering complete!
Start group-balanced training...
Group-balanced training complete!
Evaluating the model...


Evaluating batches:   1%|▋                                                 | 22/1501 [00:00<00:07, 202.41it/s]

Real digit: 2, Predicted digit: 2
Real digit: 0, Predicted digit: 0
Real digit: 2, Predicted digit: 2
Real digit: 0, Predicted digit: 0
Real digit: 4, Predicted digit: 4


Evaluating batches: 100%|████████████████████████████████████████████████| 1501/1501 [00:02<00:00, 710.83it/s]

Accuracy: 99.21%
0.992125656195317



