In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torch.nn.functional as F
from typing import List, Optional
from collections import OrderedDict
from torchvision import datasets, transforms, models
from torch.utils.data import random_split, DataLoader

In [2]:
class FoodDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", 
                 batch_size: int = 8, 
                 split: List[float] = [.7, .2, .1]):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.split = torch.tensor(split)

        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def setup(self, stage: Optional[str] = None):
        # check split
        if len(self.split) != 3: raise Exception("split size should be 3 (train, val, test)")
        # normalize split
        self.split = self.split / self.split.sum()

        # using ImageFolder
        food_images = datasets.ImageFolder(self.data_dir, self.transform)
        self.classes = food_images.classes

        # set dims
        self.dims = tuple(food_images[0][0].shape)

        # counts/splits
        sz = len(food_images)
        train_sz = math.floor(self.split[0] * sz)
        val_sz = math.floor(self.split[1] * sz)
        test_sz = sz - train_sz - val_sz

        # split
        self.train, self.val, self.test = random_split(food_images, [train_sz, val_sz, test_sz])

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [6]:
class FoodModel(pl.LightningModule):
    def __init__(self, classes: List[str], learning_rate: float = 0.1):
        super().__init__()
        self.lr = learning_rate
        self.classes = classes

        # model layers
        self.xfer = models.resnet18(pretrained=True)
        self.fc1 = nn.Linear(1000, 256)
        self.fc2 = nn.Linear(256, len(self.classes))

    def forward(self, x):
        x = F.relu(self.xfer(x))
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=-1)

    @staticmethod
    def __accuracy(outputs, Y):
        with torch.no_grad():
            _, preds = torch.max(outputs, 1)
            return (preds == Y).float().sum().item()

    def __step(self, batch):
        X, Y = batch
        outputs = self(X)
        loss = F.cross_entropy(outputs, Y)
        return loss, self.__accuracy(outputs, Y)


    def training_step(self, batch, batch_idx):
        loss, acc = self.__step(batch)

        tqdm_dict = {'train_loss': loss}
        return OrderedDict({
            'loss': loss,
            'acc': acc
        })

    def validation_step(self, batch, batch_idx):
        loss, acc = self.__step(batch)

        tqdm_dict = {'train_loss': loss}
        return OrderedDict({
            'val_loss': loss,
            'val_acc': acc
        })

    def configure_optimizers(self):
        optimizer = optim.SGD(
            self.parameters(), 
            lr=self.lr
        )
        scheduler = optim.lr_scheduler.StepLR(
            optimizer, 
            step_size=7, 
            gamma=0.1
        )
        return [optimizer], [scheduler]

In [7]:
dm = FoodDataModule(data_dir='../data/food')
dm.setup()
model = FoodModel(classes=dm.classes)

In [5]:
trainer = pl.Trainer(gpus=1)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | xfer | ResNet | 11.7 M
1 | fc1  | Linear | 256 K 
2 | fc2  | Linear | 514   
--------------------------------
11.9 M    Trainable params
0         Non-trainable params
11.9 M    Total params
47.785    Total estimated model params size (MB)
Epoch 0:  78%|███████▊  | 226/291 [01:22<00:23,  2.75it/s, loss=0.543, v_num=13]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/65 [00:00<?, ?it/s][A
Epoch 0:  78%|███████▊  | 228/291 [01:22<00:22,  2.77it/s, loss=0.543, v_num=13]
Validating:   3%|▎         | 2/65 [00:00<00:28,  2.18it/s][A
Validating:   5%|▍         | 3/65 [00:01<00:19,  3.11it/s][A
Epoch 0:  79%|███████▉  | 231/291 [01:23<00:21,  2.77it/s, loss=0.543, v_num=13]
Validating:   8%|▊         | 5/65 [00:01<00:22,  2.69it/s][A
Validating:   9%|▉         | 6/65 [00:01<00:18,  3.16it/s][A
Epoch 0:  80%|████████  | 234/291 [01:24<00:

1