<a href="https://colab.research.google.com/github/rishikesh953/pytorch/blob/main/Digit_Recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
# import statements

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


In [29]:
# device configure

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [30]:
# hyper parameters

input_size = 784
hidden_size = 100
num_classes = 10
num_epochs = 4
batch_size = 100
learning_rate = 0.01

In [31]:
# MNIST data

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

examples = iter(train_loader)
samples, labels = examples.next()

In [32]:
# print(samples.shape, labels.shape)
#
# for i in range(6):
#     plt.subplot(2, 3, i+1)
#     plt.imshow(samples[i][0],cmap='gray')
#     plt.show()

In [33]:
# custom model

class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out


model = NeuralNet(input_size, hidden_size, num_classes)

In [34]:
# loss and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [35]:
# training loop

n_total_steps = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):

        images = images.reshape(-1, 28 * 28).to(device)
        labels = labels.to(device)

        # forward pass

        outputs = model(images)
        loss = criterion(outputs, labels)

        # backward pass

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'epoch:{epoch + 1}/{num_epochs}, step {i + 1}/{n_total_steps}, loss = {loss.item():.5f}')


epoch:1/4, step 100/600, loss = 0.22877
epoch:1/4, step 200/600, loss = 0.09792
epoch:1/4, step 300/600, loss = 0.22030
epoch:1/4, step 400/600, loss = 0.09350
epoch:1/4, step 500/600, loss = 0.15717
epoch:1/4, step 600/600, loss = 0.09491
epoch:2/4, step 100/600, loss = 0.22662
epoch:2/4, step 200/600, loss = 0.24435
epoch:2/4, step 300/600, loss = 0.06578
epoch:2/4, step 400/600, loss = 0.06415
epoch:2/4, step 500/600, loss = 0.17718
epoch:2/4, step 600/600, loss = 0.19168
epoch:3/4, step 100/600, loss = 0.06649
epoch:3/4, step 200/600, loss = 0.03233
epoch:3/4, step 300/600, loss = 0.12802
epoch:3/4, step 400/600, loss = 0.35394
epoch:3/4, step 500/600, loss = 0.07361
epoch:3/4, step 600/600, loss = 0.08021
epoch:4/4, step 100/600, loss = 0.21724
epoch:4/4, step 200/600, loss = 0.06487
epoch:4/4, step 300/600, loss = 0.13474
epoch:4/4, step 400/600, loss = 0.03515
epoch:4/4, step 500/600, loss = 0.08369
epoch:4/4, step 600/600, loss = 0.02503


In [36]:
# test the model

with torch.no_grad():
    n_correct = 0
    n_samples = 0
    
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        
        # values, index
        
        _, predictions = torch.max(outputs, 1)
        n_samples += labels.shape[0]
        n_correct += (predictions == labels).sum().item()
        
    acc = 100* n_correct/n_samples
    
    print(f'accuracy = {acc}')

accuracy = 96.5


In [37]:
#  saving the model

FILE = 'digit_recognition.pth'
torch.save(model.state_dict(), FILE)