In [1]:
import torch
import torchvision
from torchvision import transforms, datasets


In [2]:
train = datasets.MNIST("", train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor()]))

test = datasets.MNIST("", train=False, download=True,
                      transform=transforms.Compose([transforms.ToTensor()]))

Create Dataloader

In [3]:
train_data = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)

test_data = torch.utils.data.DataLoader(test, batch_size=10, shuffle=False)

Build the neural network

In [4]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(784, 64)
    self.fc2 = nn.Linear(64, 64)
    self.fc3 = nn.Linear(64, 64)
    self.fc4 = nn.Linear(64, 10)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))
    x = F.relu(self.fc4(x))

    return F.log_softmax(x,dim=1)

net = Net()
print(net)

Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=10, bias=True)
)


In [5]:
import torch.optim as optim
optimiser = optim.Adam(net.parameters(), lr=0.0001)
EPOCHS = 6

In [6]:
for epoch in range(EPOCHS):
  for data in train_data:
    #print(data)
    X, y = data
    net.zero_grad()
    output = net(X.view(-1, 28*28))
    loss = F.nll_loss(output, y)
    loss.backward()
    optimiser.step()
  print(loss)

tensor(0.5391, grad_fn=<NllLossBackward>)
tensor(1.7884, grad_fn=<NllLossBackward>)
tensor(0.4921, grad_fn=<NllLossBackward>)
tensor(0.2955, grad_fn=<NllLossBackward>)
tensor(0.4953, grad_fn=<NllLossBackward>)
tensor(0.0675, grad_fn=<NllLossBackward>)


Testing the network

In [7]:
correct = 0
total = 0

with torch.no_grad():
  for t_data in test_data:
    X_test, y_test = t_data
    # invoke the model with the test data X

    output_test = net(X_test.view(-1, 784))
    #print(output)

    for index, i in enumerate(output_test):
      if torch.argmax(i) == y_test[index]:
        correct +=1
      total +=1

In [8]:
print(correct)
print(total)
accuracy = (correct/total)*100
print(f"Accuracy: {accuracy}")

7542
10000
Accuracy: 75.42
