In [None]:
# !pip install git+https://github.com/PyTorchLightning/pytorch-lightning

In [1]:
import torch
from torch import nn
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
from torchvision import transforms, datasets, models
from torch.utils.data import random_split, DataLoader

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


In [2]:
pl.seed_everything(42)
max_epochs = 3
img_size = 224
val_pct = 0.2
batch_size = 64
lr = 3e-4

Global seed set to 42


In [3]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, img_size, val_pct, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.T = transforms.Compose(
                    [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor()
                    ]
                )
        self.val_pct = val_pct
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        datasets.CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage in ('fit', 'validate'):
            data = datasets.CIFAR10(self.data_dir, train=True, transform=self.T)
            val_len = int(self.val_pct * len(data))
            self.train_data, self.val_data = random_split(data, [len(data) - val_len, val_len])
        elif stage in ('test', 'predict'):    
            self.test_data = datasets.CIFAR10(self.data_dir, train=False, transform=self.T)

    def get_dataloader(self, data):
        return DataLoader(data, batch_size=self.batch_size, num_workers=2, pin_memory=True)

    def train_dataloader(self):
        return self.get_dataloader(self.train_data)

    def val_dataloader(self):
        return self.get_dataloader(self.val_data)

    def test_dataloader(self):
        return self.get_dataloader(self.test_data)

    def predict_dataloader(self):
        return self.get_dataloader(self.test_data)

In [4]:
class Model(pl.LightningModule):
    def __init__(self, lr):
        super().__init__()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 10)

        self.lr = lr
        self.loss_fn = nn.CrossEntropyLoss()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        return self.resnet(x)
        
    def shared_step(self, batch, split):
        x, y = batch
        preds = self(x)
        if split != 'predict':
            loss = self.loss_fn(preds, y)
            acc = accuracy(preds, y)
            self.log(f'{split}_loss', loss, on_epoch=True, prog_bar=True)
            self.log(f'{split}_acc', acc, on_epoch=True, prog_bar=True)
            if split == 'train':
                return loss
        else:
            return preds

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self.shared_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self.shared_step(batch, 'test')
    
    def predict_step(self, batch, batch_idx):
        preds = self.shared_step(batch, 'predict')
        return preds

In [5]:
class PrintMetrics(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        final_str = f"epoch: {trainer.current_epoch} | "
        for k, v in trainer.callback_metrics.items():
            if "train" in k and "epoch" in k:
                final_str += f"{k.replace('_epoch', '')}: {v.item():.4f} | "
            elif "val" in k:
                final_str += f"{k}: {v.item():.4f} | "
        print(final_str[:-3])

In [6]:
cifar10_dm = CIFAR10DataModule('data/', img_size, val_pct, batch_size)
model = Model(lr)

In [7]:
trainer = pl.Trainer(max_epochs=max_epochs, gpus=1, callbacks=[PrintMetrics()])
trainer.fit(model, cifar10_dm) # uses train and val dataloaders

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


epoch: 0 | val_loss: 1.4289 | val_acc: 0.5371 | train_loss: 1.2855 | train_acc: 0.5322


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

epoch: 1 | val_loss: 0.8967 | val_acc: 0.6900 | train_loss: 0.8118 | train_acc: 0.7129


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

epoch: 2 | val_loss: 0.8359 | val_acc: 0.7215 | train_loss: 0.5916 | train_acc: 0.7961



In [8]:
trainer.validate(model, cifar10_dm) # uses val dataloader

  'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.7214999794960022, 'val_loss': 0.8358551859855652}
--------------------------------------------------------------------------------


[{'val_acc': 0.7214999794960022, 'val_loss': 0.8358551859855652}]

In [9]:
trainer.test(model, cifar10_dm) # uses test dataloader

  'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.715399980545044, 'test_loss': 0.8482254147529602}
--------------------------------------------------------------------------------



[{'test_acc': 0.715399980545044, 'test_loss': 0.8482254147529602}]

In [10]:
preds = trainer.predict(model, cifar10_dm) # uses predict dataloader(which is usually the test dataloader)
print(len(preds), len(preds[0]), preds[0][0].shape) # num_batches, batch_size, shape of output of first element in first batch(num classes)

  'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Predicting', layout=Layout(flex='2'), m…


157 64 torch.Size([10])
