<a href="https://colab.research.google.com/github/rajlm10/D2L-Torch/blob/main/D2L_SR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from torch import nn

In [2]:
def get_fashion_mnist_labels(labels): 
  """Return text labels for the Fashion-MNIST dataset.""" 
  text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] 
  return [text_labels[int(i)] for i in labels]

In [4]:
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]) 
y_hat[[0, 1], y]

tensor([0.1000, 0.5000])

In [6]:
y_hat.shape,len(y_hat)

(torch.Size([2, 3]), 2)

In [8]:
#Dummy cross_entropy

def cross_entropy(y_hat,y):
  return -torch.log(y_hat[range(len(y_hat)),y])
cross_entropy(y_hat,y)

tensor([2.3026, 0.6931])

In [9]:
#Dummy accuracy

def accuracy(y_hat, y):
  """Compute the number of correct predictions.""" 
  if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
    y_hat = y_hat.argmax(axis=1)
  cmp = y_hat.type(y.dtype) == y
  return float(cmp.type(y.dtype).sum())

In [25]:
def evaluate_accuracy(net, data_iter):
  """Compute the accuracy for a model on a dataset.""" 
  if isinstance(net, torch.nn.Module):
    net.eval() # Set the model to evaluation mode
  metric = Accumulator(2) # No. of correct predictions, no. of predictions
  
  with torch.no_grad():
    for X, y in data_iter:
      metric.add(accuracy(net(X), y), y.numel()) 
  return metric[0] / metric[1]

In [10]:
class Accumulator: 
  """For accumulating sums over `n` variables.""" 
  def __init__(self, n):
    self.data = [0.0] * n 
    
  def add(self, *args):
    self.data = [a + float(b) for a, b in zip(self.data, args)] 
    
  def reset(self):
    self.data = [0.0] * len(self.data)
  
  def __getitem__(self, idx): 
    return self.data[idx]

In [11]:
def get_workers():
  return 4

In [12]:
def load_fashion_mnist(batch_size,resize=None):
  """Download the Fashion-MNIST dataset and then load it into memory."""
  trans=[transforms.ToTensor()] #PIL image to tensor (normalized between 0-1)
  if resize:
    trans.insert(0,transforms.Resize(resize))

  trans=transforms.Compose(trans) #Chains together transforms

  mnist_train=torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
  mnist_test=torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

  return data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_workers()),data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_workers())


In [13]:
train_iter, test_iter = load_fashion_mnist(32, resize=64) 

for X, y in train_iter:
 print(X.shape, X.dtype, y.shape, y.dtype) 
 break

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw



  cpuset_checked))


torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64


In [14]:
batch_size = 256
train_iter, test_iter = load_fashion_mnist(batch_size)

  cpuset_checked))


In [16]:
net=nn.Sequential(nn.Flatten(),nn.Linear(784,10))

def init_weights(layer):
  if type(layer) == nn.Linear:
    nn.init.normal_(layer.weight, std=0.01)

net.apply(init_weights)



Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=True)
)

In [20]:
net[1].weight.size(),net[1].bias.size()

(torch.Size([10, 784]), torch.Size([10]))

In [21]:
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

In [23]:
def train_epoch(net,training_set,loss,optimizer):
  #Note training set is an iterator
  if isinstance(net,torch.nn.Module):
    net.train()

  metric=Accumulator(3) #stores sum of training loss, sum of training accuracy, no. of examples

  for X,y in training_set:
    y_hat=net(X) # n X 10
    l=loss(y_hat,y) # nX10, nX1 -> nX1

    if isinstance(optimizer,torch.optim.Optimizer):
      optimizer.zero_grad()
      l.mean().backward()
      optimizer.step()

    metric.add(float(l.sum()),accuracy(y_hat,y),y.shape[0])

  return metric[0]/metric[2], metric[1]/metric[2]





In [26]:
def train(net,training_set,test_set,loss,optimizer,num_epochs):
  for epoch in range(num_epochs):
    train_metrics=train_epoch(net,training_set,loss,optimizer)
    test_acc = evaluate_accuracy(net, test_set)
    
    print(f'''epoch {epoch}: {train_metrics},{test_acc}''')



In [27]:
train(net,train_iter,test_iter,loss,trainer,10)

  cpuset_checked))


epoch 0: (0.787415256690979, 0.7483833333333333),0.7823833333333333
epoch 1: (0.5705126907348633, 0.8134),0.8203666666666667
epoch 2: (0.5246704514821371, 0.8264833333333333),0.8280166666666666
epoch 3: (0.5017361728668213, 0.8326333333333333),0.8306333333333333
epoch 4: (0.485469190343221, 0.8365166666666667),0.83725
epoch 5: (0.4740749600728353, 0.8404833333333334),0.8304166666666667
epoch 6: (0.46519881744384767, 0.8426333333333333),0.8284666666666667
epoch 7: (0.4582118424097697, 0.8457333333333333),0.8485833333333334
epoch 8: (0.4531071772257487, 0.84645),0.8452
epoch 9: (0.4472584587097168, 0.8480166666666666),0.8478166666666667
