# MNIST Pytorch Linear Sample
from 'Pytorch Lightning' YouTube Channel, 'Episode 1: Training a classification model on MNIST with PyTorch'<br/>
https://youtu.be/OMDn66kM9Qc

In [1]:
import torch
#import torch.nn as nn
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [2]:
torch.randn(5)
#torch.randn(5).cuda()

tensor([0.4090, 0.0465, 0.8155, 0.2940, 0.3471])

In [3]:
model0 = nn.Sequential(
nn.Linear(28*28,64),
    nn.ReLU(),
    nn.Linear(64,64),
    nn.ReLU(),
    nn.Linear(64,10)
)

In [4]:
model0

Sequential(
  (0): Linear(in_features=784, out_features=64, bias=True)
  (1): ReLU()
  (2): Linear(in_features=64, out_features=64, bias=True)
  (3): ReLU()
  (4): Linear(in_features=64, out_features=10, bias=True)
)

In [5]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1=nn.Linear(28*28,64)
        self.l2=nn.Linear(64,64)        
        self.l3=nn.Linear(64,10)
        self.do=nn.Dropout(0.1)   
    
    def forward(self,x):
        h1=nn.functional.relu(self.l1(x))
        h2=nn.functional.relu(self.l2(h1))
        do=self.do(h2+h1)
        logits=self.l3(do)
        return logits
    
model=ResNet()  

In [6]:
model

ResNet(
  (l1): Linear(in_features=784, out_features=64, bias=True)
  (l2): Linear(in_features=64, out_features=64, bias=True)
  (l3): Linear(in_features=64, out_features=10, bias=True)
  (do): Dropout(p=0.1, inplace=False)
)

In [7]:
params=model.parameters()
optimizer=optim.SGD(model.parameters(),lr=1e-2)

In [8]:
loss=nn.CrossEntropyLoss()

In [9]:
train_data=datasets.MNIST('data',train=True,download=True,transform=transforms.ToTensor())
train,val=random_split(train_data,[55000,5000])
train_loader=DataLoader(train,batch_size=32)
val_loader=DataLoader(val,batch_size=32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [10]:
nb_epochs=5

for epoch in range(nb_epochs):
    losses1=list()
    accuracy1=list()
    for batch in train_loader:
        x,y=batch
        b=x.size(0)
        x=x.view(b,-1)
        logits=model(x)
        J=loss(logits,y)
        model.zero_grad()
        J.backward()
        optimizer.step()
        losses1.append(J.item())
        accuracy1.append(y.eq(logits.detach().argmax(dim=1)).float().mean()) 
        
    print(f'Epoch {epoch+1}, train loss: {torch.tensor(losses1).mean():.2f}, train acc: {torch.tensor(accuracy1).mean():.2f}') 

    losses2=list()
    accuracy2=list()
    for batch in val_loader:
        x,y=batch
        b=x.size(0)
        x=x.view(b,-1)
        with torch.no_grad():
            logits=model(x)
        J=loss(logits,y)
        losses2.append(J.item())
        accuracy2.append(y.eq(logits.detach().argmax(dim=1)).float().mean()) 
        
    print(f'Epoch {epoch+1}, valid loss: {torch.tensor(losses2).mean():.2f}, valid acc: {torch.tensor(accuracy2).mean():.2f}') 
    print()

Epoch 1, train loss: 0.85, train acc: 0.78
Epoch 1, valid loss: 0.43, valid acc: 0.87

Epoch 2, train loss: 0.38, train acc: 0.89
Epoch 2, valid loss: 0.34, valid acc: 0.90

Epoch 3, train loss: 0.31, train acc: 0.91
Epoch 3, valid loss: 0.29, valid acc: 0.92

Epoch 4, train loss: 0.27, train acc: 0.92
Epoch 4, valid loss: 0.26, valid acc: 0.93

Epoch 5, train loss: 0.24, train acc: 0.93
Epoch 5, valid loss: 0.24, valid acc: 0.93

