<a href="https://colab.research.google.com/github/teang1995/Algorithm_study/blob/master/LeNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchtext==0.8.0 torch==1.7.1 pytorch-lightning==1.2.2 torchmetrics torchvision



In [2]:
import torch
from torchmetrics import Metric
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from typing import Optional

In [3]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def prepare_data(self):
        """
            - download
            - tokenize
            - etc ...
        """
        MNIST(self.data_dir, train=True, download=True) # just download at data_dir
        MNIST(self.data_dir, train=False, download=True) # just download at data_dir
    def setup(self, stage:Optional[str]) -> None:
        """
            - count number of classes
            - build vocabulary
            - perform train/val/test split
            - apply transforms
            - etc...
        """
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) 
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # split data
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)


In [4]:
mnist_datamodule = MNISTDataModule(data_dir='./', batch_size=32)

In [5]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                           out_channels=3,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           bias=True)
        self.conv2 = nn.Conv2d(in_channels=3,
                           out_channels=6,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           bias=True)
        self.subsample = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(294, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        bs = x.shape[0]
        fm1 = self.conv1(x)
        fm2 = self.subsample(fm1)
        fm3 = self.conv2(fm2)
        fm4 = self.subsample(fm3)
        flattend = fm4.view(bs, -1)
        em1 = self.fc1(flattend)
        em2 = self.fc2(em1)
        return torch.log_softmax(em2, dim=1)

In [6]:
class LeNetModule(pl.LightningModule):
    def __init__(self,
               init_lr : int = 1e-2):
        super().__init__()
        self.init_lr = init_lr
        self.net = LeNet()
        print("init done!")
    
    def forward(self, X):
        return self.net(X)

    def training_step(self, batch, batch_idx):
        X, y = batch
        prediction = self.forward(X)
        loss = F.cross_entropy(prediction, y)
        return {'loss': loss}
  
    def validation_step(self, batch, batch_idx):
        X, y = batch
        prediction = self.forward(X)
        loss = F.cross_entropy(prediction, y)
        
        acc = torch.sum(prediction.max(1, keepdim=True)[1] == y)
        metrics = {'val_acc': acc, 'val_loss': loss}
        self.log_dict(metrics)

    def test_step(self, batch, batch_idx):
        X, y= batch
        prediction = self.forward(X)
        loss = F.cross_entropy(prediction, y)
        acc = torch.sum(prediction.max(1)[1] == y) * 100 / len(y)
        metrics = {'test_acc': acc, 'test_loss': loss}
        self.log_dict(metrics)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.init_lr)

In [7]:
import numpy as np
a = [[1 if j % 10 == i else 0 for i in range(10)] for j in range(32)]
y = [i % 10 for i in range(32)]
y = np.array(y)
y = torch.tensor(y)
a = np.array(a)
b = torch.tensor(a)
c = b.max(1)[1]
result = (y==c)
result.shape

torch.Size([32])

In [8]:
lenet_module = LeNetModule(init_lr=1e-2)

init done!


In [9]:
device = 1 if torch.cuda.is_available() else 0
max_epochs=10

trainer = pl.Trainer(gpus=device,
                     max_epochs=max_epochs,
                     progress_bar_refresh_rate=20,
                     num_sanity_val_steps=0)

trainer.fit(model=lenet_module, datamodule=mnist_datamodule)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type  | Params
-------------------------------
0 | net  | LeNet | 156 K 
-------------------------------
156 K     Trainable params
0         Non-trainable params
156 K     Total params
0.625     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

In [10]:
trainer.test(lenet_module, datamodule=mnist_datamodule, verbose=True)

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 97.33999633789062, 'test_loss': 0.08329261839389801}
--------------------------------------------------------------------------------


[{'test_acc': 97.33999633789062, 'test_loss': 0.08329261839389801}]