In [26]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import onnx

In [30]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # image size = 28x28
        self.conv1 = nn.Conv2d(1, 10, 3) 
        # image size = 26x26 x 10
        self.pool = nn.MaxPool2d(2, 2)
        # image size = 13x13 x 10
        self.layer1 = nn.Linear(13 * 13 * 10, 50)
        self.layer2 = nn.Linear(50, 10)

    def forward(self, flat_input):
        # form 2d matrix from 1d vector
        input = flat_input.view(1, 1, 28, 28)
        y = self.pool(self.conv1(input))
        y = t.flatten(y, 1)
        y = self.layer1(y)
        y = F.relu(y)
        y = self.layer2(y)
        y = F.softmax(y, dim=1)
        return y
    
    def fit_batch(self, train_load, test_loader, loss_func, optimizer):
        self.train()
        for i, batch in enumerate(train_load):
            batch_x, batch_y = batch

            y_pred = self.forward(batch_x)
            loss = loss_func(y_pred, batch_y)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            # if i % 200 == 0:
            #     print(f"Batch {i}/{len(train_load)}      training loss: {loss}")

        total_loss = 0
        self.eval()
        with t.no_grad():
            for i, batch in enumerate(test_loader):
                batch_x, batch_y = batch
                y_pred = self.forward(batch_x)
                loss = loss_func(y_pred, batch_y)
                total_loss += loss
        total_loss /= len(test_loader)
        print(f"Validation loss: {total_loss}")
        return total_loss

In [31]:
model = Net()

In [32]:
model.load_state_dict(t.load("./handwritten_model.pth"))

<All keys matched successfully>

In [33]:
test_input = t.from_numpy(np.random.rand(28 * 28)).float()
test_input.shape

torch.Size([784])

In [34]:
model(test_input)

tensor([[4.9756e-04, 3.1206e-31, 4.6606e-09, 1.7086e-06, 5.8082e-05, 1.4601e-04,
         4.9442e-07, 8.0738e-11, 9.9930e-01, 4.8773e-09]],
       grad_fn=<SoftmaxBackward0>)

In [35]:
t.onnx.export(model, t.randn(28 * 28), "./handwritten_flatten.onnx") # export to onnx