# Imports

In [None]:
import torch
from torch.utils.data import Subset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import time
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Global Variables

In [3]:
load_100 = False
num_classes = 100 if load_100 else 10
batch_size = 64
num_epochs = 10

# Data

In [None]:
if load_100:
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
else:
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_indices = list(range(len(trainset)))
train_labels = trainset.targets
train_idx, val_idx = train_test_split(train_indices, test_size=0.2, stratify=train_labels, random_state=0)

train_subset = Subset(trainset, train_idx)
val_subset = Subset(trainset, val_idx)

trainloader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
validationloader = torch.utils.data.DataLoader(val_subset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


# Optimizers

In [1]:
class SGDW(optim.SGD):
    def __init__(self, params, lr=0.01, momentum=0, dampening=0, weight_decay=0.01, nesterov=False):
        super(SGDW, self).__init__(params, lr=lr, momentum=momentum, dampening=dampening, weight_decay=0, nesterov=nesterov)
        self.weight_decay = weight_decay

    def step(self, closure=None):
        # apply weight decay to parameters before gradient step
        for group in self.param_groups:
            for param in group['params']:
                if param.grad is None:
                    continue
                param.data = param.data.add(-self.weight_decay * group['lr'], param.data)
        # apply step
        super(SGDW, self).step(closure)

NameError: name 'optim' is not defined

# Main Network

In [None]:
class VGG(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG, self).__init__()
        # input_shape: (batch_size,num_channels,image_width,image_height)
        #              (batch_size,3,32,32)
        # conv block 1
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1) # output shape: (batch_size,64,32,32), weights: (64,3,3,3), bias: (64)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) # output shape: (batch_size,64,32,32), weights: (64,64,3,3), bias: (64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # output shape: (batch_size,64,16,16)
        # conv block 2
        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)  # output shape: (batch_size,128,16,16), weights: (128,64,3,3), bias: (128)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) # output shape: (batch_size,128,16,16), weights: (128,128,3,3), bias: (128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # output shape: (batch_size,128,8,8)
        # conv block 3
        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1) # output shape: (batch_size, 256, 8, 8), weights: (256,128,3,3), bias: (256)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) # output shape: (batch_size, 256, 8, 8), weights: (256,256,3,3), bias: (256)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) # output shape: (batch_size, 256, 8, 8), weights: (256,256,3,3), bias: (256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # output shape: (batch_size,256,4,4)
        # conv block 4
        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) # output shape: (batch_size, 512, 4, 4), weights: (512,256,3,3), bias: (512)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1) # output shape: (batch_size, 512, 4, 4), weights: (512,512,3,3), bias: (512)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1) # output shape: (batch_size, 512, 4, 4), weights: (512,512,3,3), bias: (512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)  # output shape: (batch_size,512,2,2)
        # fc 1
        self.fc1 = nn.Linear(in_features=512*2*2, out_features=512) # output shape: (batch_size, 512), weights: (2048,512), bias: (512)
        # fc 2
        self.fc2 = nn.Linear(in_features=512, out_features=512) # output shape: (batch_size, 512), weights: (512,512), bias: (512)
        # fc 3
        self.fc3 = nn.Linear(in_features=512, out_features=num_classes) # output shape: (batch_size, 10), weights: (512,10), bias: (10)

    def forward(self, x):
        # convs
        x = self.pool1(F.relu(self.conv1_2(F.relu(self.conv1_1(x)))))
        x = self.pool2(F.relu(self.conv2_2(F.relu(self.conv2_1(x)))))
        x = self.pool3(F.relu(self.conv3_3(F.relu(self.conv3_2(F.relu(self.conv3_1(x)))))))
        x = self.pool4(F.relu(self.conv4_3(F.relu(self.conv4_2(F.relu(self.conv4_1(x)))))))
        # flatten
        x = x.view(x.size(0), -1)
        # fcs
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.5)
        x = self.fc3(x)
        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()

        self.backbone = models.resnet18(weights=None)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)

In [None]:
class MainTrainer:
    def __init__(self, device, trainloader, validationloader, testloader,
                 model, loss, optimizer_type, num_epochs,
                 momentum=0.9, weight_decay=0.01, rho=0.9, epsilon=1e-8,
                 lr_scheduler_factor=0.7, lr_scheduler_patience=3, early_stop_patience=5, early_stop_min_delta=0.001):
        self.device = device
        # data
        self.trainloader = trainloader
        self.validationloader = validationloader
        self.testloader = testloader
        # model
        self.model = model.to(self.device)
        # training
        self.num_epochs = num_epochs
        # loss, optimizer, learning rate scheduler
        self.criterion = loss
        self.momentum = momentum # momentum value for SGD, default 0.9
        self.weight_decay = weight_decay # weight decay in W optimizers
        self.rho = rho # decay rate for RMSProp, default 0.9
        self.epsilon = epsilon  # small constant for numerical stability, default 1e-8
        if optimizer_type == "adam":
            self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        elif optimizer_type == "adamw":
            self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=self.weight_decay)
        elif optimizer_type == "sgd":
            self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=self.momentum)
        elif optimizer_type == "sgdw":
            self.optimizer = SGDW(self.model.parameters(), lr=0.01, momentum=self.momentum, weight_decay=self.weight_decay)
        elif optimizer_type == "adagrad":
            self.optimizer = optim.Adagrad(self.model.parameters(), lr=0.01, eps=self.epsilon)
        elif optimizer_type == "rmsprop":
            self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.001, alpha=self.rho, eps=self.epsilon)
        self.lr_scheduler_factor = lr_scheduler_factor
        self.lr_scheduler_patience = lr_scheduler_patience
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min", factor=self.lr_scheduler_factor, patience=self.lr_scheduler_patience)
        # early stopping variables
        self.best_loss = float("inf")
        self.early_stop_count = 0
        self.early_stop_patience = early_stop_patience
        self.early_stop_min_delta = early_stop_min_delta
        # metrics
        self.training_losses = []
        self.training_accuracies = []
        self.validation_losses = []
        self.validation_accuracies = []
        self.epoch_times = []
        self.test_accuracy = 0

    def train(self):
        print("Starting training...")
        for epoch in range(self.num_epochs):
            # training
            start_time = time.time()
            running_loss = 0.0
            correct = 0
            total = 0
            self.model.train()
            for images, labels in self.trainloader:
                images, labels = images.to(self.device), labels.to(self.device)
                # forward pass
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                # backward pass and optimize
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # loss and accuracy
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            # training metrics
            epoch_loss = running_loss / len(self.trainloader)
            epoch_accuracy = 100 * correct / total
            epoch_duration = time.time() - start_time
            self.training_losses.append(epoch_loss)
            self.training_accuracies.append(epoch_accuracy)
            self.epoch_times.append(epoch_duration)

            # validation
            val_running_loss = 0.0
            val_correct = 0
            val_total = 0
            self.model.eval()
            with torch.no_grad():
                for val_images, val_labels in self.validationloader:
                    val_images, val_labels = val_images.to(self.device), val_labels.to(self.device)
                    # forward pass
                    val_outputs = self.model(val_images)
                    val_loss = self.criterion(val_outputs, val_labels)
                    # validation loss and accuracy
                    val_running_loss += val_loss.item()
                    _, val_predicted = val_outputs.max(1)
                    val_total += val_labels.size(0)
                    val_correct += val_predicted.eq(val_labels).sum().item()
            # validation metrics
            val_epoch_loss = val_running_loss / len(self.validationloader)
            val_epoch_accuracy = 100 * val_correct / val_total
            self.validation_losses.append(val_epoch_loss)
            self.validation_accuracies.append(val_epoch_accuracy)

            # plot metrics
            self.plot_metrics()
            # scheduler step based on validation loss
            self.scheduler.step(val_epoch_loss)
            # early stopping
            if val_epoch_loss < self.best_loss - self.early_stop_min_delta:
                self.best_loss = val_epoch_loss
                self.early_stop_count = 0
            else:
                self.early_stop_count += 1
                print(f"Early stopping patience count: {self.early_stop_count}/{self.early_stop_patience}")
                if self.early_stop_count >= self.early_stop_patience:
                    print("Early stopping triggered.")
                    break

    def test(self):
        print("Starting testing...")
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in self.testloader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        self.test_accuracy = 100 * correct / total
        print(f"Test Accuracy of the model: {self.test_accuracy:.2f}%")

    def plot_metrics(self):
        clear_output(wait=True)
        plt.figure(figsize=(16, 10))

        # training loss
        plt.subplot(3, 2, 1)
        plt.plot(range(1, len(self.training_losses) + 1), self.training_losses, marker="o", linestyle="-", color="b")
        plt.title("Training Loss Convergence")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        # validation loss
        plt.subplot(3, 2, 2)
        plt.plot(range(1, len(self.validation_losses) + 1), self.validation_losses, marker="o", linestyle="-", color="orange")
        plt.title("Validation Loss Convergence")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        # training accuracy
        plt.subplot(3, 2, 3)
        plt.plot(range(1, len(self.training_accuracies) + 1), self.training_accuracies, marker="o", linestyle="-", color="g")
        plt.title("Training Accuracy per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        # validation accuracy
        plt.subplot(3, 2, 4)
        plt.plot(range(1, len(self.validation_accuracies) + 1), self.validation_accuracies, marker="o", linestyle="-", color="purple")
        plt.title("Validation Accuracy per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        # total runtime
        plt.subplot(3, 2, (5, 6))
        total_times = [sum(self.epoch_times[:i + 1]) for i in range(len(self.epoch_times))]
        plt.plot(range(1, len(total_times) + 1), total_times, marker="o", linestyle="-", color="r")
        plt.title("Total Runtime per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Total Time (seconds)")
        plt.grid(True)

        plt.tight_layout()
        plt.show()

In [None]:
optimizers = ["adam", "sgd", "adagrad", "rmsprop"]
training_losses = []
validation_losses = []
training_accuracies = []
validation_accuracies = []
epoch_times = []
for optimizer in optimizers:
    success = False
    tries = 0
    while not success:
      model = MainNetwork(num_classes=num_classes)
      loss = nn.CrossEntropyLoss()
      trainer = MainTrainer(device, trainloader, validationloader, testloader, model, loss, optimizer, num_epochs)
      trainer.train()
      tries += 1
      if len(trainer.training_losses) >= 10:
        success = True
        print(f"{optimizer} success")
      else:
        print(f"{optimizer} fail")
      if tries > 5:
        print("Exceeded number of tries.")
    validation_losses.append(trainer.validation_losses)
    training_accuracies.append(trainer.training_accuracies)
    validation_accuracies.append(trainer.validation_accuracies)
    epoch_times.append(trainer.epoch_times)

In [None]:
class MainTrainer:
    def __init__(self, device, trainloader, validationloader, testloader,
                 model, loss, optimizer_type, num_epochs,
                 momentum=0.9, weight_decay=0.01, rho=0.9, epsilon=1e-8,
                 lr_scheduler_factor=0.7, lr_scheduler_patience=3, early_stop_patience=6, early_stop_min_delta=0.001):
        self.device = device
        # data
        self.trainloader = trainloader
        self.validationloader = validationloader
        self.testloader = testloader
        # model
        self.model = model.to(self.device)
        # training
        self.num_epochs = num_epochs
        # loss, optimizer, learning rate scheduler
        self.optimizer_type = optimizer_type
        self.criterion = loss
        self.momentum = momentum # momentum value for SGD, default 0.9
        self.weight_decay = weight_decay # weight decay in W optimizers
        self.rho = rho # decay rate for RMSProp, default 0.9
        self.epsilon = epsilon  # small constant for numerical stability, default 1e-8
        if optimizer_type == "adam":
            self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        elif optimizer_type == "adamw":
            self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=self.weight_decay)
        elif optimizer_type == "sgd":
            self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=self.momentum)
        elif optimizer_type == "sgdw":
            self.optimizer = SGDW(self.model.parameters(), lr=0.01, momentum=self.momentum, weight_decay=self.weight_decay)
        elif optimizer_type == "adagrad":
            self.optimizer = optim.Adagrad(self.model.parameters(), lr=0.01, eps=self.epsilon)
        elif optimizer_type == "rmsprop":
            self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.001, alpha=self.rho, eps=self.epsilon)
        self.lr_scheduler_factor = lr_scheduler_factor
        self.lr_scheduler_patience = lr_scheduler_patience
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min", factor=self.lr_scheduler_factor, patience=self.lr_scheduler_patience)
        # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, 
        #     T_0=10,       # number of epochs for the first restart
        #     T_mult=2,     # factor by which the number of epochs increases after each restart
        #     eta_min=1e-6  # minimum learning rate
        # )
        # early stopping variables
        self.check_early_stopping = False
        self.best_loss = float("inf")
        self.early_stop_count = 0
        self.early_stop_patience = early_stop_patience
        self.early_stop_min_delta = early_stop_min_delta
        # metrics
        self.training_losses = []
        self.training_accuracies = []
        self.validation_losses = []
        self.validation_accuracies = []
        self.epoch_times = []
        self.test_accuracy = 0

    def train(self):
        print("Starting training...")
        for epoch in range(self.num_epochs):
            # training
            start_time = time.time()
            running_loss = 0.0
            correct = 0
            total = 0
            self.model.train()
            for images, labels in self.trainloader:
                images, labels = images.to(self.device), labels.to(self.device)
                # forward pass
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                # backward pass and optimize
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # loss and accuracy
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            # training metrics
            epoch_loss = running_loss / len(self.trainloader)
            epoch_accuracy = 100 * correct / total
            epoch_duration = time.time() - start_time
            self.training_losses.append(epoch_loss)
            self.training_accuracies.append(epoch_accuracy)
            self.epoch_times.append(epoch_duration)

            # validation
            val_running_loss = 0.0
            val_correct = 0
            val_total = 0
            self.model.eval()
            with torch.no_grad():
                for val_images, val_labels in self.validationloader:
                    val_images, val_labels = val_images.to(self.device), val_labels.to(self.device)
                    # forward pass
                    val_outputs = self.model(val_images)
                    val_loss = self.criterion(val_outputs, val_labels)
                    # validation loss and accuracy
                    val_running_loss += val_loss.item()
                    _, val_predicted = val_outputs.max(1)
                    val_total += val_labels.size(0)
                    val_correct += val_predicted.eq(val_labels).sum().item()
            # validation metrics
            val_epoch_loss = val_running_loss / len(self.validationloader)
            val_epoch_accuracy = 100 * val_correct / val_total
            self.validation_losses.append(val_epoch_loss)
            self.validation_accuracies.append(val_epoch_accuracy)

            # plot metrics
            self.plot_metrics()
            # scheduler step based on validation loss
            self.scheduler.step(val_epoch_loss)
            # early stopping
            if self.check_early_stopping:
              if val_epoch_loss < self.best_loss - self.early_stop_min_delta:
                  self.best_loss = val_epoch_loss
                  self.early_stop_count = 0
              else:
                  self.early_stop_count += 1
                  print(f"Early stopping patience count: {self.early_stop_count}/{self.early_stop_patience}")
                  if self.early_stop_count >= self.early_stop_patience:
                      print("Early stopping triggered.")
                      break

    def test(self):
        print("Starting testing...")
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in self.testloader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        self.test_accuracy = 100 * correct / total
        print(f"Test Accuracy of the model: {self.test_accuracy:.2f}%")

    def plot_metrics(self):
        clear_output(wait=True)
        plt.figure(figsize=(16, 10))

        # training loss
        plt.subplot(3, 2, 1)
        plt.plot(range(1, len(self.training_losses) + 1), self.training_losses, marker="o", linestyle="-", color="b")
        plt.title("Training Loss Convergence")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        # validation loss
        plt.subplot(3, 2, 2)
        plt.plot(range(1, len(self.validation_losses) + 1), self.validation_losses, marker="o", linestyle="-", color="orange")
        plt.title("Validation Loss Convergence")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        # training accuracy
        plt.subplot(3, 2, 3)
        plt.plot(range(1, len(self.training_accuracies) + 1), self.training_accuracies, marker="o", linestyle="-", color="g")
        plt.title("Training Accuracy per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        # validation accuracy
        plt.subplot(3, 2, 4)
        plt.plot(range(1, len(self.validation_accuracies) + 1), self.validation_accuracies, marker="o", linestyle="-", color="purple")
        plt.title("Validation Accuracy per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        # total runtime
        plt.subplot(3, 2, (5, 6))
        total_times = [sum(self.epoch_times[:i + 1]) for i in range(len(self.epoch_times))]
        plt.plot(range(1, len(total_times) + 1), total_times, marker="o", linestyle="-", color="r")
        plt.title("Total Runtime per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Total Time (seconds)")
        plt.grid(True)

        plt.suptitle(self.optimizer_type)
        plt.tight_layout()
        plt.show()

In [None]:
optimizers = ["adam", "adamw", "sgd", "sgdw", "rmsprop"]
training_losses = []
validation_losses = []
training_accuracies = []
validation_accuracies = []
epoch_times = []
for optimizer_type in optimizers:
    model = ResNet(num_classes=num_classes)
    loss = nn.CrossEntropyLoss()
    trainer = MainTrainer(device, trainloader, validationloader, testloader, model, loss, optimizer_type, num_epochs)
    trainer.train()
    training_losses.append(trainer.training_losses)
    validation_losses.append(trainer.validation_losses)
    training_accuracies.append(trainer.training_accuracies)
    validation_accuracies.append(trainer.validation_accuracies)
    epoch_times.append(trainer.epoch_times)

In [None]:
colors = ["blue", "green", "red", "purple", "orange"]
fig, axs = plt.subplots(3, 2, figsize=(16, 10))

# training and validation losses
for i, optimizer_name in enumerate(optimizers):
    axs[0, 0].plot(training_losses[i], color=colors[i], label=optimizer_name)
    axs[0, 1].plot(validation_losses[i], color=colors[i], label=optimizer_name)
axs[0, 0].set_title("Training Loss")
axs[0, 0].set_xlabel("Epoch")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].legend()
axs[0, 1].set_title("Validation Loss")
axs[0, 1].set_xlabel("Epoch")
axs[0, 1].set_ylabel("Loss")
axs[0, 1].legend()

# training and validation accuracies
for i, optimizer_name in enumerate(optimizers):
    axs[1, 0].plot(training_accuracies[i], color=colors[i], label=optimizer_name)
    axs[1, 1].plot(validation_accuracies[i], color=colors[i], label=optimizer_name)
axs[1, 0].set_title("Training Accuracy")
axs[1, 0].set_xlabel("Epoch")
axs[1, 0].set_ylabel("Accuracy")
axs[1, 0].legend()
axs[1, 1].set_title("Validation Accuracy")
axs[1, 1].set_xlabel("Epoch")
axs[1, 1].set_ylabel("Accuracy")
axs[1, 1].legend()

# epoch times
for i, optimizer_name in enumerate(optimizers):
    axs[2, 0].plot(epoch_times[i], color=colors[i], label=optimizer_name)
axs[2, 0].set_title("Epoch Times")
axs[2, 0].set_xlabel("Epoch")
axs[2, 0].set_ylabel("Time (s)")
axs[2, 0].legend()

plt.tight_layout()
plt.show()

# MetaLayerConv3d

Number of trainable parameters:
* **Conv3d**:
    * **Weights**:  $\text{in\_channels} \times \text{out\_channels} \times \text{kernel size} = \text{num\_optimizers} \times \text{num\_optimizers} = \text{num\_optimizers}^2$
    * **Biases**: $\text{out\_channels} = 1$
    * **Total**: $\text{num\_optimizers}^2 + 1$
* **FC**:
    * **1st FC**:  $(\text{input\_features}\times \text{output\_features})+\text{output\_features}=\text{num\_optimizers}\times \text{out\_channels}\times \text{bias\_neurons}+\text{bias\_neurons}$
    * **2nd FC**: $(\text{input\_size}\times \text{output\_size})+\text{output\_size}=\text{bias\_neurons}\times \text{out\_channels}+\text{out\_channels}$
    * **Total**: $\left(\text{num\_optimizers}\times \text{out\_channels}\times \text{bias\_neurons} + \text{bias\_neurons} \right)+\left(\text{bias\_neurons}\times \text{out\_channels}+\text{out\_channels}\right)$

**Total**: 
$$\left[\text{num\_optimizers}^2 + 1\right] + \left[\left(\text{num\_optimizers}\times \text{out\_channels}\times \text{bias\_neurons} + \text{bias\_neurons}\right)+\left(\text{bias\_neurons}\times \text{out\_channels}+\text{out\_channels}\right)\right]$$

In [9]:
class MetaLayerConv3d(nn.Module):
    def __init__(self, num_optimizers, output_channels, input_channels, kernel_size, bias_neurons=64):
        super(MetaLayerConv3d, self).__init__()
    
        self.conv3d = nn.Conv3d(
            in_channels=num_optimizers, 
            out_channels=1, 
            kernel_size=(num_optimizers, 1, 1), 
            stride=(1, 1, 1),
            padding=(1, 0, 0)
        )
        
        self.fc1 = nn.Linear(num_optimizers * output_channels, bias_neurons)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(bias_neurons, output_channels)
        
        self.output_channels = output_channels
        self.input_channels = input_channels
        self.kernel_size = kernel_size
        
    def forward(self, weight_gradients, bias_gradients):
        # --- Weights Gradients ---
        # expect weight_gradients input of shape (num_optimizers, output_channels, input_channels, kernel_size, kernel_size)
        # add a batch dimension: shape becomes (1, num_optimizers, output_channels, input_channels, kernel_size, kernel_size)
        weight_gradients = weight_gradients.unsqueeze(0)
        # flatten last two dimensions to get shape (1, num_optimizers, output_channels, input_channels, kernel_size * kernel_size)
        weight_gradients = weight_gradients.view(1, weight_gradients.size(1), weight_gradients.size(2), weight_gradients.size(3), -1)
        # pass through Conv3d to combine optimizer channels, shape becomes (1, 1, output_channels, input_channels, kernel_size * kernel_size)
        conv_output = self.conv3d(weight_gradients)
        # remove batch dimension and channel
        weight_output = conv_output.squeeze(0).squeeze(0)
        # unflatten to (output_channels, input_channels, kernel_size, kernel_size)
        weight_output = weight_output.view(self.output_channels, self.input_channels, self.kernel_size, self.kernel_size)

        # --- Bias Gradients ---
        # expect bias_gradients input of shape (num_optimizers, output_channels)
        # flatten to shape (num_optimizers * output_channels)
        bias_gradients = bias_gradients.view(-1)
        # pass through the fully connected layers
        fc_output = self.fc1(bias_gradients)
        fc_output = self.relu(fc_output)
        bias_output = self.fc2(fc_output)
    
        return weight_output, bias_output

# MetaLayerConv2d

Number of trainable parameters:
* **Conv2d**:
    * **Weights**:  $\text{in\_features} \times \text{out\_features} \times \text{kernel size} = \text{num\_optimizers}$
    * **Biases**: $\text{out\_features} = 1$
    * **Total**: $\text{num\_optimizers} + 1$
* **FC**:
    * **1st FC**:  $\text{num\_optimizers}\times \text{out\_features} \times \text{bias\_neurons}+\text{bias\_neurons}$
    * **2nd FC**: $\text{bias\_neurons}\times \text{out\_features}+\text{out\_features}$
    * **Total**: $\left(\text{num\_optimizers}\times \text{out\_features} \times \text{bias\_neurons}+\text{bias\_neurons}\right)+\left(\text{bias\_neurons}\times \text{out\_features}+\text{out\_features}\right)$

**Total**: 
$$\left[\text{num\_optimizers} + 1\right] + \left[\left(\text{num\_optimizers}\times \text{out\_features} \times \text{bias\_neurons}+\text{bias\_neurons}\right)+\left(\text{bias\_neurons}\times \text{out\_features}+\text{out\_features}\right)\right]$$

In [10]:
class MetaLayerConv2d(nn.Module):
    def __init__(self, num_optimizers, out_features, bias_neurons=64):
        super(MetaLayerConv2d, self).__init__()
    
        self.conv2d = nn.Conv2d(
            in_channels=num_optimizers, 
            out_channels=1, 
            kernel_size=(1, 1), 
            stride=1,
            padding=0
        )
        
        self.fc1 = nn.Linear(num_optimizers * out_features, bias_neurons)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(bias_neurons, out_features)
        
    def forward(self, weight_gradients, bias_gradients):
        # --- Weights Gradients ---
        # expect weight_gradients input of shape (num_optimizers, in_features, out_features)
        # add a batch dimension: shape becomes (1, num_optimizers, in_features, out_features)
        weight_gradients = weight_gradients.unsqueeze(0)
        # pass through Conv2d to combine optimizer channels, shape becomes (1, 1, in_features, out_features)
        conv_output = self.conv2d(weight_gradients)
        # remove batch dimension and channel
        weight_output = conv_output.squeeze(0).squeeze(0)

        # --- Bias Gradients ---
        # expect bias_gradients input of shape (num_optimizers, in_features)
        # flatten to shape (num_optimizers * in_features)
        bias_gradients = bias_gradients.view(-1)
        # pass through the fully connected layers
        fc_output = self.fc1(bias_gradients)
        fc_output = self.relu(fc_output)
        bias_output = self.fc2(fc_output)
    
        return weight_output, bias_output

# Main Meta Network

In [14]:
def initialize_meta_layer_nns(main_model, num_optimizers):
    meta_layer_nns = []
    for layer in main_model.children():
        if isinstance(layer, nn.Conv2d):
            meta_layer = MetaLayerConv3d(
                num_optimizers=num_optimizers,
                output_channels=layer.out_channels,
                input_channels=layer.in_channels,
                kernel_size=layer.kernel_size[0],
                bias_neurons=64
            )
        elif isinstance(layer, nn.Linear):
            meta_layer = MetaLayerConv2d(
                num_optimizers=num_optimizers,
                out_features=layer.out_features,
                bias_neurons=64
            )
        else:
            continue
        meta_layer_nns.append(meta_layer)
    return nn.ModuleList(meta_layer_nns)

In [None]:
class Optimizer:
    def __init__(self, model_parameters, optimizer_type="adam", lr=0.001, momentum=0.9, rho=0.9, epsilon=1e-8):
        self.model_parameters = model_parameters
        self.lr = lr
        self.momentum = momentum
        self.rho = rho
        self.epsilon = epsilon
        self.multiply_lr = True
        self.optimizer, self.optimizer_function = self._initialize_optimizer(optimizer_type)

    def _initialize_optimizer(self, optimizer_type):
        """Initializes optimizer and corresponding gradient function based on optimizer type."""
        if optimizer_type == "adam":
            optimizer = optim.Adam(self.model_parameters, lr=self.lr)
            optimizer_function = self.adam_gradients
            self.multiply_lr = True
        elif optimizer_type == "sgd":
            optimizer = optim.SGD(self.model_parameters, lr=self.lr, momentum=self.momentum)
            optimizer_function = self.sgd_momentum_gradients
            self.multiply_lr = False
        elif optimizer_type == "adagrad":
            optimizer = optim.Adagrad(self.model_parameters, lr=self.lr, eps=self.epsilon)
            optimizer_function = self.adagrad_gradients
            self.multiply_lr = True
        elif optimizer_type == "rmsprop":
            optimizer = optim.RMSprop(self.model_parameters, lr=self.lr, alpha=self.rho, eps=self.epsilon)
            optimizer_function = self.rmsprop_gradients
            self.multiply_lr = True
        else:
            raise ValueError(f"Optimizer type '{optimizer_type}' is not supported.")
        return optimizer, optimizer_function

    def sgd_momentum_gradients(self):
        optimizer_gradients = []
        for param in self.model_parameters:
            if param.grad is not None:
                # initialize velocity in optimizer state if not already present
                if "velocity" not in self.optimizer.state[param]:
                    self.optimizer.state[param]["velocity"] = torch.zeros_like(param.grad)
                # retrieve the current velocity and momentum factor
                velocity = self.optimizer.state[param]["velocity"]
                lr = self.optimizer.defaults["lr"]
                # update the velocity
                velocity.mul_(self.momentum).add_(param.grad, alpha=lr)  # v_t = gamma * v_{t-1} + eta * g_t
                # append the updated velocity (momentum-adjusted gradient)
                optimizer_gradients.append(velocity.clone())  # clone to avoid in-place modifications
        return optimizer_gradients

    def adagrad_gradients(self):
        optimizer_gradients = []
        for param in self.model_parameters:
            if param.grad is not None:
                # initialize accumulated sum of squared gradients if not present
                if "sum_sq_grads" not in self.optimizer.state[param]:
                    self.optimizer.state[param]["sum_sq_grads"] = torch.zeros_like(param.grad)
                # retrieve the accumulated squared gradients
                sum_sq_grads = self.optimizer.state[param]["sum_sq_grads"]
                # accumulate the squared gradients
                sum_sq_grads.addcmul_(param.grad, param.grad)  # G_t = G_{t-1} + g_t^2
                # compute AdaGrad-adjusted gradient
                adagrad_adjusted_grad = param.grad / (sum_sq_grads.sqrt() + self.epsilon)
                # append the adjusted gradient (with learning rate scaling)
                optimizer_gradients.append(adagrad_adjusted_grad.clone())

        return optimizer_gradients

    def rmsprop_gradients(self):
        optimizer_gradients = []
        for param in self.model_parameters:
            if param.grad is not None:
                # initialize the moving average of squared gradients if not present
                if "square_avg" not in self.optimizer.state[param]:
                    self.optimizer.state[param]["square_avg"] = torch.zeros_like(param.grad)
                # retrieve the moving average of squared gradients
                square_avg = self.optimizer.state[param]["square_avg"]
                # update the moving average of squared gradients
                square_avg.mul_(self.rho).addcmul_(1 - self.rho, param.grad, param.grad)
                # compute RMSprop-adjusted gradient
                rmsprop_adjusted_grad = param.grad / (square_avg.sqrt() + self.epsilon)
                # append the adjusted gradient with learning rate scaling
                optimizer_gradients.append(rmsprop_adjusted_grad.clone())
        return optimizer_gradients

    def adam_gradients(self):
        optimizer_gradients = []
        for param in self.model_parameters:
            if param.grad is not None:
                # initialize exp_avg, exp_avg_sq, and step in optimizer state if not already done
                if "exp_avg" not in self.optimizer.state[param]:
                    self.optimizer.state[param]["exp_avg"] = torch.zeros_like(param.grad)
                if "exp_avg_sq" not in self.optimizer.state[param]:
                    self.optimizer.state[param]["exp_avg_sq"] = torch.zeros_like(param.grad)
                if "step" not in self.optimizer.state[param]:
                    self.optimizer.state[param]["step"] = 0
                # retrieve optimizer state variables
                exp_avg = self.optimizer.state[param]["exp_avg"]
                exp_avg_sq = self.optimizer.state[param]["exp_avg_sq"]
                step = self.optimizer.state[param]["step"]
                beta1, beta2 = self.optimizer.defaults["betas"]
                eps = self.optimizer.defaults["eps"]
                # increment step count
                step += 1
                self.optimizer.state[param]["step"] = step
                # update exp_avg and exp_avg_sq according to Adam rules
                exp_avg.mul_(beta1).add_(param.grad, alpha=(1 - beta1))  # update first moment
                exp_avg_sq.mul_(beta2).addcmul_(param.grad, param.grad, value=(1 - beta2))  # update second moment
                # compute bias-corrected first and second moments
                exp_avg_corrected = exp_avg / (1 - beta1 ** step)
                exp_avg_sq_corrected = exp_avg_sq / (1 - beta2 ** step)
                # calculate the optimizer-adjusted gradient
                optimizer_grad = exp_avg_corrected / (exp_avg_sq_corrected.sqrt() + eps)
                optimizer_gradients.append(optimizer_grad.clone())  # clone to avoid in-place modifications
        return optimizer_gradients

In [None]:
main_model = MainNetwork(num_classes=10)
main_optimizers = [
    Optimizer(main_model.parameters(), optimizer_type="adam", lr=0.001, epsilon=1e-8).optimizer,
    Optimizer(main_model.parameters(), optimizer_type="sgd", lr=0.01, momentum=0.9).optimizer,
    Optimizer(main_model.parameters(), optimizer_type="adagrad", lr=0.01, epsilon=1e-8).optimizer,
    Optimizer(main_model.parameters(), optimizer_type="rmsprop", lr=0.001, rho=0.9, epsilon=1e-8).optimizer
]
meta_layer_nns = initialize_meta_layer_nns(main_model, len(main_optimizers))
meta_optimizers = [Optimizer(meta_layer_nn.parameters(), optimizer_type="adam", lr=0.001, epsilon=1e-8).optimizer for meta_layer_nn in meta_layer_nns]

In [None]:
def count_model_params(model, model_name="Model"):
    table_data = []
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            param_count = param.numel()
            total_params += param_count
            table_data.append({"Model": model_name, "Layer Name": name, "Number of Parameters": param_count})
    return total_params, table_data

main_model_total, full_table_data = count_model_params(main_model, model_name="Main Model")
total_meta_nns = 0
for i, meta_layer_nn in enumerate(meta_layer_nns):
    meta_total, meta_table_data = count_model_params(meta_layer_nn, model_name=f"Meta Layer NN {i+1}")
    full_table_data.extend(meta_table_data)
    total_meta_nns += meta_total

df = pd.DataFrame(full_table_data)
df.loc["Main Total"] = ["Main Model", "-", main_model_total]
df.loc["All Meta Layers Total"] = ["Meta Layers", "-", total_meta_nns]
df.loc["Overall Total"] = ["Main Model + Meta Layers", "-", main_model_total + total_meta_nns]
df.loc["Increase %"] = ["Overall Total / Main Total", "-", int(((main_model_total + total_meta_nns) / main_model_total - 1)*100)]
df


Unnamed: 0,Model,Layer Name,Number of Parameters
0,Main Model,conv1_1.weight,1728
1,Main Model,conv1_1.bias,64
2,Main Model,conv1_2.weight,36864
3,Main Model,conv1_2.bias,64
4,Main Model,conv2_1.weight,73728
...,...,...,...
103,Meta Layer NN 13,fc2.bias,10
Main Total,Main Model,-,8952138
All Meta Layers Total,Meta Layers,-,719237
Overall Total,Main Model + Meta Layers,-,9671375


In [None]:
class MetaTrainer:
    def __init__(self, device, trainloader, validationloader, testloader, 
                 main_model, meta_layer_nns, 
                 loss, 
                 main_optimizers, meta_optimizers, 
                 num_epochs,
                 lr_scheduler_factor=0.5, lr_scheduler_patience=2, 
                 early_stop_patience=5, early_stop_min_delta=0.001):
        self.device = device
        # data
        self.trainloader = trainloader
        self.validationloader = validationloader
        self.testloader = testloader
        # models
        self.main_model = main_model.to(device)
        self.meta_layer_nns = meta_layer_nns
        for i, meta_layer_nn in enumerate(self.meta_layer_nns):
            self.meta_layer_nns[i] = meta_layer_nn.to(device)
        # loss, optimizers, meta optimizers, learning rate scheduler
        self.criterion = loss
        self.main_optimizers = main_optimizers
        self.meta_optimizers = meta_optimizers
        self.main_schedulers = [ReduceLROnPlateau(optimizer, mode="min", factor=lr_scheduler_factor, patience=lr_scheduler_patience) for optimizer in self.main_optimizers]
        self.meta_schedulers = [ReduceLROnPlateau(optimizer, mode="min", factor=lr_scheduler_factor, patience=lr_scheduler_patience) for optimizer in self.meta_optimizers]
        # training
        self.num_epochs = num_epochs
        # early stopping variables
        self.best_loss = float("inf")
        self.early_stop_count = 0
        self.early_stop_patience = early_stop_patience
        self.early_stop_min_delta = early_stop_min_delta
        # metrics
        self.training_losses = []
        self.training_accuracies = []
        self.validation_losses = []
        self.validation_accuracies = []
        self.epoch_times = []
        self.test_accuracy = 0

    def train(self):
        print("Starting meta training...")
        for epoch in range(self.num_epochs):
            start_time = time.time()
            running_loss = 0.0
            correct = 0
            total = 0

            # main model training
            self.main_model.train()
            for meta_layer_nn in self.meta_layer_nns:
                meta_layer_nn.train()
            for images, labels in self.trainloader:
                images, labels = images.to(self.device), labels.to(self.device)
                # step 1: forward pass and loss computation for the main model
                outputs = self.main_model(images)
                loss = self.criterion(outputs, labels)
                # step 2: compute raw gradients
                raw_gradients = torch.autograd.grad(loss, self.main_model.parameters(), create_graph=True)
                # step 3: preprocess gradients using each optimizer
                gradients_optimizers = []
                for optimizer_name, optimizer in self.optimizers.items():
                    optimizer_gradients = [grad.clone() for grad in raw_gradients]  # clone to avoid in-place modification
                    gradients_optimizers.append(optimizer_gradients)
                # step 4: pass each layer's gradients through its meta-layer NN
                meta_layer_outputs = []
                for i, (meta_layer_nn, weight_grad, bias_grad) in enumerate(zip(self.meta_layer_nns, gradients_optimizers[0], gradients_optimizers[1])):
                    # Get outputs for weights and biases
                    weight_output, bias_output = meta_layer_nn(weight_grad, bias_grad)
                    meta_layer_outputs.append((weight_output, bias_output))
                
                # Step 5: Update the main model's parameters using the meta-layer NN output
                with torch.no_grad():
                    for param, (weight_update, bias_update) in zip(self.main_model.parameters(), meta_layer_outputs):
                        param.grad = weight_update if param.grad is None else param.grad.add(weight_update)

                # Step 6: Perform backward pass for meta neural networks
                meta_loss = loss  # Assuming meta loss is based on the primary model's loss
                for i, meta_layer_nn in enumerate(self.meta_layer_nns):
                    self.meta_optimizers[i].zero_grad()
                    meta_loss.backward(retain_graph=True)
                    self.meta_optimizers[i].step()

                # Step 7: Track training loss and accuracy for this batch
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            # Calculate metrics for the epoch
            epoch_loss = running_loss / len(self.trainloader)
            epoch_accuracy = 100 * correct / total
            epoch_duration = time.time() - start_time
            self.training_losses.append(epoch_loss)
            self.training_accuracies.append(epoch_accuracy)
            self.epoch_times.append(epoch_duration)

            # Validation phase to monitor performance on validation data
            val_loss, val_accuracy = self.validate()
            self.validation_losses.append(val_loss)
            self.validation_accuracies.append(val_accuracy)
            
            # Early stopping
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                self.early_stop_count = 0
            else:
                self.early_stop_count += 1
                if self.early_stop_count >= self.patience:
                    print("Early stopping triggered.")
                    break

            # Print epoch summary
            print(f"Epoch [{epoch+1}/{self.num_epochs}] "
                  f"Training Loss: {epoch_loss:.4f} "
                  f"Training Acc: {epoch_accuracy:.2f}% "
                  f"Validation Loss: {val_loss:.4f} "
                  f"Validation Acc: {val_accuracy:.2f}% "
                  f"Time: {epoch_duration:.2f}s")

    def validate(self):
        self.main_model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in self.validationloader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Forward pass
                outputs = self.main_model(images)
                loss = self.criterion(outputs, labels)
                running_loss += loss.item()

                # Calculate accuracy
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        # Calculate validation metrics
        val_loss = running_loss / len(self.validationloader)
        val_accuracy = 100 * correct / total
        return val_loss, val_accuracy

    def test(self):
        self.main_model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in self.testloader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Forward pass
                outputs = self.main_model(images)
                loss = self.criterion(outputs, labels)
                running_loss += loss.item()

                # Calculate accuracy
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        # Calculate test metrics
        test_loss = running_loss / len(self.testloader)
        test_accuracy = 100 * correct / total
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
        return test_loss, test_accuracy

In [None]:
main_model = initialize_main_model()
optimizers = {
    "adam": initialize_adam_optimizer(),
    "sgd": initialize_sgd_optimizer()
}
meta_layer_nns = initialize_meta_layer_nns(main_model, len(optimizers))
meta_optimizers = [initialize_adam_optimizer() for _ in range(len(meta_layer_nns))]

for epoch in range(total_epochs):
    for batch in training_data:
        # step 1: forward pass and loss computation for the main model
        predictions = main_model.forward(batch.input)
        loss = compute_loss(predictions, batch.labels)
        # step 2: compute raw gradients
        raw_gradients = compute_gradients(loss, main_model.parameters)
        # step 3: preprocess gradients using optimizers
        gradients_optimizers = []
        for optimizer in optimizers.values:
            gradients_optimizer = optimizer.preprocess_gradients(raw_gradients)
            gradients_optimizers.append(gradients_optimizer)
        
        # step 4: here we have to iterate through each layer's gradients
        # as far as I undestand, gradients_optimizers is of shape [num_optimizers, gradients]
        # where gradients store weight, biases pairs for each layer
        # something like [weights_layer1, biases_layer1, weights_layer2, biases_layer2, etc.]
        # so we have to iterate for each layer, pass each layer gradients for weights and biases to according meta_layer_nn
        # we store the outputs of each meta_layer_nn
        for meta_layer_nn in meta_layer_nns:
            # ...
        meta_layer_nns_output = # combined outputs of each meta_layer_nn, which is of shape same as gradients_optimizer
        
        # step 5: update the primary model's parameters using the Meta NN output
        primary_model.update_parameters(meta_layer_nns_output)
        
        # step 6: compute the loss for the Meta NN (based on primary model loss)
        meta_loss = compute_meta_loss(primary_model.loss)
        
        # step 7: backward pass for each Meta layer NN and update its parameters
        for i, meta_layer_nn in enumerate(meta_layer_nns):
            meta_layer_nn.backward(meta_loss)
            meta_optimizers[i].step()
            meta_optimizers[i].zero_grad()

        # step 9: reset gradients for the primary model's optimizers
        for optimizer in optimizers:
            optimizer.zero_grad()
            optimizer.zero_grad()

In [None]:
[
    torch.Tensor(64, 3, 3, 3),   # gradient of conv1_1.weight
    torch.Tensor(64),            # gradient of conv1_1.bias
    torch.Tensor(64, 64, 3, 3),  # gradient of conv1_2.weight
    torch.Tensor(64),            # gradient of conv1_2.bias
    torch.Tensor(128, 64, 3, 3), # gradient of conv2_1.weight
    torch.Tensor(128),           # gradient of conv2_1.bias
    torch.Tensor(128, 128, 3, 3),# gradient of conv2_2.weight
    torch.Tensor(128),           # gradient of conv2_2.bias
    torch.Tensor(256, 128, 3, 3),# gradient of conv3_1.weight
    torch.Tensor(256),           # gradient of conv3_1.bias
    torch.Tensor(256, 256, 3, 3),# gradient of conv3_2.weight
    torch.Tensor(256),           # gradient of conv3_2.bias
    torch.Tensor(256, 256, 3, 3),# gradient of conv3_3.weight
    torch.Tensor(256),           # gradient of conv3_3.bias
    torch.Tensor(512, 256, 3, 3),# gradient of conv4_1.weight
    torch.Tensor(512),           # gradient of conv4_1.bias
    torch.Tensor(512, 512, 3, 3),# gradient of conv4_2.weight
    torch.Tensor(512),           # gradient of conv4_2.bias
    torch.Tensor(512, 512, 3, 3),# gradient of conv4_3.weight
    torch.Tensor(512),           # gradient of conv4_3.bias
    torch.Tensor(4096, 512),     # gradient of fc1.weight
    torch.Tensor(512),           # gradient of fc1.bias
    torch.Tensor(512, 512),      # gradient of fc2.weight
    torch.Tensor(512),           # gradient of fc2.bias
    torch.Tensor(10, 512),       # gradient of fc3.weight
    torch.Tensor(10),            # gradient of fc3.bias
]