In [6]:
import torch
import torch.nn.functional as F

In [7]:
output = F.log_softmax(torch.tensor([[0.2, 0.1, 0.3], 
                                    [0.2, 0.4, 0.3],
                                     [0.8, 0.4, 0.3]]), dim=1)

In [8]:
output

tensor([[-1.1019, -1.2019, -1.0019],
        [-1.2019, -1.0019, -1.1019],
        [-0.8228, -1.2228, -1.3228]])

In [9]:
for i in output:
  print(torch.argmax(i))

tensor(2)
tensor(1)
tensor(0)


In [10]:
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [11]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [12]:
transform = transforms.ToTensor()

In [None]:
mnist_train = datasets.MNIST(root='/content/sample_data',
                               train=True, download=True,
                               transform=transform)

In [14]:
mnist_test = datasets.MNIST(root='/content/sample_data',
                               train=False, download=True,
                               transform=transform)

In [15]:
train_loader = DataLoader(mnist_train, batch_size=10, shuffle=True)

In [16]:
test_loader = DataLoader(mnist_test, batch_size=1000, shuffle=True)

In [17]:
class DigitClassifier(nn.Module):
  def __init__(self):
    super().__init__()

    self.fc1 = nn.Linear(784, 100)
    self.fc2 = nn.Linear(100, 100)
    self.fc3 = nn.Linear(100, 50)
    self.fc4 = nn.Linear(50, 10)

  def forward(self, input_data):
    x = F.relu(self.fc1(input_data))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))
    x= self.fc4(x)
    return F.log_softmax(x, dim=1)

In [18]:
digit_classifier = DigitClassifier()

In [22]:
for i, j in train_loader:

  output = digit_classifier(i.view(-1, 28*28))
  for idx, i in enumerate(output):
    print(torch.argmax(i))
  print(j)
  break

tensor(0)
tensor(3)
tensor(0)
tensor(6)
tensor(1)
tensor(8)
tensor(9)
tensor(1)
tensor(2)
tensor(3)
tensor([0, 3, 0, 6, 1, 1, 9, 1, 2, 3])


In [20]:
optimiser = torch.optim.Adam(digit_classifier.parameters(), lr=0.001)

In [21]:
for epoch in range(3):
  for i, j in train_loader:
    digit_classifier.zero_grad()

    output = digit_classifier(i.view(-1, 28*28))
    loss = F.nll_loss(output, j)
    loss.backward()
    optimiser.step()
  print(loss)

tensor(0.0551, grad_fn=<NllLossBackward0>)
tensor(0.2435, grad_fn=<NllLossBackward0>)
tensor(0.4587, grad_fn=<NllLossBackward0>)


In [23]:
correct = 0
total = 0
with torch.no_grad():
  for i_x, i_y in test_loader:
    output = digit_classifier(i_x.view(-1,28*28))
    for idx, i in enumerate(output):
      if torch.argmax(i) == i_y[idx]:
        correct += 1
      total += 1

print(correct/total)
  

0.9682
