In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [2]:
train_dataset = datasets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

In [3]:
test_dataset = datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

In [4]:
#making the dataset iterable

batch_size = 100
n_iters = 3000
nb_epochs = n_iters / (len(train_dataset) / batch_size)
nb_epochs = int(nb_epochs)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [5]:
class FFNModel(nn.Module):
  def __init__(self, in_dim, h_dim, out_dim):
    super(FFNModel, self).__init__()
    self.fc1 = nn.Linear(in_dim, h_dim)
    self.relu1 = nn.ReLU()

    self.fc2 = nn.Linear(h_dim, h_dim)
    self.relu2 = nn.ReLU()

    self.fc3 = nn.Linear(h_dim, h_dim)
    self.relu3 = nn.ReLU()

    self.fc4 = nn.Linear(h_dim, out_dim)


  def forward(self, x):

      # Linear function 1
      out = self.fc1(x)
      # Non-linearity 1
      out = self.relu1(out)
      
      # Linear function 2
      out = self.fc2(out)
      # Non-linearity 2
      out = self.relu2(out)
      
      # Linear function 2
      out = self.fc3(out)
      # Non-linearity 2
      out = self.relu3(out)
      
      # Linear function 4 (readout)
      out = self.fc4(out)
      return out



In [6]:
input_dim = 28 * 28
hidden_dim = 100
output_dim = 10

model = FFNModel(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)


In [7]:
iter = 0

for epoch in range(nb_epochs):
  for i, (images, labels) in enumerate(train_loader):
    images = images.view(-1, 28 * 28).requires_grad_()

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    iter += 1

    if iter % 500 == 0:
      correct = 0
      total = 0
      for images, labels in test_loader:
        images = images.view(-1, 28*28).requires_grad_()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

      accuracy = 100 * correct / total

      print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy)) 



Iteration: 500. Loss: 0.19334676861763. Accuracy: 91
Iteration: 1000. Loss: 0.14088737964630127. Accuracy: 94
Iteration: 1500. Loss: 0.112630195915699. Accuracy: 94
Iteration: 2000. Loss: 0.10298855602741241. Accuracy: 96
Iteration: 2500. Loss: 0.08683620393276215. Accuracy: 96
Iteration: 3000. Loss: 0.04834673926234245. Accuracy: 96
