<a href="https://colab.research.google.com/github/rafid-dev/MNISTImageClassifier/blob/main/ImageClassifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import nn, save, load

In [3]:
NUM_EPOCHS = 10
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
# Get MNIST Dataset
train = datasets.MNIST(root="data", download=True, train=True, transform=transforms.ToTensor(), )
dataset = DataLoader(train, 32)

In [7]:
# Image Classifier Neural Network
class ImageClassifier(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(1, 32, (3, 3)),
        nn.ReLU(),
        nn.Conv2d(32, 64, (3, 3)),
        nn.ReLU(),
        nn.Conv2d(64, 64, (3, 3)),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64*(28-6)*(28-6), 10)
    )

  def forward(self, x):
    return self.model(x)

In [8]:
# Net object
net = ImageClassifier().to(DEVICE)

# Optimizer object
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_function = nn.CrossEntropyLoss()

In [15]:
# Training

for epoch in range(NUM_EPOCHS):
  for batch in dataset:
    x, y = batch
    x, y = x.to(DEVICE), y.to(DEVICE)

    # Forward
    yhat = net(x)
    loss = loss_function(yhat, y)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print(f"Epoch: {epoch} Loss: {loss.item()}")

  with open('model_state.pt', 'wb') as f:
    save(net.state_dict(), f)

Epoch: 0 Loss: 0.009057772345840931
Epoch: 1 Loss: 0.003623046213760972
Epoch: 2 Loss: 0.0006001919973641634
Epoch: 3 Loss: 0.006246666423976421
Epoch: 4 Loss: 8.906950824894011e-05
Epoch: 5 Loss: 0.017229881137609482
Epoch: 6 Loss: 7.541037484770641e-05
Epoch: 7 Loss: 1.3746202967013232e-06
Epoch: 8 Loss: 1.1511040156619856e-06
Epoch: 9 Loss: 3.42724916890802e-07


In [18]:
from PIL import Image

with open('model_state.pt', 'rb') as f:
    net.load_state_dict(load(f))

img = Image.open('img_1.jpg')
img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(DEVICE)
print(torch.argmax(net(img_tensor)))

tensor(2)
