In [1]:
from IPython.display import clear_output as clear

!pip install pytorch_lightning
clear()

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os

import torch
import pytorch_lightning as pl 

from torch import nn
from torch.nn import functional as F 
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.datasets import CIFAR10
from pytorch_lightning.metrics.functional import accuracy

In [3]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

In [4]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, conf):
        super().__init__()
        self.data_path = conf.data_path
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.dims = conf.dims 
        self.num_classes = conf.num_classes
    
    def prepare_data(self):
        CIFAR10(self.data_path, train=True, download=True)
        CIFAR10(self.data_path, train=False, download=True)
    
    def setup(self, stage):
        if stage == 'fit' or stage is None:
            data_full = CIFAR10(self.data_path, train=True, transform=self.transform)
            self.data_train, self.data_val = random_split(data_full, [45000, 5000])
        
        if stage == 'test' or stage is None:
            self.data_test = CIFAR10(self.data_path, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=32)
    
    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=32)
    
    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=32)

In [5]:
class CNNMNIST(pl.LightningModule):
    def __init__(self, conf):
        super().__init__()

        self.conf = conf 
        channels, width, height = conf.dims
        self.conv1 = nn.Conv2d(channels, 16, 3, padding=1) # 32 -> 16
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # 16 -> 8
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # 8 -> 4
        self.maxpool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(64*(width//8)*(height//8), 128)
        self.out = nn.Linear(128, conf.num_classes)

    def forward(self, x):
        b, c, w, h = x.size()
        x = self.dropout(F.relu(self.maxpool(self.conv1(x))))
        x = self.dropout(F.relu(self.maxpool(self.conv2(x))))
        x = self.dropout(F.relu(self.maxpool(self.conv3(x))))
        x = x.view(-1, 64*(w//8)*(h//8))
        x = self.dropout(self.fc(x))
        out = self.out(x)
        return out 
    
    def training_step(self, batch, batch_idx):
        x, y = batch 
        outputs = self(x)
        loss = F.cross_entropy(outputs, y)
        return loss 
    
    def validation_step(self, batch, batch_idx):
        x, y = batch 
        outputs = self(x)
        loss = F.cross_entropy(outputs, y)
        acc = accuracy(y, torch.argmax(outputs, -1))
    
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss 
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.conf.learning_rate)
        return optimizer

In [6]:
conf = Config(
    data_path=r'./',
    dims=(3, 32, 32),
    num_classes=10,
    learning_rate=3e-3,
)
conf

{'data_path': './',
 'dims': (3, 32, 32),
 'learning_rate': 0.003,
 'num_classes': 10}

In [7]:
dm = CIFAR10DataModule(conf)
model = CNNMNIST(conf)
trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=20)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [8]:
trainer.fit(model, dm)

Files already downloaded and verified
Files already downloaded and verified



  | Name    | Type      | Params
--------------------------------------
0 | conv1   | Conv2d    | 448   
1 | conv2   | Conv2d    | 4.6 K 
2 | conv3   | Conv2d    | 18.5 K
3 | maxpool | MaxPool2d | 0     
4 | dropout | Dropout   | 0     
5 | fc      | Linear    | 131 K 
6 | out     | Linear    | 1.3 K 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [9]:
trainer.test()

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': tensor(0.6319, device='cuda:0'),
 'val_loss': tensor(1.0475, device='cuda:0')}
--------------------------------------------------------------------------------



[{'val_acc': 0.6319000124931335, 'val_loss': 1.0474953651428223}]