# **Deep Learning Course**

## **Loss Functions and Multilayer Perceptrons (MLP)**

---

### **Student Information:**

- **Name:** *Zahra Maleki*
- **Student Number:** *400110009*

---

### **Assignment Overview**

In this notebook, we will explore various loss functions used in neural networks, with a specific focus on their role in training **Multilayer Perceptrons (MLPs)**. By the end of this notebook, you will have a deeper understanding of:
- Types of loss functions
- How loss functions affect the training process
- The relationship between loss functions and model optimization in MLPs

---

### **Table of Contents**

1. Introduction to Loss Functions
2. Types of Loss Functions
3. Multilayer Perceptrons (MLP)
4. Implementing Loss Functions in MLP
5. Conclusion

---



# 1.Introduction to Loss Functions 

In deep learning, **loss functions** play a crucial role in training models by quantifying the difference between the predicted outputs and the actual targets. Selecting the appropriate loss function is essential for the success of your model. In this assay, we will explore various loss functions available in PyTorch, understand their theoretical backgrounds, and provide you with a scaffolded class to experiment with these loss functions.

Before begining, let's train a simle MLP model using the **L1Loss** function. We'll return to this model later to experiment with different loss functions. We'll start by importing the necessary libraries and defining the model architecture.

First things first, let's talk about **L1Loss**.

### 1. L1Loss (`torch.nn.L1Loss`)
- **Description:** Also known as Mean Absolute Error (MAE), L1Loss computes the average absolute difference between the predicted values and the target values.
- **Use Case:** Suitable for regression tasks where robustness to outliers is desired.

Here is the mathematical formulation of L1Loss:
\begin{equation}
\text{L1Loss} = \frac{1}{n} \sum_{i=1}^{n} |y_{\text{pred}_i} - y_{\text{true}_i}|
\end{equation}

Let's implement a simple MLP model using the L1Loss function.

In [111]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from tqdm import tqdm 
from torch.nn import NLLLoss, LogSoftmax
import torch.nn.functional as F
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from io import BytesIO
from PIL import Image
from sklearn.neural_network import MLPClassifier
# Don't be courious about Adam, it's just a fancy name for a fancy optimization algorithm

Here, we'll define a class called `SimpleMLP` that inherits from `nn.Module`. This class can have multiple layers, and we'll use the `nn.Sequential` module to define the layers of the model. The model will have the following architecture:

In [84]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers=1, last_layer_activation_fn=nn.ReLU):
        super(SimpleMLP, self).__init__()
        # TODO: Define the layers of the MLP

        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))  
        layers.append(nn.ReLU()) 

        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim)) 
            layers.append(nn.ReLU()) 
    
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        if last_layer_activation_fn is not None:
            layers.append(last_layer_activation_fn())

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        # TODO: Define the forward pass of the MLP
        return self.model(x)

Now, let's define a class called `SimpleMLP_Loss` that has the following architecture:

In [85]:
class SimpleMLPTrainer:
    def __init__(self, model, criterion, optimizer):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def train(self, train_loader, num_epochs):
        #TODO: Implement the training loop
        #Note: You should also print the training loss at each epoch, use tqdm for progress bar
        #Note: You should return the training loss at each epoch

        training_losses = []

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            self.model.train() 
            for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                
                loss = self.criterion(outputs, targets)
                
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item() * inputs.size(0) 

            epoch_loss /= len(train_loader.dataset)
            training_losses.append(epoch_loss)
            if (epoch % 10 ==0):
                print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")

        return training_losses

    def evaluate(self, val_loader):
        #TODO: Implement the evaluation loop
        #Note: You should return the validation loss and accuracy
        self.model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad(): 
            for inputs, targets in val_loader:
              
                outputs = self.model(inputs)
         
                loss = self.criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)

                predictions = (torch.sigmoid(outputs) >= 0.5).float() 
                correct_predictions += (predictions == targets).sum().item()
                total_predictions += targets.size(0)

        val_loss /= len(val_loader.dataset)
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

        print(f"Validation Loss: {val_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
        
        return val_loss, accuracy


Next, lets test our model using the L1Loss function. You'll use <span style="color:red">*Titanic Dataset*</span> to train the model.


In [41]:
# Load dataset
train_url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
data = pd.read_csv(train_url)

# Preprocessing (simple example)
data = data[['Pclass', 'Sex', 'Age', 'Fare', 'Survived']].dropna()
data['Sex'] = data['Sex'].map({'male': 0, 'female': 1})

# TODO: Convert the data to PyTorch tensors and create a DataLoader
# TODO: Split the data into training and validation sets
# TODO: Define the model, criterion, and optimizer

X = data[['Pclass', 'Sex', 'Age', 'Fare']].values
y = data['Survived'].values

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(1) 

X_train, X_val, y_train, y_val = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

input_dim = X_train.shape[1]
hidden_dim = 16
output_dim = 1  

model = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
criterion = nn.L1Loss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)

trainer = SimpleMLPTrainer(model, criterion, optimizer)

num_epochs = 30
training_losses = trainer.train(train_loader, num_epochs=num_epochs)

val_loss, val_accuracy = trainer.evaluate(val_loader)

print(f"Final Validation Loss: {val_loss:.4f}")
print(f"Final Validation Accuracy: {val_accuracy * 100:.2f}%")

Epoch 1/30: 100%|██████████| 18/18 [00:00<00:00, 764.71it/s]


Epoch [1/30], Training Loss: 1.3003


Epoch 2/30: 100%|██████████| 18/18 [00:00<00:00, 675.05it/s]
Epoch 3/30: 100%|██████████| 18/18 [00:00<00:00, 1081.94it/s]
Epoch 4/30: 100%|██████████| 18/18 [00:00<00:00, 568.80it/s]
Epoch 5/30: 100%|██████████| 18/18 [00:00<00:00, 868.50it/s]
Epoch 6/30: 100%|██████████| 18/18 [00:00<00:00, 375.86it/s]
Epoch 7/30: 100%|██████████| 18/18 [00:00<00:00, 1094.63it/s]
Epoch 8/30: 100%|██████████| 18/18 [00:00<00:00, 1124.16it/s]
Epoch 9/30: 100%|██████████| 18/18 [00:00<00:00, 728.72it/s]
Epoch 10/30: 100%|██████████| 18/18 [00:00<00:00, 570.24it/s]
Epoch 11/30: 100%|██████████| 18/18 [00:00<00:00, 445.17it/s]


Epoch [11/30], Training Loss: 0.4123


Epoch 12/30: 100%|██████████| 18/18 [00:00<00:00, 445.25it/s]
Epoch 13/30: 100%|██████████| 18/18 [00:00<00:00, 440.58it/s]
Epoch 14/30: 100%|██████████| 18/18 [00:00<00:00, 749.82it/s]
Epoch 15/30: 100%|██████████| 18/18 [00:00<00:00, 371.51it/s]
Epoch 16/30: 100%|██████████| 18/18 [00:00<00:00, 547.29it/s]
Epoch 17/30: 100%|██████████| 18/18 [00:00<00:00, 559.75it/s]
Epoch 18/30: 100%|██████████| 18/18 [00:00<00:00, 969.18it/s]
Epoch 19/30: 100%|██████████| 18/18 [00:00<00:00, 736.41it/s]
Epoch 20/30: 100%|██████████| 18/18 [00:00<00:00, 738.90it/s]
Epoch 21/30: 100%|██████████| 18/18 [00:00<00:00, 750.01it/s]


Epoch [21/30], Training Loss: 0.4102


Epoch 22/30: 100%|██████████| 18/18 [00:00<00:00, 1125.17it/s]
Epoch 23/30: 100%|██████████| 18/18 [00:00<00:00, 748.68it/s]
Epoch 24/30: 100%|██████████| 18/18 [00:00<00:00, 1122.76it/s]
Epoch 25/30: 100%|██████████| 18/18 [00:00<00:00, 1157.14it/s]
Epoch 26/30: 100%|██████████| 18/18 [00:00<00:00, 764.97it/s]
Epoch 27/30: 100%|██████████| 18/18 [00:00<00:00, 1122.79it/s]
Epoch 28/30: 100%|██████████| 18/18 [00:00<00:00, 748.60it/s]
Epoch 29/30: 100%|██████████| 18/18 [00:00<00:00, 1104.44it/s]
Epoch 30/30: 100%|██████████| 18/18 [00:00<00:00, 1126.54it/s]

Validation Loss: 0.3916, Accuracy: 39.16%
Final Validation Loss: 0.3916
Final Validation Accuracy: 39.16%





<div style="text-align: center;"> <span style="color:red; font-size: 26px; font-weight: bold;">Let's train!</span> </div>

In [42]:
from torch.nn import L1Loss

# TODO: Train the model

# TODO: Evaluate the model

num_epochs = 300
batch_size = 32
learning_rate = 0.001


input_dim = X_train.shape[1]
hidden_dim = 16
output_dim = 1  

model = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, last_layer_activation_fn=None)
criterion = nn.BCEWithLogitsLoss() 
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

trainer = SimpleMLPTrainer(model, criterion, optimizer)

training_losses = trainer.train(train_loader, num_epochs=num_epochs)

val_loss, val_accuracy = trainer.evaluate(val_loader)

print(f"Final Validation Loss: {val_loss:.4f}")
print(f"Final Validation Accuracy: {val_accuracy * 100:.2f}%")


Epoch 1/300: 100%|██████████| 18/18 [00:00<00:00, 662.76it/s]


Epoch [1/300], Training Loss: 1.2978


Epoch 2/300: 100%|██████████| 18/18 [00:00<00:00, 1087.00it/s]
Epoch 3/300: 100%|██████████| 18/18 [00:00<00:00, 748.40it/s]
Epoch 4/300: 100%|██████████| 18/18 [00:00<00:00, 720.73it/s]
Epoch 5/300: 100%|██████████| 18/18 [00:00<00:00, 1115.36it/s]
Epoch 6/300: 100%|██████████| 18/18 [00:00<00:00, 1099.26it/s]
Epoch 7/300: 100%|██████████| 18/18 [00:00<00:00, 1122.22it/s]
Epoch 8/300: 100%|██████████| 18/18 [00:00<00:00, 813.54it/s]
Epoch 9/300: 100%|██████████| 18/18 [00:00<00:00, 1046.25it/s]
Epoch 10/300: 100%|██████████| 18/18 [00:00<00:00, 1162.34it/s]
Epoch 11/300: 100%|██████████| 18/18 [00:00<00:00, 1089.15it/s]


Epoch [11/300], Training Loss: 0.5713


Epoch 12/300: 100%|██████████| 18/18 [00:00<00:00, 1051.35it/s]
Epoch 13/300: 100%|██████████| 18/18 [00:00<00:00, 1091.43it/s]
Epoch 14/300: 100%|██████████| 18/18 [00:00<00:00, 1088.10it/s]
Epoch 15/300: 100%|██████████| 18/18 [00:00<00:00, 739.90it/s]
Epoch 16/300: 100%|██████████| 18/18 [00:00<00:00, 736.36it/s]
Epoch 17/300: 100%|██████████| 18/18 [00:00<00:00, 714.77it/s]
Epoch 18/300: 100%|██████████| 18/18 [00:00<00:00, 716.07it/s]
Epoch 19/300: 100%|██████████| 18/18 [00:00<00:00, 764.79it/s]
Epoch 20/300: 100%|██████████| 18/18 [00:00<00:00, 744.68it/s]
Epoch 21/300: 100%|██████████| 18/18 [00:00<00:00, 883.92it/s]


Epoch [21/300], Training Loss: 0.5414


Epoch 22/300: 100%|██████████| 18/18 [00:00<00:00, 1123.26it/s]
Epoch 23/300: 100%|██████████| 18/18 [00:00<00:00, 551.28it/s]
Epoch 24/300: 100%|██████████| 18/18 [00:00<00:00, 1121.11it/s]
Epoch 25/300: 100%|██████████| 18/18 [00:00<00:00, 1132.03it/s]
Epoch 26/300: 100%|██████████| 18/18 [00:00<00:00, 731.52it/s]
Epoch 27/300: 100%|██████████| 18/18 [00:00<00:00, 763.05it/s]
Epoch 28/300: 100%|██████████| 18/18 [00:00<00:00, 1098.15it/s]
Epoch 29/300: 100%|██████████| 18/18 [00:00<00:00, 746.22it/s]
Epoch 30/300: 100%|██████████| 18/18 [00:00<00:00, 743.26it/s]
Epoch 31/300: 100%|██████████| 18/18 [00:00<00:00, 562.54it/s]


Epoch [31/300], Training Loss: 0.5293


Epoch 32/300: 100%|██████████| 18/18 [00:00<00:00, 742.11it/s]
Epoch 33/300: 100%|██████████| 18/18 [00:00<00:00, 743.99it/s]
Epoch 34/300: 100%|██████████| 18/18 [00:00<00:00, 749.49it/s]
Epoch 35/300: 100%|██████████| 18/18 [00:00<00:00, 727.04it/s]
Epoch 36/300: 100%|██████████| 18/18 [00:00<00:00, 711.88it/s]
Epoch 37/300: 100%|██████████| 18/18 [00:00<00:00, 729.61it/s]
Epoch 38/300: 100%|██████████| 18/18 [00:00<00:00, 560.49it/s]
Epoch 39/300: 100%|██████████| 18/18 [00:00<00:00, 745.93it/s]
Epoch 40/300: 100%|██████████| 18/18 [00:00<00:00, 544.75it/s]
Epoch 41/300: 100%|██████████| 18/18 [00:00<00:00, 747.86it/s]


Epoch [41/300], Training Loss: 0.5043


Epoch 42/300: 100%|██████████| 18/18 [00:00<00:00, 339.64it/s]
Epoch 43/300: 100%|██████████| 18/18 [00:00<00:00, 735.77it/s]
Epoch 44/300: 100%|██████████| 18/18 [00:00<00:00, 734.42it/s]
Epoch 45/300: 100%|██████████| 18/18 [00:00<00:00, 747.63it/s]
Epoch 46/300: 100%|██████████| 18/18 [00:00<00:00, 738.13it/s]
Epoch 47/300: 100%|██████████| 18/18 [00:00<00:00, 749.35it/s]
Epoch 48/300: 100%|██████████| 18/18 [00:00<00:00, 1092.09it/s]
Epoch 49/300: 100%|██████████| 18/18 [00:00<00:00, 739.48it/s]
Epoch 50/300: 100%|██████████| 18/18 [00:00<00:00, 734.87it/s]
Epoch 51/300: 100%|██████████| 18/18 [00:00<00:00, 742.98it/s]


Epoch [51/300], Training Loss: 0.4811


Epoch 52/300: 100%|██████████| 18/18 [00:00<00:00, 775.61it/s]
Epoch 53/300: 100%|██████████| 18/18 [00:00<00:00, 1091.24it/s]
Epoch 54/300: 100%|██████████| 18/18 [00:00<00:00, 782.99it/s]
Epoch 55/300: 100%|██████████| 18/18 [00:00<00:00, 715.15it/s]
Epoch 56/300: 100%|██████████| 18/18 [00:00<00:00, 742.68it/s]
Epoch 57/300: 100%|██████████| 18/18 [00:00<00:00, 680.48it/s]
Epoch 58/300: 100%|██████████| 18/18 [00:00<00:00, 1115.62it/s]
Epoch 59/300: 100%|██████████| 18/18 [00:00<00:00, 1084.25it/s]
Epoch 60/300: 100%|██████████| 18/18 [00:00<00:00, 1121.09it/s]
Epoch 61/300: 100%|██████████| 18/18 [00:00<00:00, 627.79it/s]


Epoch [61/300], Training Loss: 0.4692


Epoch 62/300: 100%|██████████| 18/18 [00:00<00:00, 1124.11it/s]
Epoch 63/300: 100%|██████████| 18/18 [00:00<00:00, 1103.89it/s]
Epoch 64/300: 100%|██████████| 18/18 [00:00<00:00, 1102.97it/s]
Epoch 65/300: 100%|██████████| 18/18 [00:00<00:00, 734.10it/s]
Epoch 66/300: 100%|██████████| 18/18 [00:00<00:00, 1096.87it/s]
Epoch 67/300: 100%|██████████| 18/18 [00:00<00:00, 1149.02it/s]
Epoch 68/300: 100%|██████████| 18/18 [00:00<00:00, 1120.92it/s]
Epoch 69/300: 100%|██████████| 18/18 [00:00<00:00, 561.69it/s]
Epoch 70/300: 100%|██████████| 18/18 [00:00<00:00, 743.80it/s]
Epoch 71/300: 100%|██████████| 18/18 [00:00<00:00, 744.63it/s]


Epoch [71/300], Training Loss: 0.4643


Epoch 72/300: 100%|██████████| 18/18 [00:00<00:00, 441.18it/s]
Epoch 73/300: 100%|██████████| 18/18 [00:00<00:00, 740.76it/s]
Epoch 74/300: 100%|██████████| 18/18 [00:00<00:00, 1078.38it/s]
Epoch 75/300: 100%|██████████| 18/18 [00:00<00:00, 1128.97it/s]
Epoch 76/300: 100%|██████████| 18/18 [00:00<00:00, 1112.24it/s]
Epoch 77/300: 100%|██████████| 18/18 [00:00<00:00, 1065.79it/s]
Epoch 78/300: 100%|██████████| 18/18 [00:00<00:00, 739.50it/s]
Epoch 79/300: 100%|██████████| 18/18 [00:00<00:00, 1072.45it/s]
Epoch 80/300: 100%|██████████| 18/18 [00:00<00:00, 1092.73it/s]
Epoch 81/300: 100%|██████████| 18/18 [00:00<00:00, 1061.57it/s]


Epoch [81/300], Training Loss: 0.4588


Epoch 82/300: 100%|██████████| 18/18 [00:00<00:00, 976.69it/s]
Epoch 83/300: 100%|██████████| 18/18 [00:00<00:00, 830.80it/s]
Epoch 84/300: 100%|██████████| 18/18 [00:00<00:00, 704.38it/s]
Epoch 85/300: 100%|██████████| 18/18 [00:00<00:00, 1114.22it/s]
Epoch 86/300: 100%|██████████| 18/18 [00:00<00:00, 1098.03it/s]
Epoch 87/300: 100%|██████████| 18/18 [00:00<00:00, 885.28it/s]
Epoch 88/300: 100%|██████████| 18/18 [00:00<00:00, 737.12it/s]
Epoch 89/300: 100%|██████████| 18/18 [00:00<00:00, 1093.47it/s]
Epoch 90/300: 100%|██████████| 18/18 [00:00<00:00, 743.15it/s]
Epoch 91/300: 100%|██████████| 18/18 [00:00<00:00, 1088.49it/s]


Epoch [91/300], Training Loss: 0.4550


Epoch 92/300: 100%|██████████| 18/18 [00:00<00:00, 1057.65it/s]
Epoch 93/300: 100%|██████████| 18/18 [00:00<00:00, 1101.25it/s]
Epoch 94/300: 100%|██████████| 18/18 [00:00<00:00, 1080.98it/s]
Epoch 95/300: 100%|██████████| 18/18 [00:00<00:00, 1086.73it/s]
Epoch 96/300: 100%|██████████| 18/18 [00:00<00:00, 1087.59it/s]
Epoch 97/300: 100%|██████████| 18/18 [00:00<00:00, 1069.76it/s]
Epoch 98/300: 100%|██████████| 18/18 [00:00<00:00, 1103.05it/s]
Epoch 99/300: 100%|██████████| 18/18 [00:00<00:00, 1094.06it/s]
Epoch 100/300: 100%|██████████| 18/18 [00:00<00:00, 944.53it/s]
Epoch 101/300: 100%|██████████| 18/18 [00:00<00:00, 1233.74it/s]


Epoch [101/300], Training Loss: 0.4508


Epoch 102/300: 100%|██████████| 18/18 [00:00<00:00, 1095.06it/s]
Epoch 103/300: 100%|██████████| 18/18 [00:00<00:00, 738.35it/s]
Epoch 104/300: 100%|██████████| 18/18 [00:00<00:00, 1104.72it/s]
Epoch 105/300: 100%|██████████| 18/18 [00:00<00:00, 445.24it/s]
Epoch 106/300: 100%|██████████| 18/18 [00:00<00:00, 743.20it/s]
Epoch 107/300: 100%|██████████| 18/18 [00:00<00:00, 709.53it/s]
Epoch 108/300: 100%|██████████| 18/18 [00:00<00:00, 1471.26it/s]
Epoch 109/300: 100%|██████████| 18/18 [00:00<00:00, 1901.22it/s]
Epoch 110/300: 100%|██████████| 18/18 [00:00<00:00, 1095.33it/s]
Epoch 111/300: 100%|██████████| 18/18 [00:00<00:00, 1085.90it/s]


Epoch [111/300], Training Loss: 0.4528


Epoch 112/300: 100%|██████████| 18/18 [00:00<00:00, 2108.40it/s]
Epoch 113/300: 100%|██████████| 18/18 [00:00<00:00, 1110.73it/s]
Epoch 114/300: 100%|██████████| 18/18 [00:00<00:00, 1145.29it/s]
Epoch 115/300: 100%|██████████| 18/18 [00:00<00:00, 1081.56it/s]
Epoch 116/300: 100%|██████████| 18/18 [00:00<00:00, 1096.88it/s]
Epoch 117/300: 100%|██████████| 18/18 [00:00<00:00, 1067.03it/s]
Epoch 118/300: 100%|██████████| 18/18 [00:00<00:00, 1104.75it/s]
Epoch 119/300: 100%|██████████| 18/18 [00:00<00:00, 2241.35it/s]
Epoch 120/300: 100%|██████████| 18/18 [00:00<00:00, 1075.49it/s]
Epoch 121/300: 100%|██████████| 18/18 [00:00<00:00, 1114.68it/s]


Epoch [121/300], Training Loss: 0.4475


Epoch 122/300: 100%|██████████| 18/18 [00:00<00:00, 1074.07it/s]
Epoch 123/300: 100%|██████████| 18/18 [00:00<00:00, 757.65it/s]
Epoch 124/300: 100%|██████████| 18/18 [00:00<00:00, 1108.56it/s]
Epoch 125/300: 100%|██████████| 18/18 [00:00<00:00, 1148.72it/s]
Epoch 126/300: 100%|██████████| 18/18 [00:00<00:00, 718.95it/s]
Epoch 127/300: 100%|██████████| 18/18 [00:00<00:00, 564.61it/s]
Epoch 128/300: 100%|██████████| 18/18 [00:00<00:00, 736.59it/s]
Epoch 129/300: 100%|██████████| 18/18 [00:00<00:00, 728.11it/s]
Epoch 130/300: 100%|██████████| 18/18 [00:00<00:00, 746.33it/s]
Epoch 131/300: 100%|██████████| 18/18 [00:00<00:00, 1094.50it/s]


Epoch [131/300], Training Loss: 0.4462


Epoch 132/300: 100%|██████████| 18/18 [00:00<00:00, 735.20it/s]
Epoch 133/300: 100%|██████████| 18/18 [00:00<00:00, 1101.28it/s]
Epoch 134/300: 100%|██████████| 18/18 [00:00<00:00, 1168.76it/s]
Epoch 135/300: 100%|██████████| 18/18 [00:00<00:00, 726.89it/s]
Epoch 136/300: 100%|██████████| 18/18 [00:00<00:00, 1115.21it/s]
Epoch 137/300: 100%|██████████| 18/18 [00:00<00:00, 2184.85it/s]
Epoch 138/300: 100%|██████████| 18/18 [00:00<00:00, 447.70it/s]
Epoch 139/300: 100%|██████████| 18/18 [00:00<00:00, 739.36it/s]
Epoch 140/300: 100%|██████████| 18/18 [00:00<00:00, 1071.54it/s]
Epoch 141/300: 100%|██████████| 18/18 [00:00<00:00, 1112.20it/s]


Epoch [141/300], Training Loss: 0.4448


Epoch 142/300: 100%|██████████| 18/18 [00:00<00:00, 1117.21it/s]
Epoch 143/300: 100%|██████████| 18/18 [00:00<00:00, 1104.09it/s]
Epoch 144/300: 100%|██████████| 18/18 [00:00<00:00, 1106.79it/s]
Epoch 145/300: 100%|██████████| 18/18 [00:00<00:00, 1975.91it/s]
Epoch 146/300: 100%|██████████| 18/18 [00:00<00:00, 1118.48it/s]
Epoch 147/300: 100%|██████████| 18/18 [00:00<00:00, 1078.77it/s]
Epoch 148/300: 100%|██████████| 18/18 [00:00<00:00, 2064.46it/s]
Epoch 149/300: 100%|██████████| 18/18 [00:00<00:00, 1091.93it/s]
Epoch 150/300: 100%|██████████| 18/18 [00:00<00:00, 1086.47it/s]
Epoch 151/300: 100%|██████████| 18/18 [00:00<00:00, 1073.49it/s]


Epoch [151/300], Training Loss: 0.4432


Epoch 152/300: 100%|██████████| 18/18 [00:00<00:00, 1124.75it/s]
Epoch 153/300: 100%|██████████| 18/18 [00:00<00:00, 1086.78it/s]
Epoch 154/300: 100%|██████████| 18/18 [00:00<00:00, 1160.07it/s]
Epoch 155/300: 100%|██████████| 18/18 [00:00<00:00, 1120.42it/s]
Epoch 156/300: 100%|██████████| 18/18 [00:00<00:00, 685.41it/s]
Epoch 157/300: 100%|██████████| 18/18 [00:00<00:00, 1095.04it/s]
Epoch 158/300: 100%|██████████| 18/18 [00:00<00:00, 688.66it/s]
Epoch 159/300: 100%|██████████| 18/18 [00:00<00:00, 759.88it/s]
Epoch 160/300: 100%|██████████| 18/18 [00:00<00:00, 664.64it/s]
Epoch 161/300: 100%|██████████| 18/18 [00:00<00:00, 729.09it/s]


Epoch [161/300], Training Loss: 0.4405


Epoch 162/300: 100%|██████████| 18/18 [00:00<00:00, 749.03it/s]
Epoch 163/300: 100%|██████████| 18/18 [00:00<00:00, 735.87it/s]
Epoch 164/300: 100%|██████████| 18/18 [00:00<00:00, 749.91it/s]
Epoch 165/300: 100%|██████████| 18/18 [00:00<00:00, 759.65it/s]
Epoch 166/300: 100%|██████████| 18/18 [00:00<00:00, 806.19it/s]
Epoch 167/300: 100%|██████████| 18/18 [00:00<00:00, 1096.39it/s]
Epoch 168/300: 100%|██████████| 18/18 [00:00<00:00, 2366.77it/s]
Epoch 169/300: 100%|██████████| 18/18 [00:00<00:00, 562.92it/s]
Epoch 170/300: 100%|██████████| 18/18 [00:00<00:00, 1076.09it/s]
Epoch 171/300: 100%|██████████| 18/18 [00:00<00:00, 758.46it/s]


Epoch [171/300], Training Loss: 0.4420


Epoch 172/300: 100%|██████████| 18/18 [00:00<00:00, 546.91it/s]
Epoch 173/300: 100%|██████████| 18/18 [00:00<00:00, 1118.80it/s]
Epoch 174/300: 100%|██████████| 18/18 [00:00<00:00, 1092.88it/s]
Epoch 175/300: 100%|██████████| 18/18 [00:00<00:00, 1058.26it/s]
Epoch 176/300: 100%|██████████| 18/18 [00:00<00:00, 746.33it/s]
Epoch 177/300: 100%|██████████| 18/18 [00:00<00:00, 746.58it/s]
Epoch 178/300: 100%|██████████| 18/18 [00:00<00:00, 747.51it/s]
Epoch 179/300: 100%|██████████| 18/18 [00:00<00:00, 454.49it/s]
Epoch 180/300: 100%|██████████| 18/18 [00:00<00:00, 460.26it/s]
Epoch 181/300: 100%|██████████| 18/18 [00:00<00:00, 553.57it/s]


Epoch [181/300], Training Loss: 0.4414


Epoch 182/300: 100%|██████████| 18/18 [00:00<00:00, 538.30it/s]
Epoch 183/300: 100%|██████████| 18/18 [00:00<00:00, 749.14it/s]
Epoch 184/300: 100%|██████████| 18/18 [00:00<00:00, 1099.58it/s]
Epoch 185/300: 100%|██████████| 18/18 [00:00<00:00, 1080.43it/s]
Epoch 186/300: 100%|██████████| 18/18 [00:00<00:00, 1117.42it/s]
Epoch 187/300: 100%|██████████| 18/18 [00:00<00:00, 1076.18it/s]
Epoch 188/300: 100%|██████████| 18/18 [00:00<00:00, 1173.94it/s]
Epoch 189/300: 100%|██████████| 18/18 [00:00<00:00, 1106.25it/s]
Epoch 190/300: 100%|██████████| 18/18 [00:00<00:00, 1136.87it/s]
Epoch 191/300: 100%|██████████| 18/18 [00:00<00:00, 1115.70it/s]


Epoch [191/300], Training Loss: 0.4426


Epoch 192/300: 100%|██████████| 18/18 [00:00<00:00, 1148.83it/s]
Epoch 193/300: 100%|██████████| 18/18 [00:00<00:00, 1097.19it/s]
Epoch 194/300: 100%|██████████| 18/18 [00:00<00:00, 1103.01it/s]
Epoch 195/300: 100%|██████████| 18/18 [00:00<00:00, 1082.63it/s]
Epoch 196/300: 100%|██████████| 18/18 [00:00<00:00, 1095.96it/s]
Epoch 197/300: 100%|██████████| 18/18 [00:00<00:00, 1097.46it/s]
Epoch 198/300: 100%|██████████| 18/18 [00:00<00:00, 2060.58it/s]
Epoch 199/300: 100%|██████████| 18/18 [00:00<00:00, 1093.50it/s]
Epoch 200/300: 100%|██████████| 18/18 [00:00<00:00, 1094.91it/s]
Epoch 201/300: 100%|██████████| 18/18 [00:00<00:00, 1078.23it/s]


Epoch [201/300], Training Loss: 0.4435


Epoch 202/300: 100%|██████████| 18/18 [00:00<00:00, 1066.59it/s]
Epoch 203/300: 100%|██████████| 18/18 [00:00<00:00, 1086.83it/s]
Epoch 204/300: 100%|██████████| 18/18 [00:00<00:00, 1100.03it/s]
Epoch 205/300: 100%|██████████| 18/18 [00:00<00:00, 1091.63it/s]
Epoch 206/300: 100%|██████████| 18/18 [00:00<00:00, 1463.90it/s]
Epoch 207/300: 100%|██████████| 18/18 [00:00<00:00, 551.10it/s]
Epoch 208/300: 100%|██████████| 18/18 [00:00<00:00, 692.63it/s]
Epoch 209/300: 100%|██████████| 18/18 [00:00<00:00, 931.06it/s]
Epoch 210/300: 100%|██████████| 18/18 [00:00<00:00, 2157.32it/s]
Epoch 211/300: 100%|██████████| 18/18 [00:00<00:00, 2087.18it/s]


Epoch [211/300], Training Loss: 0.4373


Epoch 212/300: 100%|██████████| 18/18 [00:00<00:00, 1089.07it/s]
Epoch 213/300: 100%|██████████| 18/18 [00:00<00:00, 1079.63it/s]
Epoch 214/300: 100%|██████████| 18/18 [00:00<00:00, 1087.34it/s]
Epoch 215/300: 100%|██████████| 18/18 [00:00<00:00, 1116.02it/s]
Epoch 216/300: 100%|██████████| 18/18 [00:00<00:00, 1079.18it/s]
Epoch 217/300: 100%|██████████| 18/18 [00:00<00:00, 1193.30it/s]
Epoch 218/300: 100%|██████████| 18/18 [00:00<00:00, 718.74it/s]
Epoch 219/300: 100%|██████████| 18/18 [00:00<00:00, 1090.41it/s]
Epoch 220/300: 100%|██████████| 18/18 [00:00<00:00, 985.11it/s]
Epoch 221/300: 100%|██████████| 18/18 [00:00<00:00, 2365.80it/s]


Epoch [221/300], Training Loss: 0.4358


Epoch 222/300: 100%|██████████| 18/18 [00:00<00:00, 1149.88it/s]
Epoch 223/300: 100%|██████████| 18/18 [00:00<00:00, 1097.95it/s]
Epoch 224/300: 100%|██████████| 18/18 [00:00<00:00, 1108.40it/s]
Epoch 225/300: 100%|██████████| 18/18 [00:00<00:00, 860.36it/s]
Epoch 226/300: 100%|██████████| 18/18 [00:00<00:00, 1147.80it/s]
Epoch 227/300: 100%|██████████| 18/18 [00:00<00:00, 734.35it/s]
Epoch 228/300: 100%|██████████| 18/18 [00:00<00:00, 1150.24it/s]
Epoch 229/300: 100%|██████████| 18/18 [00:00<00:00, 1120.47it/s]
Epoch 230/300: 100%|██████████| 18/18 [00:00<00:00, 1005.33it/s]
Epoch 231/300: 100%|██████████| 18/18 [00:00<00:00, 777.51it/s]


Epoch [231/300], Training Loss: 0.4366


Epoch 232/300: 100%|██████████| 18/18 [00:00<00:00, 1100.51it/s]
Epoch 233/300: 100%|██████████| 18/18 [00:00<00:00, 732.56it/s]
Epoch 234/300: 100%|██████████| 18/18 [00:00<00:00, 1018.83it/s]
Epoch 235/300: 100%|██████████| 18/18 [00:00<00:00, 562.31it/s]
Epoch 236/300: 100%|██████████| 18/18 [00:00<00:00, 292.61it/s]
Epoch 237/300: 100%|██████████| 18/18 [00:00<00:00, 562.24it/s]
Epoch 238/300: 100%|██████████| 18/18 [00:00<00:00, 1125.03it/s]
Epoch 239/300: 100%|██████████| 18/18 [00:00<00:00, 746.82it/s]
Epoch 240/300: 100%|██████████| 18/18 [00:00<00:00, 570.83it/s]
Epoch 241/300: 100%|██████████| 18/18 [00:00<00:00, 562.71it/s]


Epoch [241/300], Training Loss: 0.4366


Epoch 242/300: 100%|██████████| 18/18 [00:00<00:00, 745.36it/s]
Epoch 243/300: 100%|██████████| 18/18 [00:00<00:00, 560.85it/s]
Epoch 244/300: 100%|██████████| 18/18 [00:00<00:00, 718.44it/s]
Epoch 245/300: 100%|██████████| 18/18 [00:00<00:00, 734.67it/s]
Epoch 246/300: 100%|██████████| 18/18 [00:00<00:00, 561.87it/s]
Epoch 247/300: 100%|██████████| 18/18 [00:00<00:00, 742.81it/s]
Epoch 248/300: 100%|██████████| 18/18 [00:00<00:00, 562.34it/s]
Epoch 249/300: 100%|██████████| 18/18 [00:00<00:00, 557.90it/s]
Epoch 250/300: 100%|██████████| 18/18 [00:00<00:00, 750.35it/s]
Epoch 251/300: 100%|██████████| 18/18 [00:00<00:00, 750.33it/s]


Epoch [251/300], Training Loss: 0.4356


Epoch 252/300: 100%|██████████| 18/18 [00:00<00:00, 559.65it/s]
Epoch 253/300: 100%|██████████| 18/18 [00:00<00:00, 699.54it/s]
Epoch 254/300: 100%|██████████| 18/18 [00:00<00:00, 709.48it/s]
Epoch 255/300: 100%|██████████| 18/18 [00:00<00:00, 699.43it/s]
Epoch 256/300: 100%|██████████| 18/18 [00:00<00:00, 600.46it/s]
Epoch 257/300: 100%|██████████| 18/18 [00:00<00:00, 764.42it/s]
Epoch 258/300: 100%|██████████| 18/18 [00:00<00:00, 1125.13it/s]
Epoch 259/300: 100%|██████████| 18/18 [00:00<00:00, 561.72it/s]
Epoch 260/300: 100%|██████████| 18/18 [00:00<00:00, 562.78it/s]
Epoch 261/300: 100%|██████████| 18/18 [00:00<00:00, 562.41it/s]


Epoch [261/300], Training Loss: 0.4358


Epoch 262/300: 100%|██████████| 18/18 [00:00<00:00, 781.42it/s]
Epoch 263/300: 100%|██████████| 18/18 [00:00<00:00, 447.56it/s]
Epoch 264/300: 100%|██████████| 18/18 [00:00<00:00, 551.50it/s]
Epoch 265/300: 100%|██████████| 18/18 [00:00<00:00, 449.62it/s]
Epoch 266/300: 100%|██████████| 18/18 [00:00<00:00, 562.77it/s]
Epoch 267/300: 100%|██████████| 18/18 [00:00<00:00, 733.34it/s]
Epoch 268/300: 100%|██████████| 18/18 [00:00<00:00, 750.52it/s]
Epoch 269/300: 100%|██████████| 18/18 [00:00<00:00, 750.17it/s]
Epoch 270/300: 100%|██████████| 18/18 [00:00<00:00, 628.53it/s]
Epoch 271/300: 100%|██████████| 18/18 [00:00<00:00, 1120.07it/s]


Epoch [271/300], Training Loss: 0.4372


Epoch 272/300: 100%|██████████| 18/18 [00:00<00:00, 749.18it/s]
Epoch 273/300: 100%|██████████| 18/18 [00:00<00:00, 717.49it/s]
Epoch 274/300: 100%|██████████| 18/18 [00:00<00:00, 549.75it/s]
Epoch 275/300: 100%|██████████| 18/18 [00:00<00:00, 544.39it/s]
Epoch 276/300: 100%|██████████| 18/18 [00:00<00:00, 748.62it/s]
Epoch 277/300: 100%|██████████| 18/18 [00:00<00:00, 1121.41it/s]
Epoch 278/300: 100%|██████████| 18/18 [00:00<00:00, 1195.73it/s]
Epoch 279/300: 100%|██████████| 18/18 [00:00<00:00, 1199.57it/s]
Epoch 280/300: 100%|██████████| 18/18 [00:00<00:00, 561.86it/s]
Epoch 281/300: 100%|██████████| 18/18 [00:00<00:00, 635.52it/s]


Epoch [281/300], Training Loss: 0.4351


Epoch 282/300: 100%|██████████| 18/18 [00:00<00:00, 747.78it/s]
Epoch 283/300: 100%|██████████| 18/18 [00:00<00:00, 718.62it/s]
Epoch 284/300: 100%|██████████| 18/18 [00:00<00:00, 737.75it/s]
Epoch 285/300: 100%|██████████| 18/18 [00:00<00:00, 561.95it/s]
Epoch 286/300: 100%|██████████| 18/18 [00:00<00:00, 1085.37it/s]
Epoch 287/300: 100%|██████████| 18/18 [00:00<00:00, 748.73it/s]
Epoch 288/300: 100%|██████████| 18/18 [00:00<00:00, 661.48it/s]
Epoch 289/300: 100%|██████████| 18/18 [00:00<00:00, 737.52it/s]
Epoch 290/300: 100%|██████████| 18/18 [00:00<00:00, 562.06it/s]
Epoch 291/300: 100%|██████████| 18/18 [00:00<00:00, 561.68it/s]


Epoch [291/300], Training Loss: 0.4376


Epoch 292/300: 100%|██████████| 18/18 [00:00<00:00, 1122.67it/s]
Epoch 293/300: 100%|██████████| 18/18 [00:00<00:00, 764.94it/s]
Epoch 294/300: 100%|██████████| 18/18 [00:00<00:00, 748.27it/s]
Epoch 295/300: 100%|██████████| 18/18 [00:00<00:00, 751.19it/s]
Epoch 296/300: 100%|██████████| 18/18 [00:00<00:00, 854.99it/s]
Epoch 297/300: 100%|██████████| 18/18 [00:00<00:00, 1066.42it/s]
Epoch 298/300: 100%|██████████| 18/18 [00:00<00:00, 1154.43it/s]
Epoch 299/300: 100%|██████████| 18/18 [00:00<00:00, 1125.69it/s]
Epoch 300/300: 100%|██████████| 18/18 [00:00<00:00, 1116.07it/s]

Validation Loss: 0.5122, Accuracy: 73.43%
Final Validation Loss: 0.5122
Final Validation Accuracy: 73.43%





---
# 2. Types of Loss Functions

PyTorch offers a variety of built-in loss functions tailored for different types of problems, such as regression, classification, and more. Below, we discuss several commonly used loss functions, their theoretical foundations, and typical use cases.

### 2. MSELoss (`torch.nn.MSELoss`)
- **Description:** Mean Squared Error (MSE) calculates the average of the squares of the differences between predicted and target values.
- **Use Case:** Commonly used in regression problems where larger errors are significantly penalized.

Here is boring math stuff for MSE:
\begin{equation}
\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_{i} - \hat{y}_{i})^{2}
\end{equation}

<span style="color:red; font-size: 18px; font-weight: bold;">Warning:</span> Don't forget to reinitialize the model before experimenting with different loss functions.

In [43]:
from torch.nn import MSELoss

# TODO: Train the model

# TODO: Evaluate the model

model = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, last_layer_activation_fn=None)
criterion = MSELoss()  
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

trainer = SimpleMLPTrainer(model, criterion, optimizer)

training_losses = trainer.train(train_loader, num_epochs=700)

val_loss, val_accuracy = trainer.evaluate(val_loader)

print(f"Final Validation Loss with MSE: {val_loss:.4f}")
print(f"Final Validation Accuracy with MSE: {val_accuracy * 100:.2f}%")

Epoch 1/700: 100%|██████████| 18/18 [00:00<00:00, 704.29it/s]


Epoch [1/700], Training Loss: 3.7033


Epoch 2/700: 100%|██████████| 18/18 [00:00<00:00, 749.66it/s]
Epoch 3/700: 100%|██████████| 18/18 [00:00<00:00, 1116.23it/s]
Epoch 4/700: 100%|██████████| 18/18 [00:00<00:00, 1125.01it/s]
Epoch 5/700: 100%|██████████| 18/18 [00:00<00:00, 748.40it/s]
Epoch 6/700: 100%|██████████| 18/18 [00:00<00:00, 1061.83it/s]
Epoch 7/700: 100%|██████████| 18/18 [00:00<00:00, 2241.54it/s]
Epoch 8/700: 100%|██████████| 18/18 [00:00<00:00, 1112.50it/s]
Epoch 9/700: 100%|██████████| 18/18 [00:00<00:00, 704.35it/s]
Epoch 10/700: 100%|██████████| 18/18 [00:00<00:00, 798.90it/s]
Epoch 11/700: 100%|██████████| 18/18 [00:00<00:00, 1124.96it/s]


Epoch [11/700], Training Loss: 0.2424


Epoch 12/700: 100%|██████████| 18/18 [00:00<00:00, 1124.95it/s]
Epoch 13/700: 100%|██████████| 18/18 [00:00<00:00, 1116.78it/s]
Epoch 14/700: 100%|██████████| 18/18 [00:00<00:00, 1196.85it/s]
Epoch 15/700: 100%|██████████| 18/18 [00:00<00:00, 637.51it/s]
Epoch 16/700: 100%|██████████| 18/18 [00:00<00:00, 1052.89it/s]
Epoch 17/700: 100%|██████████| 18/18 [00:00<00:00, 1125.06it/s]
Epoch 18/700: 100%|██████████| 18/18 [00:00<00:00, 747.93it/s]
Epoch 19/700: 100%|██████████| 18/18 [00:00<00:00, 1314.53it/s]
Epoch 20/700: 100%|██████████| 18/18 [00:00<00:00, 764.17it/s]
Epoch 21/700: 100%|██████████| 18/18 [00:00<00:00, 750.82it/s]


Epoch [21/700], Training Loss: 0.2151


Epoch 22/700: 100%|██████████| 18/18 [00:00<00:00, 553.90it/s]
Epoch 23/700: 100%|██████████| 18/18 [00:00<00:00, 733.15it/s]
Epoch 24/700: 100%|██████████| 18/18 [00:00<00:00, 764.85it/s]
Epoch 25/700: 100%|██████████| 18/18 [00:00<00:00, 691.84it/s]
Epoch 26/700: 100%|██████████| 18/18 [00:00<00:00, 805.04it/s]
Epoch 27/700: 100%|██████████| 18/18 [00:00<00:00, 690.95it/s]
Epoch 28/700: 100%|██████████| 18/18 [00:00<00:00, 1075.89it/s]
Epoch 29/700: 100%|██████████| 18/18 [00:00<00:00, 699.54it/s]
Epoch 30/700: 100%|██████████| 18/18 [00:00<00:00, 787.32it/s]
Epoch 31/700: 100%|██████████| 18/18 [00:00<00:00, 1170.21it/s]


Epoch [31/700], Training Loss: 0.2010


Epoch 32/700: 100%|██████████| 18/18 [00:00<00:00, 723.22it/s]
Epoch 33/700: 100%|██████████| 18/18 [00:00<00:00, 454.67it/s]
Epoch 34/700: 100%|██████████| 18/18 [00:00<00:00, 557.36it/s]
Epoch 35/700: 100%|██████████| 18/18 [00:00<00:00, 746.64it/s]
Epoch 36/700: 100%|██████████| 18/18 [00:00<00:00, 1119.33it/s]
Epoch 37/700: 100%|██████████| 18/18 [00:00<00:00, 714.55it/s]
Epoch 38/700: 100%|██████████| 18/18 [00:00<00:00, 1159.36it/s]
Epoch 39/700: 100%|██████████| 18/18 [00:00<00:00, 749.79it/s]
Epoch 40/700: 100%|██████████| 18/18 [00:00<00:00, 739.41it/s]
Epoch 41/700: 100%|██████████| 18/18 [00:00<00:00, 702.92it/s]


Epoch [41/700], Training Loss: 0.1808


Epoch 42/700: 100%|██████████| 18/18 [00:00<00:00, 781.85it/s]
Epoch 43/700: 100%|██████████| 18/18 [00:00<00:00, 1123.07it/s]
Epoch 44/700: 100%|██████████| 18/18 [00:00<00:00, 750.77it/s]
Epoch 45/700: 100%|██████████| 18/18 [00:00<00:00, 747.48it/s]
Epoch 46/700: 100%|██████████| 18/18 [00:00<00:00, 1125.01it/s]
Epoch 47/700: 100%|██████████| 18/18 [00:00<00:00, 749.96it/s]
Epoch 48/700: 100%|██████████| 18/18 [00:00<00:00, 749.09it/s]
Epoch 49/700: 100%|██████████| 18/18 [00:00<00:00, 748.63it/s]
Epoch 50/700: 100%|██████████| 18/18 [00:00<00:00, 1156.43it/s]
Epoch 51/700: 100%|██████████| 18/18 [00:00<00:00, 1126.88it/s]


Epoch [51/700], Training Loss: 0.1680


Epoch 52/700: 100%|██████████| 18/18 [00:00<00:00, 1112.06it/s]
Epoch 53/700: 100%|██████████| 18/18 [00:00<00:00, 748.36it/s]
Epoch 54/700: 100%|██████████| 18/18 [00:00<00:00, 1120.06it/s]
Epoch 55/700: 100%|██████████| 18/18 [00:00<00:00, 1095.18it/s]
Epoch 56/700: 100%|██████████| 18/18 [00:00<00:00, 548.79it/s]
Epoch 57/700: 100%|██████████| 18/18 [00:00<00:00, 682.10it/s]
Epoch 58/700: 100%|██████████| 18/18 [00:00<00:00, 557.12it/s]
Epoch 59/700: 100%|██████████| 18/18 [00:00<00:00, 1122.44it/s]
Epoch 60/700: 100%|██████████| 18/18 [00:00<00:00, 1156.94it/s]
Epoch 61/700: 100%|██████████| 18/18 [00:00<00:00, 561.72it/s]


Epoch [61/700], Training Loss: 0.1619


Epoch 62/700: 100%|██████████| 18/18 [00:00<00:00, 446.78it/s]
Epoch 63/700: 100%|██████████| 18/18 [00:00<00:00, 570.24it/s]
Epoch 64/700: 100%|██████████| 18/18 [00:00<00:00, 1122.62it/s]
Epoch 65/700: 100%|██████████| 18/18 [00:00<00:00, 1124.63it/s]
Epoch 66/700: 100%|██████████| 18/18 [00:00<00:00, 723.29it/s]
Epoch 67/700: 100%|██████████| 18/18 [00:00<00:00, 1159.57it/s]
Epoch 68/700: 100%|██████████| 18/18 [00:00<00:00, 1121.39it/s]
Epoch 69/700: 100%|██████████| 18/18 [00:00<00:00, 1092.22it/s]
Epoch 70/700: 100%|██████████| 18/18 [00:00<00:00, 764.62it/s]
Epoch 71/700: 100%|██████████| 18/18 [00:00<00:00, 1124.48it/s]


Epoch [71/700], Training Loss: 0.1618


Epoch 72/700: 100%|██████████| 18/18 [00:00<00:00, 749.36it/s]
Epoch 73/700: 100%|██████████| 18/18 [00:00<00:00, 1197.54it/s]
Epoch 74/700: 100%|██████████| 18/18 [00:00<00:00, 1127.28it/s]
Epoch 75/700: 100%|██████████| 18/18 [00:00<00:00, 1095.72it/s]
Epoch 76/700: 100%|██████████| 18/18 [00:00<00:00, 546.74it/s]
Epoch 77/700: 100%|██████████| 18/18 [00:00<00:00, 755.43it/s]
Epoch 78/700: 100%|██████████| 18/18 [00:00<00:00, 748.32it/s]
Epoch 79/700: 100%|██████████| 18/18 [00:00<00:00, 751.28it/s]
Epoch 80/700: 100%|██████████| 18/18 [00:00<00:00, 746.05it/s]
Epoch 81/700: 100%|██████████| 18/18 [00:00<00:00, 1077.84it/s]


Epoch [81/700], Training Loss: 0.1520


Epoch 82/700: 100%|██████████| 18/18 [00:00<00:00, 765.10it/s]
Epoch 83/700: 100%|██████████| 18/18 [00:00<00:00, 1123.51it/s]
Epoch 84/700: 100%|██████████| 18/18 [00:00<00:00, 764.51it/s]
Epoch 85/700: 100%|██████████| 18/18 [00:00<00:00, 1091.46it/s]
Epoch 86/700: 100%|██████████| 18/18 [00:00<00:00, 748.97it/s]
Epoch 87/700: 100%|██████████| 18/18 [00:00<00:00, 740.58it/s]
Epoch 88/700: 100%|██████████| 18/18 [00:00<00:00, 747.43it/s]
Epoch 89/700: 100%|██████████| 18/18 [00:00<00:00, 749.06it/s]
Epoch 90/700: 100%|██████████| 18/18 [00:00<00:00, 739.46it/s]
Epoch 91/700: 100%|██████████| 18/18 [00:00<00:00, 441.75it/s]


Epoch [91/700], Training Loss: 0.1507


Epoch 92/700: 100%|██████████| 18/18 [00:00<00:00, 748.62it/s]
Epoch 93/700: 100%|██████████| 18/18 [00:00<00:00, 737.09it/s]
Epoch 94/700: 100%|██████████| 18/18 [00:00<00:00, 749.33it/s]
Epoch 95/700: 100%|██████████| 18/18 [00:00<00:00, 749.76it/s]
Epoch 96/700: 100%|██████████| 18/18 [00:00<00:00, 1433.71it/s]
Epoch 97/700: 100%|██████████| 18/18 [00:00<00:00, 732.05it/s]
Epoch 98/700: 100%|██████████| 18/18 [00:00<00:00, 1095.36it/s]
Epoch 99/700: 100%|██████████| 18/18 [00:00<00:00, 748.59it/s]
Epoch 100/700: 100%|██████████| 18/18 [00:00<00:00, 998.76it/s]
Epoch 101/700: 100%|██████████| 18/18 [00:00<00:00, 732.94it/s]


Epoch [101/700], Training Loss: 0.1505


Epoch 102/700: 100%|██████████| 18/18 [00:00<00:00, 740.22it/s]
Epoch 103/700: 100%|██████████| 18/18 [00:00<00:00, 1025.70it/s]
Epoch 104/700: 100%|██████████| 18/18 [00:00<00:00, 712.56it/s]
Epoch 105/700: 100%|██████████| 18/18 [00:00<00:00, 736.86it/s]
Epoch 106/700: 100%|██████████| 18/18 [00:00<00:00, 749.55it/s]
Epoch 107/700: 100%|██████████| 18/18 [00:00<00:00, 743.53it/s]
Epoch 108/700: 100%|██████████| 18/18 [00:00<00:00, 750.14it/s]
Epoch 109/700: 100%|██████████| 18/18 [00:00<00:00, 736.93it/s]
Epoch 110/700: 100%|██████████| 18/18 [00:00<00:00, 762.75it/s]
Epoch 111/700: 100%|██████████| 18/18 [00:00<00:00, 740.15it/s]


Epoch [111/700], Training Loss: 0.1678


Epoch 112/700: 100%|██████████| 18/18 [00:00<00:00, 749.80it/s]
Epoch 113/700: 100%|██████████| 18/18 [00:00<00:00, 562.21it/s]
Epoch 114/700: 100%|██████████| 18/18 [00:00<00:00, 895.89it/s]
Epoch 115/700: 100%|██████████| 18/18 [00:00<00:00, 628.62it/s]
Epoch 116/700: 100%|██████████| 18/18 [00:00<00:00, 764.98it/s]
Epoch 117/700: 100%|██████████| 18/18 [00:00<00:00, 747.88it/s]
Epoch 118/700: 100%|██████████| 18/18 [00:00<00:00, 448.48it/s]
Epoch 119/700: 100%|██████████| 18/18 [00:00<00:00, 561.26it/s]
Epoch 120/700: 100%|██████████| 18/18 [00:00<00:00, 1118.73it/s]
Epoch 121/700: 100%|██████████| 18/18 [00:00<00:00, 1121.36it/s]


Epoch [121/700], Training Loss: 0.1455


Epoch 122/700: 100%|██████████| 18/18 [00:00<00:00, 718.58it/s]
Epoch 123/700: 100%|██████████| 18/18 [00:00<00:00, 780.70it/s]
Epoch 124/700: 100%|██████████| 18/18 [00:00<00:00, 700.52it/s]
Epoch 125/700: 100%|██████████| 18/18 [00:00<00:00, 1081.90it/s]
Epoch 126/700: 100%|██████████| 18/18 [00:00<00:00, 1114.70it/s]
Epoch 127/700: 100%|██████████| 18/18 [00:00<00:00, 1097.79it/s]
Epoch 128/700: 100%|██████████| 18/18 [00:00<00:00, 1122.41it/s]
Epoch 129/700: 100%|██████████| 18/18 [00:00<00:00, 731.95it/s]
Epoch 130/700: 100%|██████████| 18/18 [00:00<00:00, 749.69it/s]
Epoch 131/700: 100%|██████████| 18/18 [00:00<00:00, 1122.94it/s]


Epoch [131/700], Training Loss: 0.1461


Epoch 132/700: 100%|██████████| 18/18 [00:00<00:00, 749.84it/s]
Epoch 133/700: 100%|██████████| 18/18 [00:00<00:00, 737.54it/s]
Epoch 134/700: 100%|██████████| 18/18 [00:00<00:00, 704.56it/s]
Epoch 135/700: 100%|██████████| 18/18 [00:00<00:00, 1139.81it/s]
Epoch 136/700: 100%|██████████| 18/18 [00:00<00:00, 733.39it/s]
Epoch 137/700: 100%|██████████| 18/18 [00:00<00:00, 756.81it/s]
Epoch 138/700: 100%|██████████| 18/18 [00:00<00:00, 750.21it/s]
Epoch 139/700: 100%|██████████| 18/18 [00:00<00:00, 1124.16it/s]
Epoch 140/700: 100%|██████████| 18/18 [00:00<00:00, 1106.30it/s]
Epoch 141/700: 100%|██████████| 18/18 [00:00<00:00, 745.15it/s]


Epoch [141/700], Training Loss: 0.1433


Epoch 142/700: 100%|██████████| 18/18 [00:00<00:00, 555.32it/s]
Epoch 143/700: 100%|██████████| 18/18 [00:00<00:00, 726.59it/s]
Epoch 144/700: 100%|██████████| 18/18 [00:00<00:00, 377.67it/s]
Epoch 145/700: 100%|██████████| 18/18 [00:00<00:00, 747.95it/s]
Epoch 146/700: 100%|██████████| 18/18 [00:00<00:00, 1122.39it/s]
Epoch 147/700: 100%|██████████| 18/18 [00:00<00:00, 1122.76it/s]
Epoch 148/700: 100%|██████████| 18/18 [00:00<00:00, 980.28it/s]
Epoch 149/700: 100%|██████████| 18/18 [00:00<00:00, 733.64it/s]
Epoch 150/700: 100%|██████████| 18/18 [00:00<00:00, 1124.51it/s]
Epoch 151/700: 100%|██████████| 18/18 [00:00<00:00, 1069.17it/s]


Epoch [151/700], Training Loss: 0.1415


Epoch 152/700: 100%|██████████| 18/18 [00:00<00:00, 854.75it/s]
Epoch 153/700: 100%|██████████| 18/18 [00:00<00:00, 1123.89it/s]
Epoch 154/700: 100%|██████████| 18/18 [00:00<00:00, 747.23it/s]
Epoch 155/700: 100%|██████████| 18/18 [00:00<00:00, 749.55it/s]
Epoch 156/700: 100%|██████████| 18/18 [00:00<00:00, 666.47it/s]
Epoch 157/700: 100%|██████████| 18/18 [00:00<00:00, 1024.03it/s]
Epoch 158/700: 100%|██████████| 18/18 [00:00<00:00, 736.14it/s]
Epoch 159/700: 100%|██████████| 18/18 [00:00<00:00, 736.00it/s]
Epoch 160/700: 100%|██████████| 18/18 [00:00<00:00, 561.17it/s]
Epoch 161/700: 100%|██████████| 18/18 [00:00<00:00, 447.10it/s]


Epoch [161/700], Training Loss: 0.1552


Epoch 162/700: 100%|██████████| 18/18 [00:00<00:00, 556.86it/s]
Epoch 163/700: 100%|██████████| 18/18 [00:00<00:00, 559.45it/s]
Epoch 164/700: 100%|██████████| 18/18 [00:00<00:00, 743.82it/s]
Epoch 165/700: 100%|██████████| 18/18 [00:00<00:00, 721.17it/s]
Epoch 166/700: 100%|██████████| 18/18 [00:00<00:00, 730.31it/s]
Epoch 167/700: 100%|██████████| 18/18 [00:00<00:00, 1115.23it/s]
Epoch 168/700: 100%|██████████| 18/18 [00:00<00:00, 1092.90it/s]
Epoch 169/700: 100%|██████████| 18/18 [00:00<00:00, 1123.71it/s]
Epoch 170/700: 100%|██████████| 18/18 [00:00<00:00, 748.23it/s]
Epoch 171/700: 100%|██████████| 18/18 [00:00<00:00, 963.88it/s]


Epoch [171/700], Training Loss: 0.1457


Epoch 172/700: 100%|██████████| 18/18 [00:00<00:00, 1086.14it/s]
Epoch 173/700: 100%|██████████| 18/18 [00:00<00:00, 375.01it/s]
Epoch 174/700: 100%|██████████| 18/18 [00:00<00:00, 629.32it/s]
Epoch 175/700: 100%|██████████| 18/18 [00:00<00:00, 1316.32it/s]
Epoch 176/700: 100%|██████████| 18/18 [00:00<00:00, 1123.04it/s]
Epoch 177/700: 100%|██████████| 18/18 [00:00<00:00, 1097.97it/s]
Epoch 178/700: 100%|██████████| 18/18 [00:00<00:00, 730.88it/s]
Epoch 179/700: 100%|██████████| 18/18 [00:00<00:00, 729.09it/s]
Epoch 180/700: 100%|██████████| 18/18 [00:00<00:00, 764.38it/s]
Epoch 181/700: 100%|██████████| 18/18 [00:00<00:00, 747.26it/s]


Epoch [181/700], Training Loss: 0.1523


Epoch 182/700: 100%|██████████| 18/18 [00:00<00:00, 1098.40it/s]
Epoch 183/700: 100%|██████████| 18/18 [00:00<00:00, 739.57it/s]
Epoch 184/700: 100%|██████████| 18/18 [00:00<00:00, 1102.06it/s]
Epoch 185/700: 100%|██████████| 18/18 [00:00<00:00, 1127.25it/s]
Epoch 186/700: 100%|██████████| 18/18 [00:00<00:00, 751.32it/s]
Epoch 187/700: 100%|██████████| 18/18 [00:00<00:00, 745.22it/s]
Epoch 188/700: 100%|██████████| 18/18 [00:00<00:00, 739.87it/s]
Epoch 189/700: 100%|██████████| 18/18 [00:00<00:00, 1121.64it/s]
Epoch 190/700: 100%|██████████| 18/18 [00:00<00:00, 1124.13it/s]
Epoch 191/700: 100%|██████████| 18/18 [00:00<00:00, 1115.36it/s]


Epoch [191/700], Training Loss: 0.1598


Epoch 192/700: 100%|██████████| 18/18 [00:00<00:00, 1103.55it/s]
Epoch 193/700: 100%|██████████| 18/18 [00:00<00:00, 750.10it/s]
Epoch 194/700: 100%|██████████| 18/18 [00:00<00:00, 1124.86it/s]
Epoch 195/700: 100%|██████████| 18/18 [00:00<00:00, 742.45it/s]
Epoch 196/700: 100%|██████████| 18/18 [00:00<00:00, 721.58it/s]
Epoch 197/700: 100%|██████████| 18/18 [00:00<00:00, 579.10it/s]
Epoch 198/700: 100%|██████████| 18/18 [00:00<00:00, 1097.03it/s]
Epoch 199/700: 100%|██████████| 18/18 [00:00<00:00, 1104.77it/s]
Epoch 200/700: 100%|██████████| 18/18 [00:00<00:00, 1152.28it/s]
Epoch 201/700: 100%|██████████| 18/18 [00:00<00:00, 764.15it/s]


Epoch [201/700], Training Loss: 0.1434


Epoch 202/700: 100%|██████████| 18/18 [00:00<00:00, 544.58it/s]
Epoch 203/700: 100%|██████████| 18/18 [00:00<00:00, 1124.75it/s]
Epoch 204/700: 100%|██████████| 18/18 [00:00<00:00, 740.36it/s]
Epoch 205/700: 100%|██████████| 18/18 [00:00<00:00, 736.40it/s]
Epoch 206/700: 100%|██████████| 18/18 [00:00<00:00, 551.95it/s]
Epoch 207/700: 100%|██████████| 18/18 [00:00<00:00, 556.21it/s]
Epoch 208/700: 100%|██████████| 18/18 [00:00<00:00, 743.83it/s]
Epoch 209/700: 100%|██████████| 18/18 [00:00<00:00, 745.88it/s]
Epoch 210/700: 100%|██████████| 18/18 [00:00<00:00, 745.46it/s]
Epoch 211/700: 100%|██████████| 18/18 [00:00<00:00, 748.84it/s]


Epoch [211/700], Training Loss: 0.1456


Epoch 212/700: 100%|██████████| 18/18 [00:00<00:00, 736.60it/s]
Epoch 213/700: 100%|██████████| 18/18 [00:00<00:00, 741.62it/s]
Epoch 214/700: 100%|██████████| 18/18 [00:00<00:00, 1126.00it/s]
Epoch 215/700: 100%|██████████| 18/18 [00:00<00:00, 867.28it/s]
Epoch 216/700: 100%|██████████| 18/18 [00:00<00:00, 760.81it/s]
Epoch 217/700: 100%|██████████| 18/18 [00:00<00:00, 870.55it/s]
Epoch 218/700: 100%|██████████| 18/18 [00:00<00:00, 748.01it/s]
Epoch 219/700: 100%|██████████| 18/18 [00:00<00:00, 1127.18it/s]
Epoch 220/700: 100%|██████████| 18/18 [00:00<00:00, 1123.81it/s]
Epoch 221/700: 100%|██████████| 18/18 [00:00<00:00, 1077.60it/s]


Epoch [221/700], Training Loss: 0.1416


Epoch 222/700: 100%|██████████| 18/18 [00:00<00:00, 748.77it/s]
Epoch 223/700: 100%|██████████| 18/18 [00:00<00:00, 1123.16it/s]
Epoch 224/700: 100%|██████████| 18/18 [00:00<00:00, 1113.56it/s]
Epoch 225/700: 100%|██████████| 18/18 [00:00<00:00, 1147.40it/s]
Epoch 226/700: 100%|██████████| 18/18 [00:00<00:00, 749.37it/s]
Epoch 227/700: 100%|██████████| 18/18 [00:00<00:00, 750.06it/s]
Epoch 228/700: 100%|██████████| 18/18 [00:00<00:00, 1122.89it/s]
Epoch 229/700: 100%|██████████| 18/18 [00:00<00:00, 1109.80it/s]
Epoch 230/700: 100%|██████████| 18/18 [00:00<00:00, 455.06it/s]
Epoch 231/700: 100%|██████████| 18/18 [00:00<00:00, 741.32it/s]


Epoch [231/700], Training Loss: 0.1468


Epoch 232/700: 100%|██████████| 18/18 [00:00<00:00, 1021.49it/s]
Epoch 233/700: 100%|██████████| 18/18 [00:00<00:00, 780.71it/s]
Epoch 234/700: 100%|██████████| 18/18 [00:00<00:00, 1098.59it/s]
Epoch 235/700: 100%|██████████| 18/18 [00:00<00:00, 734.03it/s]
Epoch 236/700: 100%|██████████| 18/18 [00:00<00:00, 724.94it/s]
Epoch 237/700: 100%|██████████| 18/18 [00:00<00:00, 916.78it/s]
Epoch 238/700: 100%|██████████| 18/18 [00:00<00:00, 716.65it/s]
Epoch 239/700: 100%|██████████| 18/18 [00:00<00:00, 1080.73it/s]
Epoch 240/700: 100%|██████████| 18/18 [00:00<00:00, 754.27it/s]
Epoch 241/700: 100%|██████████| 18/18 [00:00<00:00, 1054.20it/s]


Epoch [241/700], Training Loss: 0.1380


Epoch 242/700: 100%|██████████| 18/18 [00:00<00:00, 545.61it/s]
Epoch 243/700: 100%|██████████| 18/18 [00:00<00:00, 1107.03it/s]
Epoch 244/700: 100%|██████████| 18/18 [00:00<00:00, 1120.32it/s]
Epoch 245/700: 100%|██████████| 18/18 [00:00<00:00, 746.69it/s]
Epoch 246/700: 100%|██████████| 18/18 [00:00<00:00, 749.11it/s]
Epoch 247/700: 100%|██████████| 18/18 [00:00<00:00, 1158.27it/s]
Epoch 248/700: 100%|██████████| 18/18 [00:00<00:00, 1124.11it/s]
Epoch 249/700: 100%|██████████| 18/18 [00:00<00:00, 741.82it/s]
Epoch 250/700: 100%|██████████| 18/18 [00:00<00:00, 747.77it/s]
Epoch 251/700: 100%|██████████| 18/18 [00:00<00:00, 750.31it/s]


Epoch [251/700], Training Loss: 0.1795


Epoch 252/700: 100%|██████████| 18/18 [00:00<00:00, 709.18it/s]
Epoch 253/700: 100%|██████████| 18/18 [00:00<00:00, 1118.08it/s]
Epoch 254/700: 100%|██████████| 18/18 [00:00<00:00, 1118.91it/s]
Epoch 255/700: 100%|██████████| 18/18 [00:00<00:00, 751.30it/s]
Epoch 256/700: 100%|██████████| 18/18 [00:00<00:00, 739.29it/s]
Epoch 257/700: 100%|██████████| 18/18 [00:00<00:00, 403.39it/s]
Epoch 258/700: 100%|██████████| 18/18 [00:00<00:00, 574.95it/s]
Epoch 259/700: 100%|██████████| 18/18 [00:00<00:00, 741.09it/s]
Epoch 260/700: 100%|██████████| 18/18 [00:00<00:00, 1122.11it/s]
Epoch 261/700: 100%|██████████| 18/18 [00:00<00:00, 1123.06it/s]


Epoch [261/700], Training Loss: 0.1363


Epoch 262/700: 100%|██████████| 18/18 [00:00<00:00, 750.31it/s]
Epoch 263/700: 100%|██████████| 18/18 [00:00<00:00, 749.88it/s]
Epoch 264/700: 100%|██████████| 18/18 [00:00<00:00, 873.62it/s]
Epoch 265/700: 100%|██████████| 18/18 [00:00<00:00, 749.15it/s]
Epoch 266/700: 100%|██████████| 18/18 [00:00<00:00, 750.77it/s]
Epoch 267/700: 100%|██████████| 18/18 [00:00<00:00, 555.92it/s]
Epoch 268/700: 100%|██████████| 18/18 [00:00<00:00, 756.85it/s]
Epoch 269/700: 100%|██████████| 18/18 [00:00<00:00, 1126.39it/s]
Epoch 270/700: 100%|██████████| 18/18 [00:00<00:00, 1146.16it/s]
Epoch 271/700: 100%|██████████| 18/18 [00:00<00:00, 746.37it/s]


Epoch [271/700], Training Loss: 0.1392


Epoch 272/700: 100%|██████████| 18/18 [00:00<00:00, 742.13it/s]
Epoch 273/700: 100%|██████████| 18/18 [00:00<00:00, 712.06it/s]
Epoch 274/700: 100%|██████████| 18/18 [00:00<00:00, 741.72it/s]
Epoch 275/700: 100%|██████████| 18/18 [00:00<00:00, 796.34it/s]
Epoch 276/700: 100%|██████████| 18/18 [00:00<00:00, 1123.61it/s]
Epoch 277/700: 100%|██████████| 18/18 [00:00<00:00, 829.66it/s]
Epoch 278/700: 100%|██████████| 18/18 [00:00<00:00, 1092.05it/s]
Epoch 279/700: 100%|██████████| 18/18 [00:00<00:00, 1127.92it/s]
Epoch 280/700: 100%|██████████| 18/18 [00:00<00:00, 745.83it/s]
Epoch 281/700: 100%|██████████| 18/18 [00:00<00:00, 729.97it/s]


Epoch [281/700], Training Loss: 0.1343


Epoch 282/700: 100%|██████████| 18/18 [00:00<00:00, 1104.86it/s]
Epoch 283/700: 100%|██████████| 18/18 [00:00<00:00, 747.20it/s]
Epoch 284/700: 100%|██████████| 18/18 [00:00<00:00, 1069.48it/s]
Epoch 285/700: 100%|██████████| 18/18 [00:00<00:00, 748.17it/s]
Epoch 286/700: 100%|██████████| 18/18 [00:00<00:00, 750.16it/s]
Epoch 287/700: 100%|██████████| 18/18 [00:00<00:00, 1052.83it/s]
Epoch 288/700: 100%|██████████| 18/18 [00:00<00:00, 1089.63it/s]
Epoch 289/700: 100%|██████████| 18/18 [00:00<00:00, 747.72it/s]
Epoch 290/700: 100%|██████████| 18/18 [00:00<00:00, 1110.45it/s]
Epoch 291/700: 100%|██████████| 18/18 [00:00<00:00, 744.26it/s]


Epoch [291/700], Training Loss: 0.1384


Epoch 292/700: 100%|██████████| 18/18 [00:00<00:00, 943.75it/s]
Epoch 293/700: 100%|██████████| 18/18 [00:00<00:00, 1123.34it/s]
Epoch 294/700: 100%|██████████| 18/18 [00:00<00:00, 1124.66it/s]
Epoch 295/700: 100%|██████████| 18/18 [00:00<00:00, 1127.91it/s]
Epoch 296/700: 100%|██████████| 18/18 [00:00<00:00, 750.09it/s]
Epoch 297/700: 100%|██████████| 18/18 [00:00<00:00, 1078.10it/s]
Epoch 298/700: 100%|██████████| 18/18 [00:00<00:00, 1098.03it/s]
Epoch 299/700: 100%|██████████| 18/18 [00:00<00:00, 749.59it/s]
Epoch 300/700: 100%|██████████| 18/18 [00:00<00:00, 1519.19it/s]
Epoch 301/700: 100%|██████████| 18/18 [00:00<00:00, 1122.44it/s]


Epoch [301/700], Training Loss: 0.1517


Epoch 302/700: 100%|██████████| 18/18 [00:00<00:00, 1094.64it/s]
Epoch 303/700: 100%|██████████| 18/18 [00:00<00:00, 749.41it/s]
Epoch 304/700: 100%|██████████| 18/18 [00:00<00:00, 1124.44it/s]
Epoch 305/700: 100%|██████████| 18/18 [00:00<00:00, 1122.04it/s]
Epoch 306/700: 100%|██████████| 18/18 [00:00<00:00, 1156.38it/s]
Epoch 307/700: 100%|██████████| 18/18 [00:00<00:00, 705.43it/s]
Epoch 308/700: 100%|██████████| 18/18 [00:00<00:00, 1120.66it/s]
Epoch 309/700: 100%|██████████| 18/18 [00:00<00:00, 749.58it/s]
Epoch 310/700: 100%|██████████| 18/18 [00:00<00:00, 1114.85it/s]
Epoch 311/700: 100%|██████████| 18/18 [00:00<00:00, 1115.67it/s]


Epoch [311/700], Training Loss: 0.1341


Epoch 312/700: 100%|██████████| 18/18 [00:00<00:00, 443.73it/s]
Epoch 313/700: 100%|██████████| 18/18 [00:00<00:00, 413.32it/s]
Epoch 314/700: 100%|██████████| 18/18 [00:00<00:00, 749.82it/s]
Epoch 315/700: 100%|██████████| 18/18 [00:00<00:00, 747.95it/s]
Epoch 316/700: 100%|██████████| 18/18 [00:00<00:00, 751.65it/s]
Epoch 317/700: 100%|██████████| 18/18 [00:00<00:00, 760.36it/s]
Epoch 318/700: 100%|██████████| 18/18 [00:00<00:00, 751.76it/s]
Epoch 319/700: 100%|██████████| 18/18 [00:00<00:00, 502.90it/s]
Epoch 320/700: 100%|██████████| 18/18 [00:00<00:00, 1038.78it/s]
Epoch 321/700: 100%|██████████| 18/18 [00:00<00:00, 745.39it/s]


Epoch [321/700], Training Loss: 0.1380


Epoch 322/700: 100%|██████████| 18/18 [00:00<00:00, 559.08it/s]
Epoch 323/700: 100%|██████████| 18/18 [00:00<00:00, 1087.99it/s]
Epoch 324/700: 100%|██████████| 18/18 [00:00<00:00, 1092.90it/s]
Epoch 325/700: 100%|██████████| 18/18 [00:00<00:00, 1119.68it/s]
Epoch 326/700: 100%|██████████| 18/18 [00:00<00:00, 1125.60it/s]
Epoch 327/700: 100%|██████████| 18/18 [00:00<00:00, 749.60it/s]
Epoch 328/700: 100%|██████████| 18/18 [00:00<00:00, 737.24it/s]
Epoch 329/700: 100%|██████████| 18/18 [00:00<00:00, 733.70it/s]
Epoch 330/700: 100%|██████████| 18/18 [00:00<00:00, 750.24it/s]
Epoch 331/700: 100%|██████████| 18/18 [00:00<00:00, 746.27it/s]


Epoch [331/700], Training Loss: 0.1443


Epoch 332/700: 100%|██████████| 18/18 [00:00<00:00, 1133.19it/s]
Epoch 333/700: 100%|██████████| 18/18 [00:00<00:00, 1056.97it/s]
Epoch 334/700: 100%|██████████| 18/18 [00:00<00:00, 761.40it/s]
Epoch 335/700: 100%|██████████| 18/18 [00:00<00:00, 1076.40it/s]
Epoch 336/700: 100%|██████████| 18/18 [00:00<00:00, 1123.31it/s]
Epoch 337/700: 100%|██████████| 18/18 [00:00<00:00, 740.45it/s]
Epoch 338/700: 100%|██████████| 18/18 [00:00<00:00, 744.10it/s]
Epoch 339/700: 100%|██████████| 18/18 [00:00<00:00, 1093.01it/s]
Epoch 340/700: 100%|██████████| 18/18 [00:00<00:00, 432.28it/s]
Epoch 341/700: 100%|██████████| 18/18 [00:00<00:00, 741.31it/s]


Epoch [341/700], Training Loss: 0.1333


Epoch 342/700: 100%|██████████| 18/18 [00:00<00:00, 748.33it/s]
Epoch 343/700: 100%|██████████| 18/18 [00:00<00:00, 1119.94it/s]
Epoch 344/700: 100%|██████████| 18/18 [00:00<00:00, 1155.67it/s]
Epoch 345/700: 100%|██████████| 18/18 [00:00<00:00, 725.72it/s]
Epoch 346/700: 100%|██████████| 18/18 [00:00<00:00, 1121.54it/s]
Epoch 347/700: 100%|██████████| 18/18 [00:00<00:00, 1129.29it/s]
Epoch 348/700: 100%|██████████| 18/18 [00:00<00:00, 749.64it/s]
Epoch 349/700: 100%|██████████| 18/18 [00:00<00:00, 742.27it/s]
Epoch 350/700: 100%|██████████| 18/18 [00:00<00:00, 1120.17it/s]
Epoch 351/700: 100%|██████████| 18/18 [00:00<00:00, 1123.61it/s]


Epoch [351/700], Training Loss: 0.1399


Epoch 352/700: 100%|██████████| 18/18 [00:00<00:00, 1096.84it/s]
Epoch 353/700: 100%|██████████| 18/18 [00:00<00:00, 749.67it/s]
Epoch 354/700: 100%|██████████| 18/18 [00:00<00:00, 747.70it/s]
Epoch 355/700: 100%|██████████| 18/18 [00:00<00:00, 1485.62it/s]
Epoch 356/700: 100%|██████████| 18/18 [00:00<00:00, 733.32it/s]
Epoch 357/700: 100%|██████████| 18/18 [00:00<00:00, 750.32it/s]
Epoch 358/700: 100%|██████████| 18/18 [00:00<00:00, 1126.49it/s]
Epoch 359/700: 100%|██████████| 18/18 [00:00<00:00, 1123.04it/s]
Epoch 360/700: 100%|██████████| 18/18 [00:00<00:00, 747.72it/s]
Epoch 361/700: 100%|██████████| 18/18 [00:00<00:00, 1095.20it/s]


Epoch [361/700], Training Loss: 0.1397


Epoch 362/700: 100%|██████████| 18/18 [00:00<00:00, 1122.11it/s]
Epoch 363/700: 100%|██████████| 18/18 [00:00<00:00, 749.52it/s]
Epoch 364/700: 100%|██████████| 18/18 [00:00<00:00, 827.96it/s]
Epoch 365/700: 100%|██████████| 18/18 [00:00<00:00, 1154.94it/s]
Epoch 366/700: 100%|██████████| 18/18 [00:00<00:00, 1121.47it/s]
Epoch 367/700: 100%|██████████| 18/18 [00:00<00:00, 372.35it/s]
Epoch 368/700: 100%|██████████| 18/18 [00:00<00:00, 556.80it/s]
Epoch 369/700: 100%|██████████| 18/18 [00:00<00:00, 751.33it/s]
Epoch 370/700: 100%|██████████| 18/18 [00:00<00:00, 744.74it/s]
Epoch 371/700: 100%|██████████| 18/18 [00:00<00:00, 747.94it/s]


Epoch [371/700], Training Loss: 0.1332


Epoch 372/700: 100%|██████████| 18/18 [00:00<00:00, 744.30it/s]
Epoch 373/700: 100%|██████████| 18/18 [00:00<00:00, 750.91it/s]
Epoch 374/700: 100%|██████████| 18/18 [00:00<00:00, 747.48it/s]
Epoch 375/700: 100%|██████████| 18/18 [00:00<00:00, 1124.39it/s]
Epoch 376/700: 100%|██████████| 18/18 [00:00<00:00, 1128.60it/s]
Epoch 377/700: 100%|██████████| 18/18 [00:00<00:00, 749.31it/s]
Epoch 378/700: 100%|██████████| 18/18 [00:00<00:00, 744.93it/s]
Epoch 379/700: 100%|██████████| 18/18 [00:00<00:00, 554.71it/s]
Epoch 380/700: 100%|██████████| 18/18 [00:00<00:00, 747.78it/s]
Epoch 381/700: 100%|██████████| 18/18 [00:00<00:00, 746.76it/s]


Epoch [381/700], Training Loss: 0.1395


Epoch 382/700: 100%|██████████| 18/18 [00:00<00:00, 744.45it/s]
Epoch 383/700: 100%|██████████| 18/18 [00:00<00:00, 673.59it/s]
Epoch 384/700: 100%|██████████| 18/18 [00:00<00:00, 691.15it/s]
Epoch 385/700: 100%|██████████| 18/18 [00:00<00:00, 792.28it/s]
Epoch 386/700: 100%|██████████| 18/18 [00:00<00:00, 734.73it/s]
Epoch 387/700: 100%|██████████| 18/18 [00:00<00:00, 735.94it/s]
Epoch 388/700: 100%|██████████| 18/18 [00:00<00:00, 706.07it/s]
Epoch 389/700: 100%|██████████| 18/18 [00:00<00:00, 1120.07it/s]
Epoch 390/700: 100%|██████████| 18/18 [00:00<00:00, 746.59it/s]
Epoch 391/700: 100%|██████████| 18/18 [00:00<00:00, 744.04it/s]


Epoch [391/700], Training Loss: 0.1339


Epoch 392/700: 100%|██████████| 18/18 [00:00<00:00, 563.25it/s]
Epoch 393/700: 100%|██████████| 18/18 [00:00<00:00, 744.25it/s]
Epoch 394/700: 100%|██████████| 18/18 [00:00<00:00, 447.06it/s]
Epoch 395/700: 100%|██████████| 18/18 [00:00<00:00, 737.73it/s]
Epoch 396/700: 100%|██████████| 18/18 [00:00<00:00, 744.04it/s]
Epoch 397/700: 100%|██████████| 18/18 [00:00<00:00, 1125.13it/s]
Epoch 398/700: 100%|██████████| 18/18 [00:00<00:00, 738.84it/s]
Epoch 399/700: 100%|██████████| 18/18 [00:00<00:00, 1078.49it/s]
Epoch 400/700: 100%|██████████| 18/18 [00:00<00:00, 1118.66it/s]
Epoch 401/700: 100%|██████████| 18/18 [00:00<00:00, 1125.12it/s]


Epoch [401/700], Training Loss: 0.1331


Epoch 402/700: 100%|██████████| 18/18 [00:00<00:00, 1115.80it/s]
Epoch 403/700: 100%|██████████| 18/18 [00:00<00:00, 1195.64it/s]
Epoch 404/700: 100%|██████████| 18/18 [00:00<00:00, 751.23it/s]
Epoch 405/700: 100%|██████████| 18/18 [00:00<00:00, 883.83it/s]
Epoch 406/700: 100%|██████████| 18/18 [00:00<00:00, 738.66it/s]
Epoch 407/700: 100%|██████████| 18/18 [00:00<00:00, 750.25it/s]
Epoch 408/700: 100%|██████████| 18/18 [00:00<00:00, 748.43it/s]
Epoch 409/700: 100%|██████████| 18/18 [00:00<00:00, 1091.16it/s]
Epoch 410/700: 100%|██████████| 18/18 [00:00<00:00, 736.18it/s]
Epoch 411/700: 100%|██████████| 18/18 [00:00<00:00, 1098.98it/s]


Epoch [411/700], Training Loss: 0.1364


Epoch 412/700: 100%|██████████| 18/18 [00:00<00:00, 746.63it/s]
Epoch 413/700: 100%|██████████| 18/18 [00:00<00:00, 735.52it/s]
Epoch 414/700: 100%|██████████| 18/18 [00:00<00:00, 1122.87it/s]
Epoch 415/700: 100%|██████████| 18/18 [00:00<00:00, 735.29it/s]
Epoch 416/700: 100%|██████████| 18/18 [00:00<00:00, 718.44it/s]
Epoch 417/700: 100%|██████████| 18/18 [00:00<00:00, 741.43it/s]
Epoch 418/700: 100%|██████████| 18/18 [00:00<00:00, 722.35it/s]
Epoch 419/700: 100%|██████████| 18/18 [00:00<00:00, 1016.34it/s]
Epoch 420/700: 100%|██████████| 18/18 [00:00<00:00, 746.66it/s]
Epoch 421/700: 100%|██████████| 18/18 [00:00<00:00, 733.50it/s]


Epoch [421/700], Training Loss: 0.1349


Epoch 422/700: 100%|██████████| 18/18 [00:00<00:00, 320.64it/s]
Epoch 423/700: 100%|██████████| 18/18 [00:00<00:00, 745.65it/s]
Epoch 424/700: 100%|██████████| 18/18 [00:00<00:00, 1124.21it/s]
Epoch 425/700: 100%|██████████| 18/18 [00:00<00:00, 1079.77it/s]
Epoch 426/700: 100%|██████████| 18/18 [00:00<00:00, 1113.04it/s]
Epoch 427/700: 100%|██████████| 18/18 [00:00<00:00, 1107.18it/s]
Epoch 428/700: 100%|██████████| 18/18 [00:00<00:00, 749.88it/s]
Epoch 429/700: 100%|██████████| 18/18 [00:00<00:00, 1080.45it/s]
Epoch 430/700: 100%|██████████| 18/18 [00:00<00:00, 747.57it/s]
Epoch 431/700: 100%|██████████| 18/18 [00:00<00:00, 1125.01it/s]


Epoch [431/700], Training Loss: 0.1367


Epoch 432/700: 100%|██████████| 18/18 [00:00<00:00, 1115.98it/s]
Epoch 433/700: 100%|██████████| 18/18 [00:00<00:00, 558.80it/s]
Epoch 434/700: 100%|██████████| 18/18 [00:00<00:00, 1156.15it/s]
Epoch 435/700: 100%|██████████| 18/18 [00:00<00:00, 756.61it/s]
Epoch 436/700: 100%|██████████| 18/18 [00:00<00:00, 738.76it/s]
Epoch 437/700: 100%|██████████| 18/18 [00:00<00:00, 877.71it/s]
Epoch 438/700: 100%|██████████| 18/18 [00:00<00:00, 744.40it/s]
Epoch 439/700: 100%|██████████| 18/18 [00:00<00:00, 746.95it/s]
Epoch 440/700: 100%|██████████| 18/18 [00:00<00:00, 728.47it/s]
Epoch 441/700: 100%|██████████| 18/18 [00:00<00:00, 1082.90it/s]


Epoch [441/700], Training Loss: 0.1368


Epoch 442/700: 100%|██████████| 18/18 [00:00<00:00, 1123.19it/s]
Epoch 443/700: 100%|██████████| 18/18 [00:00<00:00, 552.98it/s]
Epoch 444/700: 100%|██████████| 18/18 [00:00<00:00, 1125.05it/s]
Epoch 445/700: 100%|██████████| 18/18 [00:00<00:00, 745.46it/s]
Epoch 446/700: 100%|██████████| 18/18 [00:00<00:00, 735.47it/s]
Epoch 447/700: 100%|██████████| 18/18 [00:00<00:00, 1194.07it/s]
Epoch 448/700: 100%|██████████| 18/18 [00:00<00:00, 1194.15it/s]
Epoch 449/700: 100%|██████████| 18/18 [00:00<00:00, 747.34it/s]
Epoch 450/700: 100%|██████████| 18/18 [00:00<00:00, 1072.82it/s]
Epoch 451/700: 100%|██████████| 18/18 [00:00<00:00, 548.91it/s]


Epoch [451/700], Training Loss: 0.1344


Epoch 452/700: 100%|██████████| 18/18 [00:00<00:00, 527.48it/s]
Epoch 453/700: 100%|██████████| 18/18 [00:00<00:00, 1120.41it/s]
Epoch 454/700: 100%|██████████| 18/18 [00:00<00:00, 1127.72it/s]
Epoch 455/700: 100%|██████████| 18/18 [00:00<00:00, 1093.48it/s]
Epoch 456/700: 100%|██████████| 18/18 [00:00<00:00, 1111.25it/s]
Epoch 457/700: 100%|██████████| 18/18 [00:00<00:00, 741.04it/s]
Epoch 458/700: 100%|██████████| 18/18 [00:00<00:00, 737.18it/s]
Epoch 459/700: 100%|██████████| 18/18 [00:00<00:00, 748.89it/s]
Epoch 460/700: 100%|██████████| 18/18 [00:00<00:00, 1098.67it/s]
Epoch 461/700: 100%|██████████| 18/18 [00:00<00:00, 1085.81it/s]


Epoch [461/700], Training Loss: 0.1384


Epoch 462/700: 100%|██████████| 18/18 [00:00<00:00, 734.56it/s]
Epoch 463/700: 100%|██████████| 18/18 [00:00<00:00, 617.24it/s]
Epoch 464/700: 100%|██████████| 18/18 [00:00<00:00, 1057.18it/s]
Epoch 465/700: 100%|██████████| 18/18 [00:00<00:00, 946.83it/s]
Epoch 466/700: 100%|██████████| 18/18 [00:00<00:00, 854.37it/s]
Epoch 467/700: 100%|██████████| 18/18 [00:00<00:00, 749.06it/s]
Epoch 468/700: 100%|██████████| 18/18 [00:00<00:00, 1115.69it/s]
Epoch 469/700: 100%|██████████| 18/18 [00:00<00:00, 696.27it/s]
Epoch 470/700: 100%|██████████| 18/18 [00:00<00:00, 781.25it/s]
Epoch 471/700: 100%|██████████| 18/18 [00:00<00:00, 739.30it/s]


Epoch [471/700], Training Loss: 0.1523


Epoch 472/700: 100%|██████████| 18/18 [00:00<00:00, 738.97it/s]
Epoch 473/700: 100%|██████████| 18/18 [00:00<00:00, 744.81it/s]
Epoch 474/700: 100%|██████████| 18/18 [00:00<00:00, 554.97it/s]
Epoch 475/700: 100%|██████████| 18/18 [00:00<00:00, 744.93it/s]
Epoch 476/700: 100%|██████████| 18/18 [00:00<00:00, 747.00it/s]
Epoch 477/700: 100%|██████████| 18/18 [00:00<00:00, 737.95it/s]
Epoch 478/700: 100%|██████████| 18/18 [00:00<00:00, 562.07it/s]
Epoch 479/700: 100%|██████████| 18/18 [00:00<00:00, 447.43it/s]
Epoch 480/700: 100%|██████████| 18/18 [00:00<00:00, 736.81it/s]
Epoch 481/700: 100%|██████████| 18/18 [00:00<00:00, 1061.61it/s]


Epoch [481/700], Training Loss: 0.1365


Epoch 482/700: 100%|██████████| 18/18 [00:00<00:00, 1109.26it/s]
Epoch 483/700: 100%|██████████| 18/18 [00:00<00:00, 714.95it/s]
Epoch 484/700: 100%|██████████| 18/18 [00:00<00:00, 1027.06it/s]
Epoch 485/700: 100%|██████████| 18/18 [00:00<00:00, 1124.83it/s]
Epoch 486/700: 100%|██████████| 18/18 [00:00<00:00, 1126.98it/s]
Epoch 487/700: 100%|██████████| 18/18 [00:00<00:00, 747.77it/s]
Epoch 488/700: 100%|██████████| 18/18 [00:00<00:00, 1116.56it/s]
Epoch 489/700: 100%|██████████| 18/18 [00:00<00:00, 735.25it/s]
Epoch 490/700: 100%|██████████| 18/18 [00:00<00:00, 1116.23it/s]
Epoch 491/700: 100%|██████████| 18/18 [00:00<00:00, 749.32it/s]


Epoch [491/700], Training Loss: 0.1317


Epoch 492/700: 100%|██████████| 18/18 [00:00<00:00, 1047.05it/s]
Epoch 493/700: 100%|██████████| 18/18 [00:00<00:00, 748.39it/s]
Epoch 494/700: 100%|██████████| 18/18 [00:00<00:00, 1099.84it/s]
Epoch 495/700: 100%|██████████| 18/18 [00:00<00:00, 749.41it/s]
Epoch 496/700: 100%|██████████| 18/18 [00:00<00:00, 1125.00it/s]
Epoch 497/700: 100%|██████████| 18/18 [00:00<00:00, 1118.81it/s]
Epoch 498/700: 100%|██████████| 18/18 [00:00<00:00, 750.38it/s]
Epoch 499/700: 100%|██████████| 18/18 [00:00<00:00, 1089.70it/s]
Epoch 500/700: 100%|██████████| 18/18 [00:00<00:00, 764.28it/s]
Epoch 501/700: 100%|██████████| 18/18 [00:00<00:00, 736.90it/s]


Epoch [501/700], Training Loss: 0.1316


Epoch 502/700: 100%|██████████| 18/18 [00:00<00:00, 747.50it/s]
Epoch 503/700: 100%|██████████| 18/18 [00:00<00:00, 734.74it/s]
Epoch 504/700: 100%|██████████| 18/18 [00:00<00:00, 1122.19it/s]
Epoch 505/700: 100%|██████████| 18/18 [00:00<00:00, 748.20it/s]
Epoch 506/700: 100%|██████████| 18/18 [00:00<00:00, 448.49it/s]
Epoch 507/700: 100%|██████████| 18/18 [00:00<00:00, 764.24it/s]
Epoch 508/700: 100%|██████████| 18/18 [00:00<00:00, 812.67it/s]
Epoch 509/700: 100%|██████████| 18/18 [00:00<00:00, 1121.65it/s]
Epoch 510/700: 100%|██████████| 18/18 [00:00<00:00, 750.70it/s]
Epoch 511/700: 100%|██████████| 18/18 [00:00<00:00, 737.43it/s]


Epoch [511/700], Training Loss: 0.1329


Epoch 512/700: 100%|██████████| 18/18 [00:00<00:00, 1126.46it/s]
Epoch 513/700: 100%|██████████| 18/18 [00:00<00:00, 557.19it/s]
Epoch 514/700: 100%|██████████| 18/18 [00:00<00:00, 750.35it/s]
Epoch 515/700: 100%|██████████| 18/18 [00:00<00:00, 735.23it/s]
Epoch 516/700: 100%|██████████| 18/18 [00:00<00:00, 746.71it/s]
Epoch 517/700: 100%|██████████| 18/18 [00:00<00:00, 945.60it/s]
Epoch 518/700: 100%|██████████| 18/18 [00:00<00:00, 562.53it/s]
Epoch 519/700: 100%|██████████| 18/18 [00:00<00:00, 559.96it/s]
Epoch 520/700: 100%|██████████| 18/18 [00:00<00:00, 557.47it/s]
Epoch 521/700: 100%|██████████| 18/18 [00:00<00:00, 752.85it/s]


Epoch [521/700], Training Loss: 0.1345


Epoch 522/700: 100%|██████████| 18/18 [00:00<00:00, 1102.38it/s]
Epoch 523/700: 100%|██████████| 18/18 [00:00<00:00, 550.64it/s]
Epoch 524/700: 100%|██████████| 18/18 [00:00<00:00, 749.41it/s]
Epoch 525/700: 100%|██████████| 18/18 [00:00<00:00, 726.49it/s]
Epoch 526/700: 100%|██████████| 18/18 [00:00<00:00, 1028.30it/s]
Epoch 527/700: 100%|██████████| 18/18 [00:00<00:00, 1133.78it/s]
Epoch 528/700: 100%|██████████| 18/18 [00:00<00:00, 742.21it/s]
Epoch 529/700: 100%|██████████| 18/18 [00:00<00:00, 751.93it/s]
Epoch 530/700: 100%|██████████| 18/18 [00:00<00:00, 742.57it/s]
Epoch 531/700: 100%|██████████| 18/18 [00:00<00:00, 918.26it/s]


Epoch [531/700], Training Loss: 0.1326


Epoch 532/700: 100%|██████████| 18/18 [00:00<00:00, 1096.68it/s]
Epoch 533/700: 100%|██████████| 18/18 [00:00<00:00, 747.01it/s]
Epoch 534/700: 100%|██████████| 18/18 [00:00<00:00, 449.20it/s]
Epoch 535/700: 100%|██████████| 18/18 [00:00<00:00, 1116.40it/s]
Epoch 536/700: 100%|██████████| 18/18 [00:00<00:00, 1116.96it/s]
Epoch 537/700: 100%|██████████| 18/18 [00:00<00:00, 751.44it/s]
Epoch 538/700: 100%|██████████| 18/18 [00:00<00:00, 1124.88it/s]
Epoch 539/700: 100%|██████████| 18/18 [00:00<00:00, 967.02it/s]
Epoch 540/700: 100%|██████████| 18/18 [00:00<00:00, 833.61it/s]
Epoch 541/700: 100%|██████████| 18/18 [00:00<00:00, 739.81it/s]


Epoch [541/700], Training Loss: 0.1349


Epoch 542/700: 100%|██████████| 18/18 [00:00<00:00, 747.90it/s]
Epoch 543/700: 100%|██████████| 18/18 [00:00<00:00, 749.38it/s]
Epoch 544/700: 100%|██████████| 18/18 [00:00<00:00, 992.08it/s]
Epoch 545/700: 100%|██████████| 18/18 [00:00<00:00, 764.36it/s]
Epoch 546/700: 100%|██████████| 18/18 [00:00<00:00, 1124.61it/s]
Epoch 547/700: 100%|██████████| 18/18 [00:00<00:00, 1132.75it/s]
Epoch 548/700: 100%|██████████| 18/18 [00:00<00:00, 1118.93it/s]
Epoch 549/700: 100%|██████████| 18/18 [00:00<00:00, 749.12it/s]
Epoch 550/700: 100%|██████████| 18/18 [00:00<00:00, 750.36it/s]
Epoch 551/700: 100%|██████████| 18/18 [00:00<00:00, 748.94it/s]


Epoch [551/700], Training Loss: 0.1328


Epoch 552/700: 100%|██████████| 18/18 [00:00<00:00, 744.03it/s]
Epoch 553/700: 100%|██████████| 18/18 [00:00<00:00, 748.94it/s]
Epoch 554/700: 100%|██████████| 18/18 [00:00<00:00, 730.67it/s]
Epoch 555/700: 100%|██████████| 18/18 [00:00<00:00, 1155.17it/s]
Epoch 556/700: 100%|██████████| 18/18 [00:00<00:00, 1052.21it/s]
Epoch 557/700: 100%|██████████| 18/18 [00:00<00:00, 1094.39it/s]
Epoch 558/700: 100%|██████████| 18/18 [00:00<00:00, 1124.61it/s]
Epoch 559/700: 100%|██████████| 18/18 [00:00<00:00, 734.28it/s]
Epoch 560/700: 100%|██████████| 18/18 [00:00<00:00, 745.55it/s]
Epoch 561/700: 100%|██████████| 18/18 [00:00<00:00, 749.94it/s]


Epoch [561/700], Training Loss: 0.1344


Epoch 562/700: 100%|██████████| 18/18 [00:00<00:00, 320.24it/s]
Epoch 563/700: 100%|██████████| 18/18 [00:00<00:00, 559.29it/s]
Epoch 564/700: 100%|██████████| 18/18 [00:00<00:00, 1122.84it/s]
Epoch 565/700: 100%|██████████| 18/18 [00:00<00:00, 1157.60it/s]
Epoch 566/700: 100%|██████████| 18/18 [00:00<00:00, 812.34it/s]
Epoch 567/700: 100%|██████████| 18/18 [00:00<00:00, 742.62it/s]
Epoch 568/700: 100%|██████████| 18/18 [00:00<00:00, 548.07it/s]
Epoch 569/700: 100%|██████████| 18/18 [00:00<00:00, 816.35it/s]
Epoch 570/700: 100%|██████████| 18/18 [00:00<00:00, 749.64it/s]
Epoch 571/700: 100%|██████████| 18/18 [00:00<00:00, 562.03it/s]


Epoch [571/700], Training Loss: 0.1373


Epoch 572/700: 100%|██████████| 18/18 [00:00<00:00, 730.27it/s]
Epoch 573/700: 100%|██████████| 18/18 [00:00<00:00, 583.78it/s]
Epoch 574/700: 100%|██████████| 18/18 [00:00<00:00, 723.23it/s]
Epoch 575/700: 100%|██████████| 18/18 [00:00<00:00, 636.74it/s]
Epoch 576/700: 100%|██████████| 18/18 [00:00<00:00, 747.26it/s]
Epoch 577/700: 100%|██████████| 18/18 [00:00<00:00, 750.42it/s]
Epoch 578/700: 100%|██████████| 18/18 [00:00<00:00, 1110.27it/s]
Epoch 579/700: 100%|██████████| 18/18 [00:00<00:00, 1156.75it/s]
Epoch 580/700: 100%|██████████| 18/18 [00:00<00:00, 1120.99it/s]
Epoch 581/700: 100%|██████████| 18/18 [00:00<00:00, 748.52it/s]


Epoch [581/700], Training Loss: 0.1378


Epoch 582/700: 100%|██████████| 18/18 [00:00<00:00, 1082.48it/s]
Epoch 583/700: 100%|██████████| 18/18 [00:00<00:00, 740.11it/s]
Epoch 584/700: 100%|██████████| 18/18 [00:00<00:00, 1125.42it/s]
Epoch 585/700: 100%|██████████| 18/18 [00:00<00:00, 734.99it/s]
Epoch 586/700: 100%|██████████| 18/18 [00:00<00:00, 921.07it/s]
Epoch 587/700: 100%|██████████| 18/18 [00:00<00:00, 964.68it/s]
Epoch 588/700: 100%|██████████| 18/18 [00:00<00:00, 1118.65it/s]
Epoch 589/700: 100%|██████████| 18/18 [00:00<00:00, 1130.39it/s]
Epoch 590/700: 100%|██████████| 18/18 [00:00<00:00, 560.59it/s]
Epoch 591/700: 100%|██████████| 18/18 [00:00<00:00, 559.88it/s]


Epoch [591/700], Training Loss: 0.1394


Epoch 592/700: 100%|██████████| 18/18 [00:00<00:00, 750.82it/s]
Epoch 593/700: 100%|██████████| 18/18 [00:00<00:00, 1126.02it/s]
Epoch 594/700: 100%|██████████| 18/18 [00:00<00:00, 737.16it/s]
Epoch 595/700: 100%|██████████| 18/18 [00:00<00:00, 1102.22it/s]
Epoch 596/700: 100%|██████████| 18/18 [00:00<00:00, 747.44it/s]
Epoch 597/700: 100%|██████████| 18/18 [00:00<00:00, 1072.79it/s]
Epoch 598/700: 100%|██████████| 18/18 [00:00<00:00, 764.23it/s]
Epoch 599/700: 100%|██████████| 18/18 [00:00<00:00, 1123.09it/s]
Epoch 600/700: 100%|██████████| 18/18 [00:00<00:00, 1233.12it/s]
Epoch 601/700: 100%|██████████| 18/18 [00:00<00:00, 1103.99it/s]


Epoch [601/700], Training Loss: 0.1413


Epoch 602/700: 100%|██████████| 18/18 [00:00<00:00, 737.34it/s]
Epoch 603/700: 100%|██████████| 18/18 [00:00<00:00, 730.34it/s]
Epoch 604/700: 100%|██████████| 18/18 [00:00<00:00, 724.47it/s]
Epoch 605/700: 100%|██████████| 18/18 [00:00<00:00, 1123.21it/s]
Epoch 606/700: 100%|██████████| 18/18 [00:00<00:00, 747.56it/s]
Epoch 607/700: 100%|██████████| 18/18 [00:00<00:00, 738.17it/s]
Epoch 608/700: 100%|██████████| 18/18 [00:00<00:00, 548.50it/s]
Epoch 609/700: 100%|██████████| 18/18 [00:00<00:00, 1102.85it/s]
Epoch 610/700: 100%|██████████| 18/18 [00:00<00:00, 1098.53it/s]
Epoch 611/700: 100%|██████████| 18/18 [00:00<00:00, 1125.48it/s]


Epoch [611/700], Training Loss: 0.1358


Epoch 612/700: 100%|██████████| 18/18 [00:00<00:00, 747.18it/s]
Epoch 613/700: 100%|██████████| 18/18 [00:00<00:00, 1121.80it/s]
Epoch 614/700: 100%|██████████| 18/18 [00:00<00:00, 746.28it/s]
Epoch 615/700: 100%|██████████| 18/18 [00:00<00:00, 1120.59it/s]
Epoch 616/700: 100%|██████████| 18/18 [00:00<00:00, 1115.44it/s]
Epoch 617/700: 100%|██████████| 18/18 [00:00<00:00, 1197.02it/s]
Epoch 618/700: 100%|██████████| 18/18 [00:00<00:00, 449.56it/s]
Epoch 619/700: 100%|██████████| 18/18 [00:00<00:00, 561.53it/s]
Epoch 620/700: 100%|██████████| 18/18 [00:00<00:00, 781.28it/s]
Epoch 621/700: 100%|██████████| 18/18 [00:00<00:00, 1053.46it/s]


Epoch [621/700], Training Loss: 0.1331


Epoch 622/700: 100%|██████████| 18/18 [00:00<00:00, 745.10it/s]
Epoch 623/700: 100%|██████████| 18/18 [00:00<00:00, 720.86it/s]
Epoch 624/700: 100%|██████████| 18/18 [00:00<00:00, 749.09it/s]
Epoch 625/700: 100%|██████████| 18/18 [00:00<00:00, 750.24it/s]
Epoch 626/700: 100%|██████████| 18/18 [00:00<00:00, 1124.80it/s]
Epoch 627/700: 100%|██████████| 18/18 [00:00<00:00, 740.60it/s]
Epoch 628/700: 100%|██████████| 18/18 [00:00<00:00, 888.69it/s]
Epoch 629/700: 100%|██████████| 18/18 [00:00<00:00, 750.86it/s]
Epoch 630/700: 100%|██████████| 18/18 [00:00<00:00, 665.50it/s]
Epoch 631/700: 100%|██████████| 18/18 [00:00<00:00, 1118.86it/s]


Epoch [631/700], Training Loss: 0.1431


Epoch 632/700: 100%|██████████| 18/18 [00:00<00:00, 748.92it/s]
Epoch 633/700: 100%|██████████| 18/18 [00:00<00:00, 1092.88it/s]
Epoch 634/700: 100%|██████████| 18/18 [00:00<00:00, 1119.56it/s]
Epoch 635/700: 100%|██████████| 18/18 [00:00<00:00, 1118.80it/s]
Epoch 636/700: 100%|██████████| 18/18 [00:00<00:00, 748.52it/s]
Epoch 637/700: 100%|██████████| 18/18 [00:00<00:00, 751.56it/s]
Epoch 638/700: 100%|██████████| 18/18 [00:00<00:00, 1103.67it/s]
Epoch 639/700: 100%|██████████| 18/18 [00:00<00:00, 1071.54it/s]
Epoch 640/700: 100%|██████████| 18/18 [00:00<00:00, 754.42it/s]
Epoch 641/700: 100%|██████████| 18/18 [00:00<00:00, 747.81it/s]


Epoch [641/700], Training Loss: 0.1358


Epoch 642/700: 100%|██████████| 18/18 [00:00<00:00, 749.30it/s]
Epoch 643/700: 100%|██████████| 18/18 [00:00<00:00, 712.99it/s]
Epoch 644/700: 100%|██████████| 18/18 [00:00<00:00, 1123.54it/s]
Epoch 645/700: 100%|██████████| 18/18 [00:00<00:00, 734.35it/s]
Epoch 646/700: 100%|██████████| 18/18 [00:00<00:00, 373.56it/s]
Epoch 647/700: 100%|██████████| 18/18 [00:00<00:00, 1157.79it/s]
Epoch 648/700: 100%|██████████| 18/18 [00:00<00:00, 813.03it/s]
Epoch 649/700: 100%|██████████| 18/18 [00:00<00:00, 1124.80it/s]
Epoch 650/700: 100%|██████████| 18/18 [00:00<00:00, 733.49it/s]
Epoch 651/700: 100%|██████████| 18/18 [00:00<00:00, 748.19it/s]


Epoch [651/700], Training Loss: 0.1600


Epoch 652/700: 100%|██████████| 18/18 [00:00<00:00, 751.07it/s]
Epoch 653/700: 100%|██████████| 18/18 [00:00<00:00, 750.03it/s]
Epoch 654/700: 100%|██████████| 18/18 [00:00<00:00, 786.07it/s]
Epoch 655/700: 100%|██████████| 18/18 [00:00<00:00, 1098.02it/s]
Epoch 656/700: 100%|██████████| 18/18 [00:00<00:00, 736.14it/s]
Epoch 657/700: 100%|██████████| 18/18 [00:00<00:00, 1123.46it/s]
Epoch 658/700: 100%|██████████| 18/18 [00:00<00:00, 1126.09it/s]
Epoch 659/700: 100%|██████████| 18/18 [00:00<00:00, 1153.51it/s]
Epoch 660/700: 100%|██████████| 18/18 [00:00<00:00, 748.84it/s]
Epoch 661/700: 100%|██████████| 18/18 [00:00<00:00, 749.10it/s]


Epoch [661/700], Training Loss: 0.1370


Epoch 662/700: 100%|██████████| 18/18 [00:00<00:00, 735.33it/s]
Epoch 663/700: 100%|██████████| 18/18 [00:00<00:00, 747.34it/s]
Epoch 664/700: 100%|██████████| 18/18 [00:00<00:00, 733.98it/s]
Epoch 665/700: 100%|██████████| 18/18 [00:00<00:00, 751.14it/s]
Epoch 666/700: 100%|██████████| 18/18 [00:00<00:00, 1104.81it/s]
Epoch 667/700: 100%|██████████| 18/18 [00:00<00:00, 734.84it/s]
Epoch 668/700: 100%|██████████| 18/18 [00:00<00:00, 1119.29it/s]
Epoch 669/700: 100%|██████████| 18/18 [00:00<00:00, 736.88it/s]
Epoch 670/700: 100%|██████████| 18/18 [00:00<00:00, 815.22it/s]
Epoch 671/700: 100%|██████████| 18/18 [00:00<00:00, 745.01it/s]


Epoch [671/700], Training Loss: 0.1355


Epoch 672/700: 100%|██████████| 18/18 [00:00<00:00, 749.22it/s]
Epoch 673/700: 100%|██████████| 18/18 [00:00<00:00, 1121.74it/s]
Epoch 674/700: 100%|██████████| 18/18 [00:00<00:00, 379.05it/s]
Epoch 675/700: 100%|██████████| 18/18 [00:00<00:00, 724.84it/s]
Epoch 676/700: 100%|██████████| 18/18 [00:00<00:00, 1122.94it/s]
Epoch 677/700: 100%|██████████| 18/18 [00:00<00:00, 1124.76it/s]
Epoch 678/700: 100%|██████████| 18/18 [00:00<00:00, 745.53it/s]
Epoch 679/700: 100%|██████████| 18/18 [00:00<00:00, 738.43it/s]
Epoch 680/700: 100%|██████████| 18/18 [00:00<00:00, 750.46it/s]
Epoch 681/700: 100%|██████████| 18/18 [00:00<00:00, 1123.81it/s]


Epoch [681/700], Training Loss: 0.1338


Epoch 682/700: 100%|██████████| 18/18 [00:00<00:00, 1111.14it/s]
Epoch 683/700: 100%|██████████| 18/18 [00:00<00:00, 749.45it/s]
Epoch 684/700: 100%|██████████| 18/18 [00:00<00:00, 1127.37it/s]
Epoch 685/700: 100%|██████████| 18/18 [00:00<00:00, 1121.04it/s]
Epoch 686/700: 100%|██████████| 18/18 [00:00<00:00, 1119.56it/s]
Epoch 687/700: 100%|██████████| 18/18 [00:00<00:00, 748.53it/s]
Epoch 688/700: 100%|██████████| 18/18 [00:00<00:00, 460.47it/s]
Epoch 689/700: 100%|██████████| 18/18 [00:00<00:00, 1069.87it/s]
Epoch 690/700: 100%|██████████| 18/18 [00:00<00:00, 1113.71it/s]
Epoch 691/700: 100%|██████████| 18/18 [00:00<00:00, 748.09it/s]


Epoch [691/700], Training Loss: 0.1422


Epoch 692/700: 100%|██████████| 18/18 [00:00<00:00, 1073.11it/s]
Epoch 693/700: 100%|██████████| 18/18 [00:00<00:00, 538.58it/s]
Epoch 694/700: 100%|██████████| 18/18 [00:00<00:00, 780.33it/s]
Epoch 695/700: 100%|██████████| 18/18 [00:00<00:00, 1119.73it/s]
Epoch 696/700: 100%|██████████| 18/18 [00:00<00:00, 752.21it/s]
Epoch 697/700: 100%|██████████| 18/18 [00:00<00:00, 745.88it/s]
Epoch 698/700: 100%|██████████| 18/18 [00:00<00:00, 748.97it/s]
Epoch 699/700: 100%|██████████| 18/18 [00:00<00:00, 747.30it/s]
Epoch 700/700: 100%|██████████| 18/18 [00:00<00:00, 731.81it/s]

Validation Loss: 0.1741, Accuracy: 46.85%
Final Validation Loss with MSE: 0.1741
Final Validation Accuracy with MSE: 46.85%





### 3. NLLLoss (`torch.nn.NLLLoss`)
- **Description:** Negative Log-Likelihood Loss measures the likelihood of the target class under the predicted probability distribution.
- **Use Case:** Typically used in multi-class classification tasks, especially when combined with `log_softmax` activation.

Here is the mathematical formulation of NLLLoss:
\begin{equation}
\text{NLLLoss} = -\frac{1}{n} \sum_{i=1}^{n} \log(y_{i})
\end{equation}

I hope you note the logarithm in the formula. It's important! 

Why?

In this part, run your training with Relu at last layer. <span style="color:red; font-weight: bold;">Discuss </span> and explain the difference between the results of the two models. Find a proper solution to the problem.


In [82]:
# Apply sigmoid activation for BCELoss
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers=1, last_layer_activation_fn=nn.ReLU):
        super(SimpleMLP, self).__init__()
        
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())

        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        if last_layer_activation_fn is not None:
            layers.append(last_layer_activation_fn())
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)



class SimpleMLPTrainer:
    def __init__(self, model, criterion, optimizer):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def train(self, train_loader, num_epochs):
        training_losses = []

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            self.model.train()
            for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                
                # Apply sigmoid activation for BCELoss
                probs = torch.sigmoid(outputs)
                
                # Calculate BCELoss
                loss = self.criterion(probs, targets)
                
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item() * inputs.size(0)

            epoch_loss /= len(train_loader.dataset)
            training_losses.append(epoch_loss)
            if (epoch % 10 == 0):
                print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")

        return training_losses

    def evaluate(self, val_loader):
        self.model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = self.model(inputs)
                
                # Apply sigmoid activation for BCELoss
                probs = torch.sigmoid(outputs)
                
                loss = self.criterion(probs, targets)
                val_loss += loss.item() * inputs.size(0)

                predictions = (probs >= 0.5).float()
                correct_predictions += (predictions == targets).sum().item()
                total_predictions += targets.size(0)

        val_loss /= len(val_loader.dataset)
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

        print(f"Final Validation Loss with BCELoss: {val_loss:.4f}")
        print(f"Final Validation Accuracy with BCELoss: {accuracy * 100:.2f}%")
        return val_loss, accuracy

In [83]:
# Run with relu activation function
from torch.nn import ReLU, NLLLoss, CrossEntropyLoss
# TODO: Train the model
# TODO: Evaluate the model

criterion = nn.BCELoss()  
optimizer = optim.Adam(model.parameters(), lr=0.01)

input_dim = X_train.shape[1]
hidden_dim = 16
output_dim = 1  
model = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
trainer = SimpleMLPTrainer(model, criterion, optimizer)

training_losses = trainer.train(train_loader, num_epochs=100)
val_loss, val_accuracy = trainer.evaluate(val_loader)



Epoch 1/100: 100%|██████████| 18/18 [00:00<00:00, 1128.73it/s]


Epoch [1/100], Training Loss: 0.9708


Epoch 2/100: 100%|██████████| 18/18 [00:00<00:00, 1124.50it/s]
Epoch 3/100: 100%|██████████| 18/18 [00:00<00:00, 1124.80it/s]
Epoch 4/100: 100%|██████████| 18/18 [00:00<00:00, 2149.15it/s]
Epoch 5/100: 100%|██████████| 18/18 [00:00<00:00, 2071.32it/s]
Epoch 6/100: 100%|██████████| 18/18 [00:00<00:00, 1085.30it/s]
Epoch 7/100: 100%|██████████| 18/18 [00:00<00:00, 734.05it/s]
Epoch 8/100: 100%|██████████| 18/18 [00:00<00:00, 1098.53it/s]
Epoch 9/100: 100%|██████████| 18/18 [00:00<00:00, 1097.94it/s]
Epoch 10/100: 100%|██████████| 18/18 [00:00<00:00, 2930.35it/s]
Epoch 11/100: 100%|██████████| 18/18 [00:00<00:00, 1070.39it/s]


Epoch [11/100], Training Loss: 0.9708


Epoch 12/100: 100%|██████████| 18/18 [00:00<00:00, 1616.51it/s]
Epoch 13/100: 100%|██████████| 18/18 [00:00<00:00, 722.07it/s]
Epoch 14/100: 100%|██████████| 18/18 [00:00<00:00, 1092.54it/s]
Epoch 15/100: 100%|██████████| 18/18 [00:00<00:00, 910.99it/s]
Epoch 16/100: 100%|██████████| 18/18 [00:00<00:00, 1286.90it/s]
Epoch 17/100: 100%|██████████| 18/18 [00:00<00:00, 2337.74it/s]
Epoch 18/100: 100%|██████████| 18/18 [00:00<00:00, 2357.60it/s]
Epoch 19/100: 100%|██████████| 18/18 [00:00<00:00, 1069.48it/s]
Epoch 20/100: 100%|██████████| 18/18 [00:00<00:00, 1097.84it/s]
Epoch 21/100: 100%|██████████| 18/18 [00:00<00:00, 1142.02it/s]


Epoch [21/100], Training Loss: 0.9708


Epoch 22/100: 100%|██████████| 18/18 [00:00<00:00, 2170.78it/s]
Epoch 23/100: 100%|██████████| 18/18 [00:00<00:00, 2157.07it/s]
Epoch 24/100: 100%|██████████| 18/18 [00:00<00:00, 2074.11it/s]
Epoch 25/100: 100%|██████████| 18/18 [00:00<00:00, 1086.11it/s]
Epoch 26/100: 100%|██████████| 18/18 [00:00<00:00, 1094.82it/s]
Epoch 27/100: 100%|██████████| 18/18 [00:00<00:00, 1262.31it/s]
Epoch 28/100: 100%|██████████| 18/18 [00:00<00:00, 1039.05it/s]
Epoch 29/100: 100%|██████████| 18/18 [00:00<00:00, 965.63it/s]
Epoch 30/100: 100%|██████████| 18/18 [00:00<00:00, 2257.02it/s]
Epoch 31/100: 100%|██████████| 18/18 [00:00<00:00, 1999.77it/s]


Epoch [31/100], Training Loss: 0.9708


Epoch 32/100: 100%|██████████| 18/18 [00:00<00:00, 1074.04it/s]
Epoch 33/100: 100%|██████████| 18/18 [00:00<00:00, 1060.51it/s]
Epoch 34/100: 100%|██████████| 18/18 [00:00<00:00, 1086.09it/s]
Epoch 35/100: 100%|██████████| 18/18 [00:00<00:00, 742.74it/s]
Epoch 36/100: 100%|██████████| 18/18 [00:00<00:00, 748.57it/s]
Epoch 37/100: 100%|██████████| 18/18 [00:00<00:00, 1081.50it/s]
Epoch 38/100: 100%|██████████| 18/18 [00:00<00:00, 529.51it/s]
Epoch 39/100: 100%|██████████| 18/18 [00:00<00:00, 578.52it/s]
Epoch 40/100: 100%|██████████| 18/18 [00:00<00:00, 553.00it/s]
Epoch 41/100: 100%|██████████| 18/18 [00:00<00:00, 1104.91it/s]


Epoch [41/100], Training Loss: 0.9708


Epoch 42/100: 100%|██████████| 18/18 [00:00<00:00, 2233.98it/s]
Epoch 43/100: 100%|██████████| 18/18 [00:00<00:00, 1119.26it/s]
Epoch 44/100: 100%|██████████| 18/18 [00:00<00:00, 1102.07it/s]
Epoch 45/100: 100%|██████████| 18/18 [00:00<00:00, 1108.02it/s]
Epoch 46/100: 100%|██████████| 18/18 [00:00<00:00, 1096.10it/s]
Epoch 47/100: 100%|██████████| 18/18 [00:00<00:00, 1086.20it/s]
Epoch 48/100: 100%|██████████| 18/18 [00:00<00:00, 1097.76it/s]
Epoch 49/100: 100%|██████████| 18/18 [00:00<00:00, 970.50it/s]
Epoch 50/100: 100%|██████████| 18/18 [00:00<00:00, 1109.15it/s]
Epoch 51/100: 100%|██████████| 18/18 [00:00<00:00, 1125.65it/s]


Epoch [51/100], Training Loss: 0.9708


Epoch 52/100: 100%|██████████| 18/18 [00:00<00:00, 721.11it/s]
Epoch 53/100: 100%|██████████| 18/18 [00:00<00:00, 1092.58it/s]
Epoch 54/100: 100%|██████████| 18/18 [00:00<00:00, 1112.53it/s]
Epoch 55/100: 100%|██████████| 18/18 [00:00<00:00, 719.33it/s]
Epoch 56/100: 100%|██████████| 18/18 [00:00<00:00, 1117.59it/s]
Epoch 57/100: 100%|██████████| 18/18 [00:00<00:00, 1124.63it/s]
Epoch 58/100: 100%|██████████| 18/18 [00:00<00:00, 738.63it/s]
Epoch 59/100: 100%|██████████| 18/18 [00:00<00:00, 740.59it/s]
Epoch 60/100: 100%|██████████| 18/18 [00:00<00:00, 749.09it/s]
Epoch 61/100: 100%|██████████| 18/18 [00:00<00:00, 1104.78it/s]


Epoch [61/100], Training Loss: 0.9708


Epoch 62/100: 100%|██████████| 18/18 [00:00<00:00, 1223.80it/s]
Epoch 63/100: 100%|██████████| 18/18 [00:00<00:00, 1066.92it/s]
Epoch 64/100: 100%|██████████| 18/18 [00:00<00:00, 751.72it/s]
Epoch 65/100: 100%|██████████| 18/18 [00:00<00:00, 986.29it/s]
Epoch 66/100: 100%|██████████| 18/18 [00:00<00:00, 747.92it/s]
Epoch 67/100: 100%|██████████| 18/18 [00:00<00:00, 541.88it/s]
Epoch 68/100: 100%|██████████| 18/18 [00:00<00:00, 1121.24it/s]
Epoch 69/100: 100%|██████████| 18/18 [00:00<00:00, 1109.00it/s]
Epoch 70/100: 100%|██████████| 18/18 [00:00<00:00, 1117.84it/s]
Epoch 71/100: 100%|██████████| 18/18 [00:00<00:00, 1088.55it/s]


Epoch [71/100], Training Loss: 0.9708


Epoch 72/100: 100%|██████████| 18/18 [00:00<00:00, 1098.29it/s]
Epoch 73/100: 100%|██████████| 18/18 [00:00<00:00, 1123.76it/s]
Epoch 74/100: 100%|██████████| 18/18 [00:00<00:00, 700.05it/s]
Epoch 75/100: 100%|██████████| 18/18 [00:00<00:00, 815.29it/s]
Epoch 76/100: 100%|██████████| 18/18 [00:00<00:00, 1117.34it/s]
Epoch 77/100: 100%|██████████| 18/18 [00:00<00:00, 1124.36it/s]
Epoch 78/100: 100%|██████████| 18/18 [00:00<00:00, 1123.39it/s]
Epoch 79/100: 100%|██████████| 18/18 [00:00<00:00, 748.72it/s]
Epoch 80/100: 100%|██████████| 18/18 [00:00<00:00, 901.93it/s]
Epoch 81/100: 100%|██████████| 18/18 [00:00<00:00, 1176.25it/s]


Epoch [81/100], Training Loss: 0.9708


Epoch 82/100: 100%|██████████| 18/18 [00:00<00:00, 724.72it/s]
Epoch 83/100: 100%|██████████| 18/18 [00:00<00:00, 1115.90it/s]
Epoch 84/100: 100%|██████████| 18/18 [00:00<00:00, 1031.30it/s]
Epoch 85/100: 100%|██████████| 18/18 [00:00<00:00, 1126.34it/s]
Epoch 86/100: 100%|██████████| 18/18 [00:00<00:00, 717.38it/s]
Epoch 87/100: 100%|██████████| 18/18 [00:00<00:00, 1135.74it/s]
Epoch 88/100: 100%|██████████| 18/18 [00:00<00:00, 2152.40it/s]
Epoch 89/100: 100%|██████████| 18/18 [00:00<00:00, 1103.99it/s]
Epoch 90/100: 100%|██████████| 18/18 [00:00<00:00, 1149.11it/s]
Epoch 91/100: 100%|██████████| 18/18 [00:00<00:00, 745.07it/s]


Epoch [91/100], Training Loss: 0.9708


Epoch 92/100: 100%|██████████| 18/18 [00:00<00:00, 1127.94it/s]
Epoch 93/100: 100%|██████████| 18/18 [00:00<00:00, 1095.23it/s]
Epoch 94/100: 100%|██████████| 18/18 [00:00<00:00, 749.38it/s]
Epoch 95/100: 100%|██████████| 18/18 [00:00<00:00, 1422.66it/s]
Epoch 96/100: 100%|██████████| 18/18 [00:00<00:00, 559.12it/s]
Epoch 97/100: 100%|██████████| 18/18 [00:00<00:00, 722.64it/s]
Epoch 98/100: 100%|██████████| 18/18 [00:00<00:00, 704.43it/s]
Epoch 99/100: 100%|██████████| 18/18 [00:00<00:00, 1373.01it/s]
Epoch 100/100: 100%|██████████| 18/18 [00:00<00:00, 1128.44it/s]

Final Validation Loss with BCELoss: 0.8931
Final Validation Accuracy with BCELoss: 39.16%





In [None]:
# Run with LogSoftmax activation function
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers=1, last_layer_activation_fn=nn.LogSoftmax):
        super(SimpleMLP, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        if last_layer_activation_fn is not None:
            layers.append(last_layer_activation_fn(dim=1))  
            
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)
    

class SimpleMLPTrainer:
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer

    def manual_nll_loss(self, log_probs, targets):
        # Manually compute NLLLoss
        if targets.dim() > 1:
            targets = targets.squeeze()
        targets = targets.long()  # Convert to LongTensor if needed
        batch_size = log_probs.size(0)
        target_log_probs = log_probs[range(batch_size), targets]  # log prob for correct class
        return -target_log_probs.mean()

    def train(self, train_loader, num_epochs):
        training_losses = []
        for epoch in range(num_epochs):
            epoch_loss = 0.0
            self.model.train()
            for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                self.optimizer.zero_grad()
                log_probs = self.model(inputs) 
                loss = self.manual_nll_loss(log_probs, targets)  
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item() * inputs.size(0)
            epoch_loss /= len(train_loader.dataset)
            training_losses.append(epoch_loss)
            if (epoch % 10 == 0):
                print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")
        return training_losses

    def evaluate(self, val_loader):
  
        self.model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
            
                targets = targets.squeeze().long()
                
                log_probs = self.model(inputs)
                loss = self.manual_nll_loss(log_probs, targets)
                val_loss += loss.item() * inputs.size(0)
                
                _, predicted = torch.max(log_probs, 1)
                
                correct_predictions += (predicted == targets).sum().item()
                total_predictions += targets.size(0)

        val_loss /= len(val_loader.dataset)
        
        accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0
        print(f"Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%")
        return val_loss, accuracy





In [63]:
# Run with LogSoftmax activation function
from torch.nn import NLLLoss

# TODO: Train the model

# TODO: Evaluate the model
model = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=2)  # 2 output nodes for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and evaluation
trainer = SimpleMLPTrainer(model, optimizer)

training_losses = trainer.train(train_loader, num_epochs=30)
val_loss, val_accuracy = trainer.evaluate(val_loader)

print(f"Final Validation Loss with manual NLLLoss: {val_loss:.4f}")
print(f"Final Validation Accuracy with manual NLLLoss: {val_accuracy :.2f}%")

Epoch 1/30: 100%|██████████| 18/18 [00:00<00:00, 562.34it/s]


Epoch [1/30], Training Loss: 1.6479


Epoch 2/30: 100%|██████████| 18/18 [00:00<00:00, 748.85it/s]
Epoch 3/30: 100%|██████████| 18/18 [00:00<00:00, 754.93it/s]
Epoch 4/30: 100%|██████████| 18/18 [00:00<00:00, 917.48it/s]
Epoch 5/30: 100%|██████████| 18/18 [00:00<00:00, 741.20it/s]
Epoch 6/30: 100%|██████████| 18/18 [00:00<00:00, 1084.42it/s]
Epoch 7/30: 100%|██████████| 18/18 [00:00<00:00, 719.18it/s]
Epoch 8/30: 100%|██████████| 18/18 [00:00<00:00, 1117.39it/s]
Epoch 9/30: 100%|██████████| 18/18 [00:00<00:00, 744.57it/s]
Epoch 10/30: 100%|██████████| 18/18 [00:00<00:00, 453.96it/s]
Epoch 11/30: 100%|██████████| 18/18 [00:00<00:00, 1115.06it/s]


Epoch [11/30], Training Loss: 0.6107


Epoch 12/30: 100%|██████████| 18/18 [00:00<00:00, 753.14it/s]
Epoch 13/30: 100%|██████████| 18/18 [00:00<00:00, 1125.20it/s]
Epoch 14/30: 100%|██████████| 18/18 [00:00<00:00, 2151.91it/s]
Epoch 15/30: 100%|██████████| 18/18 [00:00<00:00, 1105.82it/s]
Epoch 16/30: 100%|██████████| 18/18 [00:00<00:00, 743.99it/s]
Epoch 17/30: 100%|██████████| 18/18 [00:00<00:00, 733.34it/s]
Epoch 18/30: 100%|██████████| 18/18 [00:00<00:00, 749.60it/s]
Epoch 19/30: 100%|██████████| 18/18 [00:00<00:00, 1114.98it/s]
Epoch 20/30: 100%|██████████| 18/18 [00:00<00:00, 761.69it/s]
Epoch 21/30: 100%|██████████| 18/18 [00:00<00:00, 746.80it/s]


Epoch [21/30], Training Loss: 0.5735


Epoch 22/30: 100%|██████████| 18/18 [00:00<00:00, 549.70it/s]
Epoch 23/30: 100%|██████████| 18/18 [00:00<00:00, 742.44it/s]
Epoch 24/30: 100%|██████████| 18/18 [00:00<00:00, 749.95it/s]
Epoch 25/30: 100%|██████████| 18/18 [00:00<00:00, 596.46it/s]
Epoch 26/30: 100%|██████████| 18/18 [00:00<00:00, 556.14it/s]
Epoch 27/30: 100%|██████████| 18/18 [00:00<00:00, 450.12it/s]
Epoch 28/30: 100%|██████████| 18/18 [00:00<00:00, 544.75it/s]
Epoch 29/30: 100%|██████████| 18/18 [00:00<00:00, 741.36it/s]
Epoch 30/30: 100%|██████████| 18/18 [00:00<00:00, 552.18it/s]


Validation Loss: 0.6203, Accuracy: 61.54%
Final Validation Loss with manual NLLLoss: 0.6203
Final Validation Accuracy with manual NLLLoss: 61.54%


Your reason for your choice:

<div>
**Your answer here**
</div>

When using ReLU as the last layer activation function, the model output values are non-negative but unbounded, meaning they don't naturally fall within a probability distribution range (i.e., [0,1]). This output isn't optimal for binary or multi-class classification tasks where we want outputs representing class probabilities.

Impact of Using ReLU in the Last Layer: ReLU allows positive outputs but may produce values much greater than 1. When combined with loss functions like NLLLoss, this can lead to poor performance, as NLLLoss expects log-probabilities (typically from a log_softmax activation) rather than raw outputs.


For binary or multi-class classification, a more suitable solution is to use sigmoid (for binary) or softmax (for multi-class) in the last layer. These functions output values between 0 and 1, representing probabilities.



### 4. CrossEntropyLoss (`torch.nn.CrossEntropyLoss`)
- **Description:** Combines `LogSoftmax` and `NLLLoss` in one single class. It computes the cross-entropy loss between the target and the output logits.
- **Use Case:** Widely used for multi-class classification problems.

The mathematical formulation of CrossEntropyLoss is as follows:
\begin{equation}
  \text{CrossEntropy}(y, \hat{y}) = - \sum_{i=1}^{C} y_i \log\left(\frac{e^{\hat{y}_i}}{\sum_{j=1}^{C} e^{\hat{y}_j}}\right)
\end{equation}
  where:
  - \( C \) is the number of classes,
  - \( y_i \) is a one-hot encoded target vector (or a scalar class label),
  - \( \hat{y}_i \) represents the logits (unnormalized model outputs) for each class.
  
  In practice, `torch.nn.CrossEntropyLoss` expects raw logits as input and internally applies the softmax function to convert the logits into probabilities, followed by the negative log-likelihood computation.

- **Background:** Cross-entropy measures the difference between the true distribution \( y \) and the predicted distribution \( \hat{y} \). The function minimizes the negative log-probability assigned to the correct class, effectively penalizing predictions that deviate from the true class, making it a standard choice for classification tasks in deep learning.

Now, let's implement a class called `SimpleMLP_Loss` that has the following architecture:


In [None]:
class SimpleMLP_Loss(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers=1):
        super(SimpleMLP_Loss, self).__init__()
        
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())

        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
class SimpleMLPTrainer_Loss:
    def __init__(self, model, criterion, optimizer):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def train(self, train_loader, num_epochs):
        training_losses = []

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            self.model.train()
            for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                self.optimizer.zero_grad()
                outputs = self.model(inputs)  
                
                targets = targets.squeeze().long()
                
                loss = self.criterion(outputs, targets)
                
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item() * inputs.size(0)

            epoch_loss /= len(train_loader.dataset)
            training_losses.append(epoch_loss)
            if (epoch % 10 == 0):
                print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")

        return training_losses

    def evaluate(self, val_loader):
        self.model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = self.model(inputs)  
                
                targets = targets.squeeze().long()
                
                loss = self.criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                
                _, predictions = torch.max(outputs, 1)
                correct_predictions += (predictions == targets).sum().item()
                total_predictions += targets.size(0)

        val_loss /= len(val_loader.dataset)
        accuracy = correct_predictions / total_predictions * 100 if total_predictions > 0 else 0

        print(f"Final Validation Loss with CrossEntropyLoss: {val_loss:.4f}")
        print(f"Final Validation Accuracy with CrossEntropyLoss: {accuracy:.2f}%")
        return val_loss, accuracy


In [92]:
from torch.nn import CrossEntropyLoss

#TODO Train the model

#TODO Evaluate the model
input_dim = X_train.shape[1]
hidden_dim =16
output_dim = 2   

model = SimpleMLP_Loss(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)
trainer = SimpleMLPTrainer_Loss(model, criterion, optimizer)

training_losses = trainer.train(train_loader, num_epochs=100)
val_loss, val_accuracy = trainer.evaluate(val_loader)


Epoch 1/100: 100%|██████████| 18/18 [00:00<00:00, 455.52it/s]


Epoch [1/100], Training Loss: 2.1591


Epoch 2/100: 100%|██████████| 18/18 [00:00<00:00, 734.80it/s]
Epoch 3/100: 100%|██████████| 18/18 [00:00<00:00, 554.98it/s]
Epoch 4/100: 100%|██████████| 18/18 [00:00<00:00, 438.84it/s]
Epoch 5/100: 100%|██████████| 18/18 [00:00<00:00, 726.91it/s]
Epoch 6/100: 100%|██████████| 18/18 [00:00<00:00, 561.90it/s]
Epoch 7/100: 100%|██████████| 18/18 [00:00<00:00, 438.29it/s]
Epoch 8/100: 100%|██████████| 18/18 [00:00<00:00, 718.18it/s]
Epoch 9/100: 100%|██████████| 18/18 [00:00<00:00, 554.23it/s]
Epoch 10/100: 100%|██████████| 18/18 [00:00<00:00, 539.85it/s]
Epoch 11/100: 100%|██████████| 18/18 [00:00<00:00, 747.76it/s]


Epoch [11/100], Training Loss: 0.5732


Epoch 12/100: 100%|██████████| 18/18 [00:00<00:00, 743.44it/s]
Epoch 13/100: 100%|██████████| 18/18 [00:00<00:00, 444.50it/s]
Epoch 14/100: 100%|██████████| 18/18 [00:00<00:00, 564.51it/s]
Epoch 15/100: 100%|██████████| 18/18 [00:00<00:00, 563.53it/s]
Epoch 16/100: 100%|██████████| 18/18 [00:00<00:00, 394.44it/s]
Epoch 17/100: 100%|██████████| 18/18 [00:00<00:00, 562.35it/s]
Epoch 18/100: 100%|██████████| 18/18 [00:00<00:00, 588.55it/s]
Epoch 19/100: 100%|██████████| 18/18 [00:00<00:00, 568.03it/s]
Epoch 20/100: 100%|██████████| 18/18 [00:00<00:00, 560.15it/s]
Epoch 21/100: 100%|██████████| 18/18 [00:00<00:00, 439.08it/s]


Epoch [21/100], Training Loss: 0.5561


Epoch 22/100: 100%|██████████| 18/18 [00:00<00:00, 445.89it/s]
Epoch 23/100: 100%|██████████| 18/18 [00:00<00:00, 561.17it/s]
Epoch 24/100: 100%|██████████| 18/18 [00:00<00:00, 562.11it/s]
Epoch 25/100: 100%|██████████| 18/18 [00:00<00:00, 561.23it/s]
Epoch 26/100: 100%|██████████| 18/18 [00:00<00:00, 607.54it/s]
Epoch 27/100: 100%|██████████| 18/18 [00:00<00:00, 558.03it/s]
Epoch 28/100: 100%|██████████| 18/18 [00:00<00:00, 734.43it/s]
Epoch 29/100: 100%|██████████| 18/18 [00:00<00:00, 373.68it/s]
Epoch 30/100: 100%|██████████| 18/18 [00:00<00:00, 457.05it/s]
Epoch 31/100: 100%|██████████| 18/18 [00:00<00:00, 744.24it/s]


Epoch [31/100], Training Loss: 0.5491


Epoch 32/100: 100%|██████████| 18/18 [00:00<00:00, 748.69it/s]
Epoch 33/100: 100%|██████████| 18/18 [00:00<00:00, 562.23it/s]
Epoch 34/100: 100%|██████████| 18/18 [00:00<00:00, 554.96it/s]
Epoch 35/100: 100%|██████████| 18/18 [00:00<00:00, 555.96it/s]
Epoch 36/100: 100%|██████████| 18/18 [00:00<00:00, 447.12it/s]
Epoch 37/100: 100%|██████████| 18/18 [00:00<00:00, 563.19it/s]
Epoch 38/100: 100%|██████████| 18/18 [00:00<00:00, 552.36it/s]
Epoch 39/100: 100%|██████████| 18/18 [00:00<00:00, 562.28it/s]
Epoch 40/100: 100%|██████████| 18/18 [00:00<00:00, 740.92it/s]
Epoch 41/100: 100%|██████████| 18/18 [00:00<00:00, 753.38it/s]


Epoch [41/100], Training Loss: 0.5244


Epoch 42/100: 100%|██████████| 18/18 [00:00<00:00, 371.51it/s]
Epoch 43/100: 100%|██████████| 18/18 [00:00<00:00, 449.03it/s]
Epoch 44/100: 100%|██████████| 18/18 [00:00<00:00, 402.15it/s]
Epoch 45/100: 100%|██████████| 18/18 [00:00<00:00, 368.81it/s]
Epoch 46/100: 100%|██████████| 18/18 [00:00<00:00, 373.13it/s]
Epoch 47/100: 100%|██████████| 18/18 [00:00<00:00, 370.19it/s]
Epoch 48/100: 100%|██████████| 18/18 [00:00<00:00, 538.74it/s]
Epoch 49/100: 100%|██████████| 18/18 [00:00<00:00, 549.98it/s]
Epoch 50/100: 100%|██████████| 18/18 [00:00<00:00, 563.60it/s]
Epoch 51/100: 100%|██████████| 18/18 [00:00<00:00, 713.80it/s]


Epoch [51/100], Training Loss: 0.5124


Epoch 52/100: 100%|██████████| 18/18 [00:00<00:00, 559.17it/s]
Epoch 53/100: 100%|██████████| 18/18 [00:00<00:00, 554.83it/s]
Epoch 54/100: 100%|██████████| 18/18 [00:00<00:00, 739.53it/s]
Epoch 55/100: 100%|██████████| 18/18 [00:00<00:00, 369.38it/s]
Epoch 56/100: 100%|██████████| 18/18 [00:00<00:00, 372.34it/s]
Epoch 57/100: 100%|██████████| 18/18 [00:00<00:00, 558.24it/s]
Epoch 58/100: 100%|██████████| 18/18 [00:00<00:00, 450.13it/s]
Epoch 59/100: 100%|██████████| 18/18 [00:00<00:00, 397.00it/s]
Epoch 60/100: 100%|██████████| 18/18 [00:00<00:00, 450.70it/s]
Epoch 61/100: 100%|██████████| 18/18 [00:00<00:00, 542.51it/s]


Epoch [61/100], Training Loss: 0.5080


Epoch 62/100: 100%|██████████| 18/18 [00:00<00:00, 555.33it/s]
Epoch 63/100: 100%|██████████| 18/18 [00:00<00:00, 559.94it/s]
Epoch 64/100: 100%|██████████| 18/18 [00:00<00:00, 561.57it/s]
Epoch 65/100: 100%|██████████| 18/18 [00:00<00:00, 741.29it/s]
Epoch 66/100: 100%|██████████| 18/18 [00:00<00:00, 551.81it/s]
Epoch 67/100: 100%|██████████| 18/18 [00:00<00:00, 562.86it/s]
Epoch 68/100: 100%|██████████| 18/18 [00:00<00:00, 747.44it/s]
Epoch 69/100: 100%|██████████| 18/18 [00:00<00:00, 562.14it/s]
Epoch 70/100: 100%|██████████| 18/18 [00:00<00:00, 749.44it/s]
Epoch 71/100: 100%|██████████| 18/18 [00:00<00:00, 586.97it/s]


Epoch [71/100], Training Loss: 0.4954


Epoch 72/100: 100%|██████████| 18/18 [00:00<00:00, 444.38it/s]
Epoch 73/100: 100%|██████████| 18/18 [00:00<00:00, 436.96it/s]
Epoch 74/100: 100%|██████████| 18/18 [00:00<00:00, 374.35it/s]
Epoch 75/100: 100%|██████████| 18/18 [00:00<00:00, 502.46it/s]
Epoch 76/100: 100%|██████████| 18/18 [00:00<00:00, 561.29it/s]
Epoch 77/100: 100%|██████████| 18/18 [00:00<00:00, 435.08it/s]
Epoch 78/100: 100%|██████████| 18/18 [00:00<00:00, 400.93it/s]
Epoch 79/100: 100%|██████████| 18/18 [00:00<00:00, 373.78it/s]
Epoch 80/100: 100%|██████████| 18/18 [00:00<00:00, 447.92it/s]
Epoch 81/100: 100%|██████████| 18/18 [00:00<00:00, 444.42it/s]


Epoch [81/100], Training Loss: 0.4748


Epoch 82/100: 100%|██████████| 18/18 [00:00<00:00, 369.44it/s]
Epoch 83/100: 100%|██████████| 18/18 [00:00<00:00, 450.02it/s]
Epoch 84/100: 100%|██████████| 18/18 [00:00<00:00, 548.86it/s]
Epoch 85/100: 100%|██████████| 18/18 [00:00<00:00, 444.30it/s]
Epoch 86/100: 100%|██████████| 18/18 [00:00<00:00, 444.19it/s]
Epoch 87/100: 100%|██████████| 18/18 [00:00<00:00, 444.07it/s]
Epoch 88/100: 100%|██████████| 18/18 [00:00<00:00, 494.28it/s]
Epoch 89/100: 100%|██████████| 18/18 [00:00<00:00, 561.31it/s]
Epoch 90/100: 100%|██████████| 18/18 [00:00<00:00, 733.69it/s]
Epoch 91/100: 100%|██████████| 18/18 [00:00<00:00, 448.81it/s]


Epoch [91/100], Training Loss: 0.4798


Epoch 92/100: 100%|██████████| 18/18 [00:00<00:00, 558.82it/s]
Epoch 93/100: 100%|██████████| 18/18 [00:00<00:00, 552.84it/s]
Epoch 94/100: 100%|██████████| 18/18 [00:00<00:00, 449.54it/s]
Epoch 95/100: 100%|██████████| 18/18 [00:00<00:00, 557.38it/s]
Epoch 96/100: 100%|██████████| 18/18 [00:00<00:00, 748.63it/s]
Epoch 97/100: 100%|██████████| 18/18 [00:00<00:00, 543.48it/s]
Epoch 98/100: 100%|██████████| 18/18 [00:00<00:00, 449.90it/s]
Epoch 99/100: 100%|██████████| 18/18 [00:00<00:00, 559.90it/s]
Epoch 100/100: 100%|██████████| 18/18 [00:00<00:00, 551.23it/s]


Final Validation Loss with CrossEntropyLoss: 0.5468
Final Validation Accuracy with CrossEntropyLoss: 73.43%



### 5. KLDivLoss (`torch.nn.KLDivLoss`)
- **Description:** Kullback-Leibler Divergence Loss measures how one probability distribution diverges from a second, reference distribution. Unlike other loss functions that focus on classification, KL divergence specifically compares the relative entropy between two distributions. It quantifies the information loss when using the predicted distribution to approximate the true distribution. 

- **Mathematical Function:**
\begin{equation}
  \text{KL}(P \parallel Q) = \sum_{i=1}^{C} P(i) \left( \log P(i) - \log Q(i) \right)
\end{equation}
  where:
  - \( P \) is the target (true) probability distribution,
  - \( Q \) is the predicted distribution (often the output of `log_softmax`),
  - \( C \) is the number of classes.

  KL divergence is always non-negative, and it equals zero if the two distributions are identical. The loss function expects the model's output to be in the form of log-probabilities (using `log_softmax`) and compares this against a target probability distribution, which is typically a normalized distribution (using softmax).

- **Use Case:** KLDivLoss is frequently used in:
  - **Variational Autoencoders (VAEs):** In VAEs, KL divergence is used to measure how much the learned latent space distribution deviates from a prior distribution (often Gaussian).
  - **Knowledge Distillation:** In teacher-student models, KL divergence is used to transfer the "soft" knowledge from a teacher model to a student model by comparing their output probability distributions.
  - **Reinforcement Learning:** It can be used to update policies while minimizing the divergence from a previous policy.

- **Background:** Kullback-Leibler divergence, a core concept in information theory, measures the inefficiency of assuming the predicted distribution \( Q \) when the true distribution is \( P \). It is asymmetric, meaning that \( KL(P \parallel Q) \neq KL(Q \parallel P) \), so the direction of the comparison matters.

Again, in this part, run your training with Relu at last layer. <span style="color:red; font-weight: bold;">Discuss </span> and explain the difference between the results of the two models. Find a proper solution to the problem.


In [98]:
class SimpleMLP_KL(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers=1):
        super(SimpleMLP_KL, self).__init__()
        
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())

        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.LogSoftmax(dim=1))  
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class SimpleMLPTrainer_KL:
    def __init__(self, model, criterion, optimizer):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def train(self, train_loader, num_epochs):
        training_losses = []

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            self.model.train()
            for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                self.optimizer.zero_grad()
                outputs = self.model(inputs)  
                target_probs = torch.softmax(targets, dim=1)
                
                loss = self.criterion(outputs, target_probs)
                
                loss.backward()
                self.optimizer.step()
                
                epoch_loss += loss.item() * inputs.size(0)

            epoch_loss /= len(train_loader.dataset)
            training_losses.append(epoch_loss)
            if (epoch % 10 == 0):
                print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}")

        return training_losses

    def evaluate(self, val_loader):
        self.model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = self.model(inputs)  # Log probabilities

                # Convert targets to one-hot encoding and then to a probability distribution
                targets = F.one_hot(targets.squeeze().long(), num_classes=outputs.size(1)).float()
                
                # Calculate KLDivLoss
                loss = self.criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)

                # Convert log-probabilities to predicted class
                predictions = torch.argmax(outputs.exp(), dim=1)  # Convert to probabilities and get max
                correct_predictions += (predictions == targets.argmax(dim=1)).sum().item()
                total_predictions += targets.size(0)

        val_loss /= len(val_loader.dataset)
        accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0

        print(f"Final Validation Loss with KLDivLoss: {val_loss:.4f}")
        print(f"Final Validation Accuracy with KLDivLoss: {accuracy:.2f}%")
        return val_loss, accuracy


In [None]:
# Run with LogSoftmax activation function
from torch.nn import NLLLoss

# TODO: Train the model

# TODO: Evaluate the model

model = SimpleMLP_KL(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
criterion = nn.KLDivLoss(reduction='batchmean')  
optimizer = optim.Adam(model.parameters(), lr=0.001)

trainer = SimpleMLPTrainer_KL(model, criterion, optimizer)
training_losses = trainer.train(train_loader, num_epochs=80)
val_loss, val_accuracy = trainer.evaluate(val_loader)

Epoch 1/80: 100%|██████████| 18/18 [00:00<00:00, 478.64it/s]


Epoch [1/80], Training Loss: 4.0629


Epoch 2/80: 100%|██████████| 18/18 [00:00<00:00, 544.24it/s]
Epoch 3/80: 100%|██████████| 18/18 [00:00<00:00, 562.66it/s]
Epoch 4/80: 100%|██████████| 18/18 [00:00<00:00, 735.74it/s]
Epoch 5/80: 100%|██████████| 18/18 [00:00<00:00, 1106.06it/s]
Epoch 6/80: 100%|██████████| 18/18 [00:00<00:00, 1124.78it/s]
Epoch 7/80: 100%|██████████| 18/18 [00:00<00:00, 1125.94it/s]
Epoch 8/80: 100%|██████████| 18/18 [00:00<00:00, 1122.81it/s]
Epoch 9/80: 100%|██████████| 18/18 [00:00<00:00, 1125.03it/s]
Epoch 10/80: 100%|██████████| 18/18 [00:00<00:00, 760.99it/s]
Epoch 11/80: 100%|██████████| 18/18 [00:00<00:00, 1123.59it/s]


Epoch [11/80], Training Loss: 1.4018


Epoch 12/80: 100%|██████████| 18/18 [00:00<00:00, 1123.62it/s]
Epoch 13/80: 100%|██████████| 18/18 [00:00<00:00, 744.64it/s]
Epoch 14/80: 100%|██████████| 18/18 [00:00<00:00, 940.18it/s]
Epoch 15/80: 100%|██████████| 18/18 [00:00<00:00, 1375.38it/s]
Epoch 16/80: 100%|██████████| 18/18 [00:00<00:00, 1124.68it/s]
Epoch 17/80: 100%|██████████| 18/18 [00:00<00:00, 1088.27it/s]
Epoch 18/80: 100%|██████████| 18/18 [00:00<00:00, 737.73it/s]
Epoch 19/80: 100%|██████████| 18/18 [00:00<00:00, 1077.10it/s]
Epoch 20/80: 100%|██████████| 18/18 [00:00<00:00, 1150.91it/s]
Epoch 21/80: 100%|██████████| 18/18 [00:00<00:00, 1096.95it/s]


Epoch [21/80], Training Loss: 1.3946


Epoch 22/80: 100%|██████████| 18/18 [00:00<00:00, 1140.86it/s]
Epoch 23/80: 100%|██████████| 18/18 [00:00<00:00, 747.37it/s]
Epoch 24/80: 100%|██████████| 18/18 [00:00<00:00, 841.70it/s]
Epoch 25/80: 100%|██████████| 18/18 [00:00<00:00, 1155.60it/s]
Epoch 26/80: 100%|██████████| 18/18 [00:00<00:00, 1089.37it/s]
Epoch 27/80: 100%|██████████| 18/18 [00:00<00:00, 749.99it/s]
Epoch 28/80: 100%|██████████| 18/18 [00:00<00:00, 1099.95it/s]
Epoch 29/80: 100%|██████████| 18/18 [00:00<00:00, 1123.12it/s]
Epoch 30/80: 100%|██████████| 18/18 [00:00<00:00, 1070.57it/s]
Epoch 31/80: 100%|██████████| 18/18 [00:00<00:00, 548.66it/s]


Epoch [31/80], Training Loss: 1.3907


Epoch 32/80: 100%|██████████| 18/18 [00:00<00:00, 721.13it/s]
Epoch 33/80: 100%|██████████| 18/18 [00:00<00:00, 1135.56it/s]
Epoch 34/80: 100%|██████████| 18/18 [00:00<00:00, 1092.30it/s]
Epoch 35/80: 100%|██████████| 18/18 [00:00<00:00, 1106.51it/s]
Epoch 36/80: 100%|██████████| 18/18 [00:00<00:00, 1089.67it/s]
Epoch 37/80: 100%|██████████| 18/18 [00:00<00:00, 1159.61it/s]
Epoch 38/80: 100%|██████████| 18/18 [00:00<00:00, 1103.91it/s]
Epoch 39/80: 100%|██████████| 18/18 [00:00<00:00, 1090.64it/s]
Epoch 40/80: 100%|██████████| 18/18 [00:00<00:00, 741.88it/s]
Epoch 41/80: 100%|██████████| 18/18 [00:00<00:00, 1107.70it/s]


Epoch [41/80], Training Loss: 1.3889


Epoch 42/80: 100%|██████████| 18/18 [00:00<00:00, 560.73it/s]
Epoch 43/80: 100%|██████████| 18/18 [00:00<00:00, 554.92it/s]
Epoch 44/80: 100%|██████████| 18/18 [00:00<00:00, 750.11it/s]
Epoch 45/80: 100%|██████████| 18/18 [00:00<00:00, 750.12it/s]
Epoch 46/80: 100%|██████████| 18/18 [00:00<00:00, 1119.69it/s]
Epoch 47/80: 100%|██████████| 18/18 [00:00<00:00, 1065.33it/s]
Epoch 48/80: 100%|██████████| 18/18 [00:00<00:00, 748.37it/s]
Epoch 49/80: 100%|██████████| 18/18 [00:00<00:00, 1134.17it/s]
Epoch 50/80: 100%|██████████| 18/18 [00:00<00:00, 1037.97it/s]
Epoch 51/80: 100%|██████████| 18/18 [00:00<00:00, 757.00it/s]


Epoch [51/80], Training Loss: 1.3884


Epoch 52/80: 100%|██████████| 18/18 [00:00<00:00, 1092.27it/s]
Epoch 53/80: 100%|██████████| 18/18 [00:00<00:00, 721.44it/s]
Epoch 54/80: 100%|██████████| 18/18 [00:00<00:00, 1097.92it/s]
Epoch 55/80: 100%|██████████| 18/18 [00:00<00:00, 1107.28it/s]
Epoch 56/80: 100%|██████████| 18/18 [00:00<00:00, 1107.16it/s]
Epoch 57/80: 100%|██████████| 18/18 [00:00<00:00, 1157.69it/s]
Epoch 58/80: 100%|██████████| 18/18 [00:00<00:00, 367.56it/s]
Epoch 59/80: 100%|██████████| 18/18 [00:00<00:00, 756.66it/s]
Epoch 60/80: 100%|██████████| 18/18 [00:00<00:00, 1081.33it/s]
Epoch 61/80: 100%|██████████| 18/18 [00:00<00:00, 1089.60it/s]


Epoch [61/80], Training Loss: 1.3874


Epoch 62/80: 100%|██████████| 18/18 [00:00<00:00, 1122.14it/s]
Epoch 63/80: 100%|██████████| 18/18 [00:00<00:00, 1094.15it/s]
Epoch 64/80: 100%|██████████| 18/18 [00:00<00:00, 1122.86it/s]
Epoch 65/80: 100%|██████████| 18/18 [00:00<00:00, 1125.15it/s]
Epoch 66/80: 100%|██████████| 18/18 [00:00<00:00, 1125.05it/s]
Epoch 67/80: 100%|██████████| 18/18 [00:00<00:00, 1122.92it/s]
Epoch 68/80: 100%|██████████| 18/18 [00:00<00:00, 922.07it/s]
Epoch 69/80: 100%|██████████| 18/18 [00:00<00:00, 874.54it/s]
Epoch 70/80: 100%|██████████| 18/18 [00:00<00:00, 670.96it/s]
Epoch 71/80: 100%|██████████| 18/18 [00:00<00:00, 1124.09it/s]


Epoch [71/80], Training Loss: 1.3868


Epoch 72/80: 100%|██████████| 18/18 [00:00<00:00, 747.62it/s]
Epoch 73/80: 100%|██████████| 18/18 [00:00<00:00, 1083.24it/s]
Epoch 74/80: 100%|██████████| 18/18 [00:00<00:00, 1156.11it/s]
Epoch 75/80: 100%|██████████| 18/18 [00:00<00:00, 1157.81it/s]
Epoch 76/80: 100%|██████████| 18/18 [00:00<00:00, 750.05it/s]
Epoch 77/80: 100%|██████████| 18/18 [00:00<00:00, 734.10it/s]
Epoch 78/80: 100%|██████████| 18/18 [00:00<00:00, 1092.90it/s]
Epoch 79/80: 100%|██████████| 18/18 [00:00<00:00, 1096.04it/s]
Epoch 80/80: 100%|██████████| 18/18 [00:00<00:00, 1122.34it/s]

Final Validation Loss with KLDivLoss: 0.6939
Final Validation Accuracy with KLDivLoss: 53.15%





Your reason for your choice:

<div>
**Your answer here**
</div>

To use KLDivLoss effectively, we need to ensure that the model outputs log-probabilities (using log_softmax in the final layer) and that the targets are formatted as probability distributions (using softmax). Using KLDivLoss with ReLU in the final layer will not work as expected, because ReLU does not produce a normalized probability distribution.

With ReLU Activation: If using ReLU, the model cannot produce normalized probabilities, so KLDivLoss becomes meaningless because it compares distributions.

With LogSoftmax Activation: Using log_softmax produces log-probabilities, which allows KLDivLoss to calculate divergence between the predicted and target distributions effectively.

### 6. CosineEmbeddingLoss (`torch.nn.CosineEmbeddingLoss`)
- **Description:** Measures the cosine similarity between two input tensors, `x1` and `x2`, and computes the loss based on a label `y` that indicates whether the tensors should be similar (`y = 1`) or dissimilar (`y = -1`). Cosine similarity focuses on the angle between vectors, disregarding their magnitude.

- **Mathematical Function:** 
\begin{equation}
  \text{CosineEmbeddingLoss}(x1, x2, y) = 
  \begin{cases} 
  1 - \cos(x_1, x_2), & \text{if } y = 1 \\
  \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
  \end{cases}
\end{equation}
  where $ \cos(x_1, x_2) $ is the cosine similarity between the two vectors, and `margin` is a threshold that determines how dissimilar the vectors should be.

- **Use Case:** Commonly used in tasks like face verification, image similarity, and other scenarios where the relative orientation of vectors (angle) is more important than their length, such as in embeddings and metric learning.

- **Background:** Cosine similarity compares the directional alignment of vectors, making it ideal for high-dimensional data where the magnitude may not be as informative. This loss is particularly useful when training models to learn meaningful embeddings that capture semantic similarity.

You'll become more fimiliar with this loss function in future.

---

# Regularization in Machine Learning

## Introduction

Regularization is a fundamental technique in machine learning that helps prevent overfitting by adding a penalty to the loss function. This penalty discourages the model from becoming too complex, ensuring better generalization to unseen data. In this notebook, you will explore the concepts of regularization, understand different types of regularization techniques, and apply them using Python's popular libraries.

## What is Regularization?

Regularization involves adding a regularization term to the loss function used to train machine learning models. This term imposes a constraint on the model's coefficients, effectively reducing their magnitude. By doing so, regularization helps in:

- **Preventing Overfitting:** Ensures the model does not become too tailored to the training data.
- **Improving Generalization:** Enhances the model's performance on new, unseen data.
- **Feature Selection:** Especially in L1 regularization, it can drive some coefficients to zero, effectively selecting important features.

## Types of Regularization

There are several types of regularization techniques, each imposing different constraints on the model's parameters:

### 1. L1 Regularization (Lasso)

L1 regularization adds the absolute value of the magnitude of coefficients as a penalty term to the loss function. It can lead to sparse models where some feature coefficients are exactly zero.

### 2. L2 Regularization (Ridge)

L2 regularization adds the squared magnitude of coefficients as a penalty term to the loss function. It tends to shrink the coefficients evenly but does not set them to zero.

### 3. Elastic Net

Elastic Net combines both L1 and L2 regularization penalties. It balances the benefits of both Lasso and Ridge methods, allowing for feature selection and coefficient shrinkage.

## Homework Time!
Import Iris dataset from sklearn.datasets and apply ridge regression with different alpha values. Then, create a gif that shows the changes of the classification boundary with respect to alpha values.

Import the libs that you need and start coding!

In [104]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from PIL import Image
from io import BytesIO
import imageio
import warnings


# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

Load the Iris dataset and select Setosa and Versicolor classes

In [106]:
# 1. Load and Prepare the Iris Dataset

# Select only two classes for binary classification (Setosa and Versicolor)

# Select two features for 2D visualization (Sepal Length and Petal Length)

# Split into training and testing sets

iris = load_iris()
X = iris.data
y = iris.target

selected_classes = (y == 0) | (y == 1)
X = X[selected_classes]
y = y[selected_classes]

X = X[:, [0, 2]]  

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# Standardize the features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)


print("Training set shape:", X_train.shape)
print("Testing set shape:", X_test.shape)
print("Sample of standardized training features:\n", X_train[:5])
print("Sample of training labels:\n", y_train[:5])


Training set shape: (70, 2)
Testing set shape: (30, 2)
Sample of standardized training features:
 [[1.94114606 1.29689348]
 [1.62587872 1.22635537]
 [2.2564134  1.43796969]
 [1.94114606 1.08527916]
 [2.09877973 1.36743159]]
Sample of training labels:
 [1 1 1 1 1]


Define Function to Plot Decision Boundary

In [113]:
def plot_decision_boundary(model, X, y, alpha):
    # Define the grid over the feature space
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100))

    # Predict over the grid
    grid = np.c_[xx.ravel(), yy.ravel()]
    Z = model.predict(grid)
    Z = Z.reshape(xx.shape)

    # Create a figure
    fig, ax = plt.subplots(figsize=(6, 5))

    # Plot the decision boundary
    ax.contourf(xx, yy, Z, alpha=0.3, levels=[-0.1, 0.1, 1.1], colors=['blue', 'red'])

    # Scatter plot of the training data
    scatter = ax.scatter(
        X[:, 0], X[:, 1], c=y, cmap='bwr', edgecolor='k', s=50
    )

    # Title and labels
    ax.set_title(f'MLP Decision Boundary (alpha={alpha})')
    ax.set_xlabel('Sepal Length (standardized)')
    ax.set_ylabel('Petal Length (standardized)')

    # Remove axes for clarity
    ax.set_xticks([])
    ax.set_yticks([])

    # Tight layout
    plt.tight_layout()

    # Save the plot to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

Train MLP with Varying Alpha Values and Collect Images

In [116]:
def create_decision_boundary_gif(alpha_values, X_train, y_train, n_neurons):
    # List to store images
    images = []

    for idx, alpha in enumerate(alpha_values):
        print(f"Processing alpha={alpha:.4f} ({idx + 1}/{len(alpha_values)})")

        # Create and train the MLP
        mlp = MLPClassifier(hidden_layer_sizes=(n_neurons,), alpha=alpha, max_iter=1000, random_state=42)
        mlp.fit(X_train, y_train)

        # Plot decision boundary and get the image
        img = plot_decision_boundary(mlp, X_train, y_train, alpha)
        images.append(img)

    # Save the images as a GIF
    gif_filename = 'mlp_classification_boundaries_Rosa.gif'
    images[0].save(
        gif_filename,
        save_all=True,
        append_images=images[1:],
        duration=500,
        loop=0
    )

    print(f"GIF saved as '{gif_filename}'")

    return gif_filename


## RUN

In [117]:

# Use np.logspace to generate alpha values from 0.01 to 400
alpha_values = np.logspace(np.log10(0.01), np.log10(400), num=30)
# Define the number of neurons in the hidden layer
n_neurons = 10  # Example number of neurons, adjust as desired

# Create the decision boundary GIF
gif_filename = create_decision_boundary_gif(alpha_values, X_train, y_train, n_neurons)

Processing alpha=0.0100 (1/30)
Processing alpha=0.0144 (2/30)
Processing alpha=0.0208 (3/30)
Processing alpha=0.0299 (4/30)
Processing alpha=0.0431 (5/30)
Processing alpha=0.0622 (6/30)
Processing alpha=0.0896 (7/30)
Processing alpha=0.1291 (8/30)
Processing alpha=0.1860 (9/30)
Processing alpha=0.2681 (10/30)
Processing alpha=0.3863 (11/30)
Processing alpha=0.5567 (12/30)
Processing alpha=0.8022 (13/30)
Processing alpha=1.1561 (14/30)
Processing alpha=1.6660 (15/30)
Processing alpha=2.4009 (16/30)
Processing alpha=3.4599 (17/30)
Processing alpha=4.9861 (18/30)
Processing alpha=7.1854 (19/30)
Processing alpha=10.3548 (20/30)
Processing alpha=14.9223 (21/30)
Processing alpha=21.5043 (22/30)
Processing alpha=30.9897 (23/30)
Processing alpha=44.6591 (24/30)
Processing alpha=64.3578 (25/30)
Processing alpha=92.7456 (26/30)
Processing alpha=133.6549 (27/30)
Processing alpha=192.6090 (28/30)
Processing alpha=277.5673 (29/30)
Processing alpha=400.0000 (30/30)
GIF saved as 'mlp_classification_b

Your gif should look like this:

<div style="text-align: center;">

### **Multilayer Perceptron Classification Boundaries**

![Classification Boundaries](mlp_classification_boundaries_example.gif)

*Figure 1: Demonstration of classification boundaries created by a Multilayer Perceptron (MLP) model.*

</div>

