### Run in GPU

In [1]:
# Import dependencies
import torch
from PIL import Image
from torch import nn, save, load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor


In [2]:
# Get MNIST data 0-9
train = datasets.MNIST(root="data", download=True, train=True, transform=ToTensor()) #Transforming Image to Tensor
dataset = DataLoader(train, 32) #Batch size=32
#1,28,28 - classes 0-9

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.00MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:06<00:00, 243kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.90MB/s]


In [3]:
# Image Classifier Neural Network
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, (3,3)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (3,3)),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3,3)),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*(28-6)*(28-6), 10)
        )

    def forward(self, x):
        return self.model(x)

In [4]:
# Instance of the neural network, loss, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clf = ImageClassifier().to(device)
opt = Adam(clf.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

#### Link for jpeg - https://github.com/nicknochnack/PyTorchin15/tree/main

In [9]:
# Training flow
if __name__ == "__main__":
    for epoch in range(1): # train for 10 epochs
        for batch in dataset:
            X,y = batch
            X, y = X.to(device), y.to(device)
            yhat = clf(X)
            loss = loss_fn(yhat, y)

            # Apply backprop
            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"Epoch:{epoch} loss is {loss.item()}")

    with open('model_state.pt', 'wb') as f:
        save(clf.state_dict(), f)

    with open('model_state.pt', 'rb') as f:
        clf.load_state_dict(load(f))

    img = Image.open('img_3.jpg')
    img_tensor = ToTensor()(img).unsqueeze(0).to(device)

    print(torch.argmax(clf(img_tensor)))

Epoch:0 loss is 0.0
tensor(9, device='cuda:0')
