In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from spn_3.spn import SPN
from utils import train

In [2]:
class MLP_MNIST(nn.Module):
    def __init__(self):
        super(MLP_MNIST, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.model(x)


In [3]:
#hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load and preprocess the entire MNIST dataset once
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)

num_classes = 10

# Normalize and flatten once
train_data = train_dataset.data.float().div(255).view(-1, 784).to(device)
train_labels = F.one_hot(train_dataset.targets.to(device), num_classes=num_classes).float()

test_data = test_dataset.data.float().div(255).view(-1, 784).to(device)
test_labels = F.one_hot(test_dataset.targets.to(device), num_classes=num_classes).float()

# Create DataLoader from preloaded GPU tensors
train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=batch_size, shuffle=False)

In [5]:
model_1 = MLP_MNIST().to(device)
model_1 = torch.jit.script(model_1)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_1.parameters(), lr=learning_rate)

In [6]:
train(model_1, train_loader, None, test_loader, epochs, optimizer, criterion)

Epoch: 1 Total_Time: 2.3711 Average_Time_per_batch: 0.0025 Train_Accuracy: 0.9204 Train_Loss: 0.2621 
Epoch: 2 Total_Time: 1.9994 Average_Time_per_batch: 0.0021 Train_Accuracy: 0.9705 Train_Loss: 0.0965 
Epoch: 3 Total_Time: 1.6010 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9795 Train_Loss: 0.0655 
Epoch: 4 Total_Time: 1.6191 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9847 Train_Loss: 0.0481 
Epoch: 5 Total_Time: 1.5913 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9873 Train_Loss: 0.0400 
Epoch: 6 Total_Time: 1.6047 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9905 Train_Loss: 0.0310 
Epoch: 7 Total_Time: 1.6079 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9912 Train_Loss: 0.0273 
Epoch: 8 Total_Time: 1.6212 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9930 Train_Loss: 0.0218 
Epoch: 9 Total_Time: 1.7408 Average_Time_per_batch: 0.0019 Train_Accuracy: 0.9926 Train_Loss: 0.0226 
Epoch: 10 Total_Time: 1.6006 Average_Time_per_batch: 0.0017 Train_Accuracy: 0.9950

In [None]:
model_2 = SPN(784, 246, 10).cuda()
model_2 = torch.jit.trace(model_2, torch.randn(batch_size, 784).cuda())
optimizer = optim.Adam(model_2.parameters(), lr=learning_rate)

In [8]:
train(model_2, train_loader, None, test_loader, epochs, optimizer, criterion)

Epoch: 1 Total_Time: 1.3331 Average_Time_per_batch: 0.0014 Train_Accuracy: 0.9109 Train_Loss: 0.3205 
Epoch: 2 Total_Time: 1.1924 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9606 Train_Loss: 0.1363 
Epoch: 3 Total_Time: 1.1987 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9731 Train_Loss: 0.0913 
Epoch: 4 Total_Time: 1.1946 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9793 Train_Loss: 0.0680 
Epoch: 5 Total_Time: 1.1941 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9845 Train_Loss: 0.0520 
Epoch: 6 Total_Time: 1.2135 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9881 Train_Loss: 0.0402 
Epoch: 7 Total_Time: 1.2073 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9908 Train_Loss: 0.0308 
Epoch: 8 Total_Time: 1.1791 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9930 Train_Loss: 0.0251 
Epoch: 9 Total_Time: 1.1625 Average_Time_per_batch: 0.0012 Train_Accuracy: 0.9939 Train_Loss: 0.0209 
Epoch: 10 Total_Time: 1.1771 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.9958