# MNIST 10 Digit Classifier With 2 hidden layers

## Import the libraries

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import pandas as pd
import pickle
import random
import math
import os
from sklearn import metrics

## Declarations

### Given Template Functions

In [None]:
class My_dataset(Dataset):
    """
    Dataset Class for any dataset.
    This is a python class object, it inherits functions from 
    the pytorch Dataset object.
    For anyone unfamiliar with the python class object, see 
    https://www.w3schools.com/python/python_classes.asp
    or a more complicated but more detailed tutorial
    https://docs.python.org/3/tutorial/classes.html
    For anyone familiar with python class, but unfamiliar with pytorch
    Dataset object, see 
    https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files
    """

    def __init__(self, data_dir, anno_csv) -> object:
        self.anno_data = pd.read_csv(anno_csv)
        self.data_dir = data_dir

    def __len__(self):
        return len(self.anno_data)

    def __getitem__(self, idx):
        data_name = self.anno_data.iloc[idx, 0]
        data_location = self.data_dir + data_name
        data = np.float32(np.load(data_location))
        # print("read data:", data) # TODO: Remove
        # This is for one-hot encoding of the output label
        gt_y = np.float32(np.zeros(10))
        index = self.anno_data.iloc[idx, 1]
        gt_y[index] = 1
        return data, gt_y
    
    def split_to_batches(self, batch_num):
        """
        Function takes in batch number and returns a tensor of batch_num samples
        """

        # Shuffle the annotation dataframe
        # Source: https://www.aporia.com/resources/how-to/shuffle-dataframe-rows/
        shuffled_anno_data = self.anno_data.sample(frac=1)
        
        # Split the shuffled dataframe into an equally split list
        # Source: https://stackoverflow.com/questions/54730276/how-to-randomly-split-a-dataframe-into-several-smaller-dataframes
        batches = np.array_split(shuffled_anno_data, batch_num)
        return batches

    def head(self):
        return self.anno_data.head()
    
    def get_batches(self, batch_num):
        #Randomly split your training data into mini-batches where each mini-batch has 50 samples
        #Since we have 50000 training samples, and each batch has 50 samples,
        #the total number of batch will be 1000
        # YOU ARE NOT ALLOWED TO USE DATALOADER CLASS FOR RANDOM BATCH SELECTION

        # device = utilsNN.set_torch_device()
        device = torch.device('cpu')
        
        # Create a list of indices
        indices = np.arange(0, len(self))

        # Shuffle the randomized indices
        random.shuffle(indices)

        # Split the indices array into batches
        batched_indices = np.array_split(indices, batch_num)

        # print(batched_indices)
        # print("batched_indices shape:", len(batched_indices), "x", len(batched_indices[0]))

        # Batch the data from the randomized indices and return them
        batched_data = []
        train_batched_Y = []
        for indices in batched_indices:
            # For each batch
            data_batch, gt_y_batch = My_dataset.get_data(self, indices)
            batched_data.append(data_batch)
            train_batched_Y.append(gt_y_batch)

        data = torch.as_tensor(np.array(batched_data), device=device)
        gt_y = torch.as_tensor(np.array(train_batched_Y), device=device)
        # gt_y = train_batched_Y

        # t_list = [t0, t1, t2]
        # torch.stack(t_list, dim =0 )

        # print("batched data device:", data.device) # TODO: Remove
        
        return data, gt_y    

    @staticmethod
    def get_data(dataset: Dataset, indices):
        """
        Returns data and gt_y as lists
        """
        device = utilsNN.set_torch_device('cpu')
        
        # print("Indices:", indices) # Todo: remove
        data_list = []
        gt_y_list = []
        for index in indices:
            datum, gt_y_single = dataset[index]
            data_list.append(datum)
            gt_y_list.append(gt_y_single)
        data = torch.as_tensor(np.array(data_list), device=device)
        gt_y = torch.as_tensor(np.array(gt_y_list), device=device)

        # print("data device:", data.device) # TODO: Remove

        return data, gt_y


class utilsNN:
    @staticmethod
    def init_weight(input_dim, output_dim):
        # Initializes weights using variant of Xavier weight initialization:
        # https://machinelearningmastery.com/weight-initialization-for-deep-learning-neural-networks/
        uniform_dist_one = torch.as_tensor(np.random.rand(input_dim+1, output_dim), dtype=torch.float32)
        weight = torch.div(2*(uniform_dist_one - 0.5), math.sqrt(input_dim))
        # print("weight device:", weight.device) # TODO: Remove
        return weight
    
    @staticmethod
    def ReLU(layer):
        """
        ReLU Function applied to each layer
        """
        # Create zero matrix
        zeros_matrix = torch.zeros_like(layer)

        # Perform and return ReLU operation
        return torch.maximum(layer, zeros_matrix)

    @staticmethod
    def forward_step(layer, weights, activation_func):
        # device = utilsNN.set_torch_device()
        X = utilsNN.cat_ones(layer)
        # print(f"Multiplying layer ({X.shape}) and weights ({weights.shape})") # TODO: Remove
        # print(f"X device: {X.device}, weights device: {weights.device}") # TODO: Remove
        z = torch.matmul(X, weights.to(X.device))
        output = activation_func(z)
        return output
    
    @staticmethod
    def softmax(tensor):
        """
        Applies softmax to entire tensor
        """
        # device = utilsNN.set_torch_device()
        # Subtract max for stable softmax (https://stackoverflow.com/questions/42599498/numerically-stable-softmax)
        X = tensor - torch.max(input=tensor, dim=tensor.dim()-1, keepdim=True)[0]
        
        # Now perform softmax function
        result = torch.exp(X)
        result /= torch.sum(result, dim=tensor.dim()-1, keepdim=True)
        return result
    
    @staticmethod
    def loss_NLL(predicted, actual):
        """
        Negative log likelihood loss function.
        """
        # print(f"Multiplying actual ({actual.shape}) by expected ({expected.shape})") # TODO: Remove
        M = actual.shape[0]
        loss = torch.sum( -torch.multiply(actual, torch.log(predicted)) / M)
        return loss.item()
    
    @staticmethod
    def cat_ones(my_tensor):
        # device = utilsNN.set_torch_device()
        # print("my tensor shape:", my_tensor.shape) # TODO: Remove
        
        # Create the ones matrix to append
        ones_shape_list = list(my_tensor.shape)
        ones_shape_list[-1] = 1
        ones_shape = tuple(ones_shape_list)
        ones_column = torch.ones(ones_shape)
        # print(ones_column)
        # print(torch.ones((2,3,4)))
        concatenated_tensor = torch.cat((my_tensor, ones_column), dim=my_tensor.dim()-1)
        # print(concatenated_tensor.shape) # TODO: Remove
        return concatenated_tensor
    
    @staticmethod
    def calc_output_grad(predicted, actual):
        return -torch.div(actual, predicted)
            
    @staticmethod
    def softmax_deriv(softmax_result):
        """
        Return softmax derivative, referenced formula from
        https://stats.stackexchange.com/questions/215521/how-to-find-derivative-of-softmax-function-for-the-purpose-of-gradient-descent
        """
        # print("softmax_result.shape", softmax_result.shape) # TODO: Remove
        # squeezed_softmax = softmax_result.squeeze()
        # print("squeezed_softmax.shape", squeezed_softmax.shape) # TODO: Remove

        # Calculate Diag(y)
        diag_y_list = [torch.diag(datum) for datum in softmax_result]
        diag_y = torch.stack(diag_y_list)
        
        # print("diag_y.shape:", diag_y.shape) # TODO: Remove
        # print(diag_y) # TODO: Remove
        y  = softmax_result.unsqueeze(2) # Becomes M x K x 1
        y_T = softmax_result.unsqueeze(1) # Becomes M X 1 X K
        # yy_T = torch.matmul(softmax_result, torch.transpose(softmax_result, 0, 1))
        yy_T = torch.matmul(y, y_T)
        # print("yy_T shape:", yy_T.shape) # TODO: Remove
        return diag_y - yy_T
    
    @staticmethod
    def relu_deriv(relu_result):
        # Create a binary tensor to transform
        bin_ten = torch.where(relu_result > 0, 1.0, 0.0)

        # Create M x N x N identity matrix
        identity_m = torch.eye(relu_result.shape[-1])

        # Apply the transformation
        derivative = torch.einsum("bi,ij->bij", bin_ten, identity_m)

        return derivative
    
    @staticmethod
    def derivative_dzdW(prev_layer, curr_layer_node_num):
        """
        Derivative of z with respect to the weight for layer l, with
        the previous layer as an input.

        Input must be contain the batches, so dimension of 3
        """
        # print("prev_layer_dim:", prev_layer.dim()) # TODO: Remove
        assert(prev_layer.dim() == 2)
        identity_m = torch.eye(curr_layer_node_num)
        derivative = torch.einsum("ij,bn->binj", identity_m, prev_layer)
        # print(f"Derivative dzdW with shape {derivative.shape}: {derivative}") # TODO: Remove
        return derivative
    
    @staticmethod 
    def backward_step(model, layer_index, layer_grad, activation : str, last_pass : bool = False):

        # Calculate the derivative of the activation function
        if activation == 'softmax':
            dg_dz = utilsNN.softmax_deriv(model.layers[layer_index])
        elif activation == 'ReLU':
            dg_dz = utilsNN.relu_deriv(model.layers[layer_index])
        else:
            raise Exception(f'Invalid activation function string provided: "{activation}"')
        
        # Calculate the derivative of z, or the layer, with respect to W
        dzdW = utilsNN.derivative_dzdW(
            prev_layer = model.layers[layer_index-1],
            curr_layer_node_num = model.layers[layer_index].shape[1]
        )

        # Calculate the gradient of the bias
        grad_W0 = torch.einsum("bij,bj->bi", dg_dz, layer_grad)

        # Calculate the gradient of the weights
        grad_W = torch.einsum("bijk,bk->bji", dzdW, grad_W0)
        
        # Concatenate to form Theta matrix
        grad_Theta = torch.cat((grad_W, grad_W0.unsqueeze(1)), dim=1)

        # Calculate the gradient of the previous layer if last pass
        if not last_pass:
            prev_layer_grad = torch.einsum("ij,bj->bi", model.weights[layer_index - 1][:-1].to(device=grad_W0.device), grad_W0)
            gradients = (grad_Theta, prev_layer_grad)
        else:
            gradients = (grad_Theta)

        return gradients
    
    @staticmethod
    def update_weights(weights, gradients, learning_rate : float):
        for i in range(len(weights)):
            weights[i] = weights[i] - learning_rate*(gradients[i].to(device=weights[i].device))

    @staticmethod
    def calc_acc(prediction, actual):
        # Calculate accuracy
        accuracy = torch.sum(torch.argmax(prediction, 1) == torch.argmax(actual, 1))/actual.shape[0]
        return accuracy.item()
    
    @staticmethod
    def calc_digit_errors(prediction, actual, total_digits):
        digit_errors = []
        for y in range(total_digits):
            indices = torch.nonzero(torch.argmax(prediction, 1) == y, as_tuple=False)
            digit_err = 1 - utilsNN.calc_acc(prediction[indices[:,0]], actual[indices[:,0]])
            digit_errors.append(digit_err)
        
        return digit_errors
    
    @staticmethod
    def set_torch_device(forced_device="none"):
        if forced_device == "cpu" or not torch.cuda.is_available():
            device = torch.device('cpu')
            torch.set_default_device(device)
        elif forced_device == "cuda" or torch.cuda.is_available():
            device = torch.device('cuda')
            torch.set_default_device(device)
        return device
    
    # @staticmethod
    # def test(dir : str, anno : str, model = None):



class DigitClassifierModelDNN:
    def __init__(self, input_dim: int, hidden_layer_num: int, hidden_node_num: int, output_dim: int):
        # device = utilsNN.set_torch_device()
        
        # self.hidden_layer_num = hidden_layer_num
        self.output_dim = output_dim
        self.layers = [0]*(2 + hidden_layer_num)

        # Initialize the weights
        self.weights = []

        ## Initialize the weight matrix from the input to the first hidden later
        Theta_1 = utilsNN.init_weight(input_dim, hidden_node_num) # Includes the bias as the last row
        self.weights.append(Theta_1)

        ## Initialize the weight matrix for the rest of the hidden layers
        for n in range(2, hidden_layer_num + 1):
            Theta_n = utilsNN.init_weight(hidden_node_num, hidden_node_num)
            # print("theta_n type:", type(Theta_n)) # TODO: Remove
            self.weights.append(Theta_n)
        
        ## Add the final weight for the last hidden layer 
        self.weights.append(utilsNN.init_weight(hidden_node_num, output_dim))
        

    def forward(self, input):
        """
        param input is a 2D tensor
        """

        L = len(self.layers) - 2
        self.layers[0] = input

        # self.layers[0] = utilsNN.forward_step(input, self.weights[0], utilsNN.ReLU)
        for n in range(1, L + 1):
            self.layers[n] = utilsNN.forward_step(self.layers[n-1], self.weights[n-1], utilsNN.ReLU)

            # print("next_hidden_layer:", hidden_layer_n) # TODO: Remove

        self.layers[L+1] = utilsNN.forward_step(self.layers[L], self.weights[L], utilsNN.softmax)

        return self.layers[L+1]
        
    def backward(self, output_grad):
        """
        Propagate backwards
        """
        grad_Theta_3, grad_H_2 = utilsNN.backward_step(
            model = self,
            layer_index = 3,
            layer_grad = output_grad,
            activation = "softmax",
            last_pass = False
        )

        # TODO: Remove test prints below
        # print(f"grad_Theta_3 with shape {grad_Theta_3.shape}:", grad_Theta_3)
        # print(f"grad_H_2 with shape {grad_H_2.shape}:", grad_H_2)

        grad_Theta_2, grad_H_1 = utilsNN.backward_step(
            model = self,
            layer_index = 2,
            layer_grad = grad_H_2,
            activation = "ReLU",
            last_pass = False
        )

        grad_Theta_1 = utilsNN.backward_step(
            model = self,
            layer_index = 1,
            layer_grad = grad_H_1,
            activation = "ReLU",
            last_pass = True
        )

        # Average the gradients over the batch
        grads = [torch.mean(grad_Theta_1, dim=0)]
        grads.append(torch.mean(grad_Theta_2, dim=0))
        grads.append(torch.mean(grad_Theta_3, dim=0))

        gradients = tuple(grads)
        return gradients
    
    def pred(self, input, actual_out):
        # Make the prediction
        y_hat = self.forward(input)


        # # TODO: Remove
        # print_tensor(data, "data")
        # print_tensor(gt_Y, "gt_Y")
        # print_tensor(pred, "pred")

        # Calculate accuracy
        accuracy = utilsNN.calc_acc(prediction=y_hat, actual=actual_out)
        digit_errors = utilsNN.calc_digit_errors(prediction=y_hat, actual=actual_out, total_digits=10)
        loss = utilsNN.loss_NLL(predicted=y_hat, actual=actual_out)

        return (accuracy, digit_errors, loss)
    
    def save_weights(self, fname : str):
        # Convert to desired format
        Theta = []
        for weight in self.weights:
            W = weight.to(device=torch.device('cpu')).numpy()[:-1]
            W_0 = weight.to(device=torch.device('cpu')).numpy()[-1]
            Theta.append(W)
            Theta.append(W_0)
        
        # Save the file
        filehandler = open(fname,"wb") 
        pickle.dump(Theta, filehandler, protocol=2)
        filehandler.close()

    def load_weights(self, fname : str):
        filehandler = open(fname, "rb")
        Theta = pickle.load(filehandler)
        W = Theta[0]
        for i in range(1, len(Theta)):
            if i % 2 == 0:
                W = Theta[i]
            else:
                weight = np.concatenate((W, np.expand_dims(Theta[i], axis=0)), axis=0)
                self.weights[i//2] = torch.as_tensor(weight)
        
        return self.weights
    
    def change_weights_device(self, device):
        for i in range(len(self.weights)):
            self.weights[i] = self.weights[i].to(device = device)
                

# HELPER FUNCTIONS
def PA2_get_data(data_dir, anno_csv, batch_num):
    device = utilsNN.set_torch_device('cpu')
    # Read the data and labels from the training data
    dataset = My_dataset(data_dir=data_dir, anno_csv=anno_csv)
    batched_data, batched_labels_enc = dataset.get_batches(batch_num=batch_num)
    return batched_data, batched_labels_enc

# Main Functions
def PA2_train(train_batched_X, train_batched_Y, test_batched_X, test_batched_Y):
    # Specifying the device to GPU/CPU. Here, GPU means 'cuda' and CPU means 'cpu'
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    device = utilsNN.set_torch_device()

    # Learning rate
    learning_rate = 0.13

    # Initialize the model
    model = DigitClassifierModelDNN(
        input_dim = train_batched_X.shape[-1],
        hidden_layer_num = 2,
        hidden_node_num = 100,
        output_dim = 10
    )
    
    # Initialize list for losses and accuracies
    train_losses = []
    train_accuracies = []
    train_errors_digits = []
    test_losses = []
    test_accuracies = []
    test_errors_digits = []
    

    #You can set up your own maximum epoch. You may need  5 or 10 epochs to have a correct model.
    my_max_epoch = 10
    epochs = np.arange(0, my_max_epoch)
    for epoch in epochs:
        total_batch = len(train_batched_X)
        
        # print(len(batched_X[0][0])) # TODO: Remove

        print(f"----Epoch {epoch:>d}----")

        for b in range(total_batch):
            '''Compute the loss for each batch, gradient of loss with respect to W, 
            and update W accordingly.'''
            calc_stats_per_batch = False # CHANGE TO TRUE IF YOU WANT TO SEE INTERMEDIATE BATCH STATS

            device = utilsNN.set_torch_device()

            train_X_batch = train_batched_X[b].to(device = device)
            train_Y_batch = train_batched_Y[b].to(device = device)

            # print("First weight device:", model.weights[0].device) # TODO: Remove

            # Perform forward propagation
            forward_results = model.forward(train_X_batch)
        
            # Calculate the output gradient
            output_grad = utilsNN.calc_output_grad(predicted=forward_results, actual=train_Y_batch.to(device=device))
        
            # Propagate backwards
            gradients = model.backward(output_grad=output_grad)
            # print(f"gradients[0] device: {gradients[0].device}") # TODO: Remove

            utilsNN.update_weights(model.weights, gradients, learning_rate)
            # Plot training losses versus epochs

            # Plot the average training classification error, average testing classification error,
            # and value of the loss function after each parameters update (assuming after each batch).
            # Also do the batch error for each digit, and overall

        # Calculate the stats
        device = utilsNN.set_torch_device('cpu') 
        # Calculate train loss
        batch_train_loss = utilsNN.loss_NLL(predicted=forward_results, actual=train_Y_batch)

        # Check for nan
        if math.isnan(batch_train_loss):
            raise Exception("ERROR: NAN DETECTED")
        
        # Training Accuracy
        batch_train_accuracy = utilsNN.calc_acc(forward_results, train_Y_batch)

        # Training digit error
        batch_train_digit_errors = utilsNN.calc_digit_errors(forward_results, train_Y_batch, 10)
        
        # Testing stats, need to change the device as well
        model.change_weights_device(device)
        batch_test_accuracy, batch_test_digit_errors, batch_test_loss = model.pred(test_batched_X[0], test_batched_Y[0])
        device = utilsNN.set_torch_device()
        model.change_weights_device(device)

        # Append the data
        ## Train
        train_losses.append(batch_train_loss)
        train_accuracies.append(batch_train_accuracy)
        train_errors_digits.append(batch_train_digit_errors)
        ## Test
        test_losses.append(batch_test_loss)
        test_accuracies.append(batch_test_accuracy)
        test_errors_digits.append(batch_test_digit_errors)
            
        print(f"Epoch Training | Loss: {batch_train_loss:.4f} | Accuracy: {batch_train_accuracy:.4f}")

        #Take the mean of all the mini-batch losses and denote it as your loss of the current epoch

    stats = {
        'training' : {
            'losses' : train_losses,
            'accuracies' : train_accuracies,
            'digit errors' : np.transpose(np.array(train_errors_digits))
        },
        'testing' : {
            'losses' : test_losses,
            'accuracies' : test_accuracies,
            'digit errors' : np.transpose(np.array(test_errors_digits))
        },
        'epochs' : epochs
    }

    # Plot the training loss vs accuracy
    # Visualize the final weight matrix
    # Save the final weight matrix
    model.save_weights("nn_parameters.txt")

    return stats


def PA2_test(model = None):
    # Specifying the training directory and label files
    test_dir = './'
    test_anno_file = './data_prog2Spring24/labels/test_anno.csv'
    feature_length = 784
    # Specifying the device to GPU/CPU. Here, GPU means 'cuda' and CPU means 'cpu'
    torch.set_default_device('cpu')
    # device = utilsNN.set_torch_device('cpu')

    # Get the testing data
    # MNIST_testing_dataset = My_dataset(data_dir=test_dir, anno_csv=test_anno_file)
    data, Y_encoded = PA2_get_data(test_dir, test_anno_file, 1)
    
    if (model == None):
        # Predict Y using X and updated W.
        model = DigitClassifierModelDNN(
            input_dim = data.shape[-1],
            hidden_layer_num = 2,
            hidden_node_num = 100,
            output_dim = 10
        )

    # Load the weights into the model
    model.load_weights("nn_parameters.txt")

    model.change_weights_device(torch.device('cpu'))

    # Make the prediction and get the accuracies
    accuracy, digit_errors, loss = model.pred(data[0], Y_encoded[0])

    create_conf_mat(model.layers[-1], Y_encoded[0])

    # print_tensor(pred, "pred")

    # Calculate accuracy
    print(digit_errors)

    print("Test accuracy:", accuracy)

def line_plot(xvals, xlab : str, ylab: str, yvals0, ylab0 : str, dir : str, yvals1 = None, ylab1 : str = None):
    """
    line_plot Creates a line plot of the xvals vs yvals
    """
    fig = plt.figure()
    
    if yvals1:
        plt.plot(xvals, yvals0, label = ylab0)
        plt.plot(xvals, yvals1, label = ylab1)
        plt.legend()
    else:
        plt.plot(xvals, yvals0)
    
    plt.title(f"{ylab} vs {xlab}")
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    file_path = os.path.join(dir, f"{ylab.replace(' ', '_')}_vs_{xlab.replace(' ', '_')}_plot.png")
    plt.savefig(file_path)

def plot_digit_errors(group : str, xvals, xlab, errors, dir : str):
    ylab = 'Errors'
    
    fig = plt.figure()
    for i in range(10):
        plt.plot(xvals, errors[i], label = str(i))
    
    plt.legend(loc='upper right')

    plt.title(f"{group} Digit {ylab} vs {xlab}")
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    file_path = os.path.join(dir, f"{group}_digit_{ylab.replace(' ', '_')}_vs_{xlab.replace(' ', '_')}_plot.png")
    plt.savefig(file_path)

def performance_plots(stats : dict):
    plot_dir = './plots'
    # line_plot(
    #     xvals0=stats['epochs'],
    #     yvals=stats['training']['losses'],
    #     xlabel="Epochs",
    #     ylabel="Average Training Loss"
    # )
    print(stats['training']['losses'])

    # Training and testing loss over epochs plot 
    line_plot(
        xvals = stats['epochs'],
        xlab='Epochs',
        ylab='Average Loss',
        yvals0=stats['training']['losses'],
        ylab0='Training Losses',
        yvals1=stats['testing']['losses'],
        ylab1='Testing Losses',
        dir = plot_dir
    )

    # Training and testing accuracies over epochs plot
    line_plot(
        xvals = stats['epochs'],
        xlab='Epochs',
        ylab='Average Accuracy',
        yvals0=stats['training']['accuracies'],
        ylab0='Training Accuracies',
        yvals1=stats['testing']['accuracies'],
        ylab1='Testing Accuracies',
        dir = plot_dir
    )

    # Training and testing digit errors
    plot_digit_errors(
        group="Training",
        xvals = stats['epochs'],
        xlab='Epochs',
        errors=stats['training']['digit errors'],
        dir = plot_dir
    )

    plot_digit_errors(
        group="Testing",
        xvals = stats['epochs'],
        xlab='Epochs',
        errors=stats['testing']['digit errors'],
        dir = plot_dir
    )

def one_hot_decode(tensor):
    return torch.argmax(tensor, dim=1).numpy()

def create_conf_mat(prediction, actual):
    my_pred = one_hot_decode(prediction)
    my_act = one_hot_decode(actual)

    cm_fig = plt.figure()
    confusion_matrix = metrics.confusion_matrix(y_pred=my_pred, y_true=my_act)
    cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix,
                                                display_labels=range(10))
    # cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix)
    cm_display.plot()
    plt.savefig('./plots/class_confusion_matrix.jpg')

def print_tensor(tensor, tensor_name):
    torch.set_printoptions(threshold=1_000)
    print(f"{tensor_name} with shape {tensor.shape}: {tensor}")
    torch.set_printoptions(profile="default")


## Main Body

### Define constants

In [None]:
# if __name__ == "__main__":

# Specifying the training directory and label files
train_dir = './'
train_anno_file = './data_prog2Spring24/labels/train_anno.csv'
# Specifying the training directory and label files
test_dir = './'
test_anno_file = './data_prog2Spring24/labels/test_anno.csv'
total_batch = 1000

### Get the batched data

In [None]:
train_batched_X, train_batched_Y = PA2_get_data(
    data_dir=train_dir,
    anno_csv=train_anno_file,
    batch_num=total_batch
)

test_batched_X, test_batched_Y = PA2_get_data(
    data_dir=train_dir,
    anno_csv=train_anno_file,
    batch_num=total_batch
)

### Train off of the batched data

In [None]:
stats = PA2_train(
    train_batched_X=train_batched_X,
    train_batched_Y=train_batched_Y,
    test_batched_X=test_batched_X,
    test_batched_Y=test_batched_Y
)

In [None]:
performance_plots(stats)

In [None]:
PA2_test()