<a href="https://colab.research.google.com/github/rajlm10/D2L-Torch/blob/main/D2L_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [42]:
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from torch import nn

In [43]:
def get_fashion_mnist_labels(labels): 
  """Return text labels for the Fashion-MNIST dataset.""" 
  text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] 
  return [text_labels[int(i)] for i in labels]

In [44]:
#Dummy accuracy
def accuracy(y_hat, y):
  """Compute the number of correct predictions.""" 
  if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
    y_hat = y_hat.argmax(axis=1)
  cmp = y_hat.type(y.dtype) == y
  return float(cmp.type(y.dtype).sum())

In [45]:
class Accumulator: 
  """For accumulating sums over `n` variables.""" 
  def __init__(self, n):
    self.data = [0.0] * n 
    
  def add(self, *args):
    self.data = [a + float(b) for a, b in zip(self.data, args)] 
    
  def reset(self):
    self.data = [0.0] * len(self.data)
  
  def __getitem__(self, idx): 
    return self.data[idx]

In [46]:
def evaluate_accuracy(net,test_iter):
  """Compute the accuracy for a model on a dataset."""
  if isinstance(net,torch.nn.Module):
    net.eval()
  metric=Accumulator(2) #no of correct preds, no of predictions

  with torch.no_grad():
    for X, y in test_iter:
      metric.add(accuracy(net(X), y), y.numel()) 
  return metric[0] / metric[1]



In [47]:
def get_workers():
  return 2

In [48]:
def load_fashion_mnist(batch_size,resize=None):
  """Download the Fashion-MNIST dataset and then load it into memory."""
  trans=[transforms.ToTensor()] #PIL image to tensor (normalized between 0-1)
  if resize:
    trans.insert(0,transforms.Resize(resize))
  
  trans=transforms.Compose(trans) #Chains together transforms

  mnist_train=torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
  mnist_test=torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

  return data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_workers()),data.DataLoader(mnist_test,batch_size,shuffle=True,num_workers=get_workers())


In [49]:
batch_size = 256
train_iter, test_iter = load_fashion_mnist(batch_size)

In [50]:
from torch.nn.modules.activation import ReLU

net=nn.Sequential(
    nn.Flatten(),
    nn.Linear(784,256),
    nn.ReLU(),
    nn.Linear(256,10)
)

def init_weights(layer):
  if type(layer) == nn.Linear:
    nn.init.normal_(layer.weight,std=0.01)

net.apply(init_weights)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=10, bias=True)
)

In [51]:
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none') 
trainer = torch.optim.SGD(net.parameters(), lr=lr)

In [52]:
def train_epoch(net,training_set,loss,optimizer):
  #Note training set is an iterator
  if isinstance(net,torch.nn.Module):
    net.train()

  metric=Accumulator(3) #stores sum of training loss, sum of training accuracy, no. of examples

  for X,y in training_set:
    y_hat=net(X) # n X 10
    l=loss(y_hat,y) # nX10, nX1 -> nX1

    if isinstance(optimizer,torch.optim.Optimizer):
      optimizer.zero_grad()
      l.mean().backward()
      optimizer.step()

    metric.add(float(l.sum()),accuracy(y_hat,y),y.shape[0])

  return metric[0]/metric[2], metric[1]/metric[2]





In [53]:
def train(net,training_set,test_set,loss,optimizer,num_epochs):
  for epoch in range(num_epochs):
    train_loss,train_acc=train_epoch(net,training_set,loss,optimizer)
    test_acc = evaluate_accuracy(net, test_set)

    print(f'''epoch {epoch+1}: Train Loss: {train_loss},Train Acc: {train_acc}, Test Acc: {test_acc}''')

In [54]:
train(net,train_iter,test_iter,loss,trainer,10)

epoch 1: Train Loss: 1.0462628690083822,Train Acc: 0.6399666666666667, Test Acc: 0.7253
epoch 2: Train Loss: 0.5954515837351481,Train Acc: 0.7911333333333334, Test Acc: 0.7736
epoch 3: Train Loss: 0.5227659240722656,Train Acc: 0.8171, Test Acc: 0.801
epoch 4: Train Loss: 0.48079511960347493,Train Acc: 0.8312833333333334, Test Acc: 0.8302
epoch 5: Train Loss: 0.4560855724334717,Train Acc: 0.8404333333333334, Test Acc: 0.8244
epoch 6: Train Loss: 0.43415399583180747,Train Acc: 0.8482833333333333, Test Acc: 0.8348
epoch 7: Train Loss: 0.42083296286265054,Train Acc: 0.8528666666666667, Test Acc: 0.8248
epoch 8: Train Loss: 0.40448241259257,Train Acc: 0.8577, Test Acc: 0.8477
epoch 9: Train Loss: 0.3944818537394206,Train Acc: 0.8616666666666667, Test Acc: 0.8496
epoch 10: Train Loss: 0.3821519070943197,Train Acc: 0.8643833333333333, Test Acc: 0.8415
