In [41]:
import torch
import torch.nn as nn
import torch.fx as fx
import torch.optim as optim
import torch.nn.functional as F
from convolution import CustomConv2D

In [42]:
class MNISTModel(nn.Module):
    def __init__(self, is_training=True):
        super(MNISTModel, self).__init__()

        if is_training:
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        else:
            self.conv1 = CustomConv2D(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = CustomConv2D(32, 64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [45]:
model = MNISTModel(is_training=True)

example_input = torch.randn(1, 1, 28, 28)
traced_model = torch.jit.trace(model, example_input)
traced_model.save("traced_mnist_model.pt")

In [46]:
loaded_traced_model = torch.jit.load("traced_mnist_model.pt")

output = loaded_traced_model(example_input)
print("Traced Model Output:", output)

Traced Model Output: tensor([[ 0.2431, -0.0546, -0.1588, -0.0545,  0.0427, -0.1299, -0.0787, -0.0068,
         -0.0827, -0.2091]], grad_fn=<AddmmBackward0>)


In [47]:
model = MNISTModel()

scripted_model = torch.jit.script(model)

scripted_model.save("scripted_mnist_model.pt")

In [48]:
loaded_scripted_model = torch.jit.load("scripted_mnist_model.pt")

output = loaded_scripted_model(example_input)
print("Scripted Model Output:", output)

Scripted Model Output: tensor([[ 0.1761,  0.0027,  0.0665, -0.0906, -0.0172,  0.0932, -0.0141,  0.1020,
         -0.2283, -0.1848]], grad_fn=<AddmmBackward0>)
