In [1]:
!pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/12/98/86a89dcd54f84582bbf24cb29cd104b966fcf934d92d5dfc626f225015d2/pytorch_lightning-1.1.4-py3-none-any.whl (684kB)
[K     |████████████████████████████████| 686kB 14.3MB/s 
Collecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/ec/80/72ac0982cc833945fada4b76c52f0f65435ba4d53bc9317d1c70b5f7e7d5/fsspec-0.8.5-py3-none-any.whl (98kB)
[K     |████████████████████████████████| 102kB 9.2MB/s 
[?25hCollecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 43.5MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 42.3MB/s 

In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import EarlyStopping

def seed():
  pl.utilities.seed.seed_everything(113)

# Data Processing

In [None]:
!wget https://www.dropbox.com/s/3j76hc0q63it4iu/dataset.zip
!unzip dataset.zip

In [4]:
img_size = 150
train_dir = "train/"
test_dir = "test/"

train_transform = transforms.Compose([
    transforms.CenterCrop(img_size),
    transforms.RandomRotation(5), 
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])


test_transform = transforms.Compose([
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_img = datasets.ImageFolder(train_dir, transform=train_transform)
test_img = datasets.ImageFolder(test_dir, transform=test_transform)

trainloaders = torch.utils.data.DataLoader(train_img, batch_size=64, shuffle=True)
testloaders = torch.utils.data.DataLoader(test_img, batch_size=32)

# Lightning Model

In [5]:
class lightningModel(pl.LightningModule):

    def __init__(self, ks=4, ps=3, fm1=16, fm2=32, n=256, n_labels=5):
        super().__init__()  
        self.conv1 = nn.Conv2d(3, fm1, kernel_size=ks, stride=1, padding=0)
        self.pool = nn.MaxPool2d(kernel_size=ps, stride=2, padding=0)
        self.conv2 = nn.Conv2d(fm1, fm2, kernel_size=ks, stride=1, padding=0)
        
        # calculate CNN's output size
        res = self.conv_size(self.conv_size(self.conv_size(self.conv_size(img_size, ks), ps, s=2), ks), ps, s=2)**2*fm2
        
        self.fc1 = nn.Linear(res, n)
        self.fc2 = nn.Linear(n, n_labels)
        self.do = nn.Dropout()

        self.loss = nn.CrossEntropyLoss()

    def conv_size(self, inp, k, p=0, s=1):
        return (inp-k+2*p)//s+1

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(x.shape[0], -1)
        x = self.do(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer
    
    def training_step(self, batch, batch_idx):
        image, label = batch
        output = self(image)
        error = self.loss(output, label)
        acc = accuracy(output, label)
        self.log('train_loss', error, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {'loss':error}

    def test_step(self, batch, batch_idx):
        image, label = batch
        output = self(image)
        error = self.loss(output, label)
        acc = accuracy(output, label)
        return {'loss':error, 'accuracy':acc}

    def test_epoch_end(self, outputs):
        acc = torch.Tensor([x['accuracy'] for x in outputs]).mean()
        error = torch.Tensor([x['loss'] for x in outputs]).mean()
        self.log('test_loss', error)
        self.log('test_acc', acc)
        return {'loss':error, 'accuracy':acc}


# Train the Lightning

In [6]:
seed()
model = lightningModel()
trainer = pl.Trainer(gpus=1, max_epochs=20)
trainer.fit(model, trainloaders)

Global seed set to 113
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | conv1 | Conv2d           | 784   
1 | pool  | MaxPool2d        | 0     
2 | conv2 | Conv2d           | 8.2 K 
3 | fc1   | Linear           | 9.5 M 
4 | fc2   | Linear           | 1.3 K 
5 | do    | Dropout          | 0     
6 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
9.5 M     Trainable params
0         Non-trainable params
9.5 M     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

# Test the Lightning

In [7]:
trainer.test(model, testloaders)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'accuracy': tensor(0.5677),
 'loss': tensor(1.2040),
 'test_acc': tensor(0.5677),
 'test_loss': tensor(1.2040)}
--------------------------------------------------------------------------------




[{'accuracy': 0.5677083134651184,
  'loss': 1.2040480375289917,
  'test_acc': 0.5677083134651184,
  'test_loss': 1.2040480375289917}]

# Tensorboard for logging

In [8]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/version_4/

<IPython.core.display.Javascript object>