In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        # Input layer (28*28 = 784 pixels) to a hidden layer of 128 neurons
        self.fc1 = nn.Linear(28 * 28, 128)
        # Second hidden layer
        self.fc2 = nn.Linear(128, 64)
        # Output layer (10 classes for digits 0-9)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # Flatten the input image from 28x28 to a 784-element vector
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # Apply log softmax for the final output (often used with NLLLoss, which is part of CrossEntropyLoss)
        return F.log_softmax(x, dim=1)

In [None]:
SimpleNN

In [None]:
model = SimpleNN().to("cpu")

In [None]:
model

In [None]:
torch.load("mnist_simple_nn.pth",
           map_location=torch.device("cpu"),
           weights_only=True,
           )

In [None]:
model.load_state_dict(torch.load("mnist_simple_nn.pth",
                                 map_location=torch.device("cpu"),
                                 weights_only=True,
                                 ))

In [None]:
from matplotlib import pyplot as plt

with torch.no_grad():
    x = torch.randn(size=(28,28))
    y = model(x)
    print(y)
    print(y.max())
    plt.imshow(x, cmap="gray")
    plt.title(y.argmax())
    plt.show()
