In [1]:
import numpy as np
import torch

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from torch.utils.data import DataLoader

In [2]:
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f57b5f7a9b0>

In [3]:
datapath = 'data'

# load dataset
data_train = MNIST(
    root = datapath,
    train = True,                         
    transform = ToTensor(), 
)
data_test = MNIST(
    root = datapath, 
    train = False, 
    transform = ToTensor(),
)

In [4]:
import torch.nn as nn

class SimpleMLP(nn.Module):

    """
    - input_shape: shape of a single input data point
    """
    def __init__(self, input_shape, n_classes):
        super(SimpleMLP, self).__init__()
        self.input_shape = np.asarray(input_shape)
        self.n_classes = n_classes
        self.seq_model = nn.Sequential(
            nn.Linear(self.input_shape.prod(), 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, self.n_classes)
        )

    def forward(self, x):
        x = x.view(-1, self.input_shape.prod())  # make the input of shape (batch_size, height*weight)
        logits = self.seq_model(x)
        return logits

In [5]:
model = SimpleMLP(input_shape=(1, 28, 28), n_classes=10)

In [6]:
loaded_state_dict = torch.load("saved_models/MLP.pt")
model.load_state_dict(loaded_state_dict)

<All keys matched successfully>

In [7]:
train_loader = DataLoader(data_train, batch_size=32, shuffle=True)
test_loader = DataLoader(data_test, batch_size=32, shuffle=False)

In [10]:
def model_accuracy(data_loader):
    n_total = 0
    n_correct = 0

    for x_batch, y_batch in data_loader:
        logits_batch = model(x_batch)  # model's output scores
        n_total += len(y_batch)
        n_correct += sum(logits_batch.argmax(axis=-1) == y_batch)
    return (n_correct / n_total).item()

print(f"Train accuracy before training: {model_accuracy(train_loader):.4f}")
print(f"Test accuracy before training: {model_accuracy(test_loader):.4f}")

Train accuracy before training: 0.993
Test accuracy before training: 0.977
