In [None]:


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Define the shallow model
class ShallowNetwork(nn.Module):
    def __init__(self):
        super(ShallowNetwork, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the deep model
class DeepNetwork(nn.Module):
    def __init__(self):
        super(DeepNetwork, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
test_dataset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the training function
def train(model, device, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        _, predicted = torch.max(output, 1)
        correct += (predicted == target).sum().item()
        total += target.size(0)
    accuracy = correct / total
    return total_loss / len(loader), accuracy

# Define the testing function
def test(model, device, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            _, predicted = torch.max(output, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    accuracy = correct / total
    return total_loss / len(loader), accuracy

# Train and test the models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
shallow_model = ShallowNetwork().to(device)
deep_model = DeepNetwork().to(device)

criterion = nn.CrossEntropyLoss()
shallow_optimizer = optim.SGD(shallow_model.parameters(), lr=0.01)
deep_optimizer = optim.SGD(deep_model.parameters(), lr=0.01)

shallow_train_accuracies = []
shallow_test_accuracies = []
deep_train_accuracies = []
deep_test_accuracies = []

for epoch in range(10):
    shallow_train_loss, shallow_train_acc = train(shallow_model, device, train_loader, shallow_optimizer, criterion)
    shallow_test_loss, shallow_test_acc = test(shallow_model, device, test_loader, criterion)
    deep_train_loss, deep_train_acc = train(deep_model, device, train_loader, deep_optimizer, criterion)
    deep_test_loss, deep_test_acc = test(deep_model, device, test_loader, criterion)
    shallow_train_accuracies.append(shallow_train_acc)
    shallow_test_accuracies.append(shallow_test_acc)
    deep_train_accuracies.append(deep_train_acc)
    deep_test_accuracies.append(deep_test_acc)
    print(f"Epoch {epoch+1}, Shallow Train Acc: {shallow_train_acc:.4f}, Shallow Test Acc: {shallow_test_acc:.4f}, Deep Train Acc: {deep_train_acc:.4f}, Deep Test Acc: {deep_test_acc:.4f}")

plt.plot
