In [1]:
# !pip install git+https://github.com/PyTorchLightning/pytorch-lightning

In [None]:
import torch
from torch import nn
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
from torch.utils.data import random_split, DataLoader
from torchvision import transforms, datasets, models

In [3]:
pl.seed_everything(42)
max_epochs = 3
img_size = 224
batch_size = 64
val_pct = 0.2
lr = 3e-4

Global seed set to 42


In [4]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, img_size, batch_size, val_pct):
        super().__init__()
        self.data_dir = data_dir
        self.T = transforms.Compose(
                    [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor()
                    ]
                )
        self.val_pct = val_pct
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        datasets.CIFAR10(self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        data = datasets.CIFAR10(self.data_dir, train=True, transform=self.T)
        val_len = int(val_pct * len(data))
        self.train_data, self.val_data = random_split(data, [len(data) - val_len, val_len])
        self.test_data = datasets.CIFAR10(self.data_dir, train=False, transform=self.T)

    def get_dataloader(self, data):
        return DataLoader(data, batch_size=self.batch_size, num_workers=2, pin_memory=True)

    def train_dataloader(self):
        return self.get_dataloader(self.train_data)

    def val_dataloader(self):
        return self.get_dataloader(self.val_data)

    def test_dataloader(self):
        return self.get_dataloader(self.test_data)

In [5]:
class Model(pl.LightningModule):
    def __init__(self, num_classes, lr, loss_fn):
        super().__init__()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

        self.lr = lr
        self.loss_fn = loss_fn

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        return self.resnet(x)

    def shared_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        acc = accuracy(preds, y)
        return loss, acc

    def shared_log(self, split, loss, acc):
        self.log(f'{split}_loss', loss, on_epoch=True, prog_bar=True)
        self.log(f'{split}_acc', acc, on_epoch=True, prog_bar=True)

    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.shared_log('train', loss, acc)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.shared_log('val', loss, acc)

    def test_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.shared_log('test', loss, acc)

In [6]:
cifar10_dm = CIFAR10DataModule('data/', img_size, batch_size, val_pct)
loss_fn = nn.CrossEntropyLoss()
model = Model(10, lr, loss_fn)

In [None]:
class PrintMetrics(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        final_str = f"epoch: {trainer.current_epoch} | "
        for k, v in trainer.callback_metrics.items():
            if "train" in k and "epoch" in k:
                final_str += f"{k.replace('_epoch', '')}: {v.item():.4f} | "
            elif "val" in k:
                final_str += f"{k}: {v.item():.4f} | "
        print(final_str[:-3])

In [8]:
callbacks = [PrintMetrics()]
trainer = pl.Trainer(max_epochs=max_epochs, gpus=1, callbacks=callbacks)
trainer.fit(model, cifar10_dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


epoch: 0 | val_loss: 1.6040 | val_acc: 0.4987 | train_loss: 1.2869 | train_acc: 0.5332


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

epoch: 1 | val_loss: 1.1130 | val_acc: 0.6267 | train_loss: 0.8232 | train_acc: 0.7116


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

epoch: 2 | val_loss: 0.9274 | val_acc: 0.6879 | train_loss: 0.6046 | train_acc: 0.7911

