In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from dpn_3.dpn import DPN
from utils import train

In [2]:
class MLP_MNIST(nn.Module):
    def __init__(self):
        super(MLP_MNIST, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 12),
            nn.ReLU(),
            nn.Linear(12, 12),
            nn.ReLU(),
            nn.Linear(12, 10),
        )

    def forward(self, x):
        return self.model(x)


In [3]:
#hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 30
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load and preprocess the entire MNIST dataset once
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)

num_classes = 10

# Normalize and flatten once
train_data = train_dataset.data.float().div(255).view(-1, 784).to(device)
train_labels = F.one_hot(train_dataset.targets.to(device), num_classes=num_classes).float()

test_data = test_dataset.data.float().div(255).view(-1, 784).to(device)
test_labels = F.one_hot(test_dataset.targets.to(device), num_classes=num_classes).float()

# Create DataLoader from preloaded GPU tensors
train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=batch_size, shuffle=False)

In [5]:
model_1 = MLP_MNIST().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_1.parameters(), lr=learning_rate)

In [6]:
train_metrics_1, val_metrics_1, test_metrics_1 = train(model_1, train_loader, test_loader, test_loader, epochs, optimizer, criterion)


Epoch: 1 Total_Time: 1.9327 Average_Time_per_batch: 0.0021 Train_Accuracy: 0.7867 Train_Loss: 0.7099 Validation_Accuracy: 0.8940 Validation_Loss: 0.3649
Epoch: 2 Total_Time: 1.7312 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.9080 Train_Loss: 0.3248 Validation_Accuracy: 0.9187 Validation_Loss: 0.2854
Epoch: 3 Total_Time: 1.7311 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.9244 Train_Loss: 0.2668 Validation_Accuracy: 0.9245 Validation_Loss: 0.2595
Epoch: 4 Total_Time: 1.6912 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.9331 Train_Loss: 0.2359 Validation_Accuracy: 0.9302 Validation_Loss: 0.2366
Epoch: 5 Total_Time: 1.7212 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.9385 Train_Loss: 0.2155 Validation_Accuracy: 0.9362 Validation_Loss: 0.2139
Epoch: 6 Total_Time: 1.6885 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.9424 Train_Loss: 0.2030 Validation_Accuracy: 0.9402 Validation_Loss: 0.1996
Epoch: 7 Total_Time: 1.7280 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.9450

In [7]:
model_2 = DPN(784, 34, 10).cuda()
optimizer = optim.Adam(model_2.parameters(), lr=learning_rate)

In [8]:
train_metrics_2, val_metrics_2, test_metrics_2 = train(model_2, train_loader, test_loader, test_loader, epochs, optimizer, criterion)


Epoch: 1 Total_Time: 1.3294 Average_Time_per_batch: 0.0014 Train_Accuracy: 0.8855 Train_Loss: 0.4343 Validation_Accuracy: 0.9213 Validation_Loss: 0.2709
Epoch: 2 Total_Time: 1.3651 Average_Time_per_batch: 0.0015 Train_Accuracy: 0.9310 Train_Loss: 0.2438 Validation_Accuracy: 0.9397 Validation_Loss: 0.2095
Epoch: 3 Total_Time: 1.3071 Average_Time_per_batch: 0.0014 Train_Accuracy: 0.9446 Train_Loss: 0.1950 Validation_Accuracy: 0.9485 Validation_Loss: 0.1776
Epoch: 4 Total_Time: 0.9413 Average_Time_per_batch: 0.0010 Train_Accuracy: 0.9535 Train_Loss: 0.1630 Validation_Accuracy: 0.9548 Validation_Loss: 0.1571
Epoch: 5 Total_Time: 1.3297 Average_Time_per_batch: 0.0014 Train_Accuracy: 0.9592 Train_Loss: 0.1410 Validation_Accuracy: 0.9572 Validation_Loss: 0.1461
Epoch: 6 Total_Time: 1.3497 Average_Time_per_batch: 0.0014 Train_Accuracy: 0.9637 Train_Loss: 0.1248 Validation_Accuracy: 0.9592 Validation_Loss: 0.1376
Epoch: 7 Total_Time: 1.3295 Average_Time_per_batch: 0.0014 Train_Accuracy: 0.9674