In [None]:
# !pip install pytorch-lightning wandb

In [None]:
import wandb
import torch
from torch import nn
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import random_split, DataLoader
from torchvision import transforms, datasets, models

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mzer0sh0t[0m (use `wandb login --relogin` to force relogin)


True

In [2]:
pl.seed_everything(42)
# start_epoch = 1
max_epochs = 3
img_size = 224
val_pct = 0.2
batch_size = 64
lr = 3e-4

Global seed set to 42


In [3]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, img_size, val_pct, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.T = transforms.Compose(
                    [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor()
                    ]
                )
        self.val_pct = val_pct
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        datasets.CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage in ('fit', 'validate'):
            data = datasets.CIFAR10(self.data_dir, train=True, transform=self.T)
            val_len = int(val_pct * len(data))
            self.train_data, self.val_data = random_split(data, [len(data) - val_len, val_len])
        elif stage in ('test', 'predict'):    
            self.test_data = datasets.CIFAR10(self.data_dir, train=False, transform=self.T)

    def get_dataloader(self, data):
        return DataLoader(data, batch_size=self.batch_size, num_workers=2, pin_memory=True)

    def train_dataloader(self):
        return self.get_dataloader(self.train_data)

    def val_dataloader(self):
        return self.get_dataloader(self.val_data)

In [4]:
class Model(pl.LightningModule):
    # def __init__(self, start_epoch, lr):
    def __init__(self, lr):
        super().__init__()
        # self.epoch = start_epoch - 1 if start_epoch != 0 else start_epoch
        self.save_hyperparameters()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 10)
        self.lr = lr
        self.loss_fn = nn.CrossEntropyLoss()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        return self.resnet(x)
        
    def shared_step(self, batch, split):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        acc = accuracy(preds.softmax(-1), y)
        # self.log('custom_epoch', self.epoch, on_epoch=True, on_step=False, prog_bar=False)
        self.log(f'{split}_loss', loss, on_epoch=True, prog_bar=True)
        self.log(f'{split}_acc', acc, on_epoch=True, prog_bar=True)
        if split == 'train':
            return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self.shared_step(batch, 'val')
        
    # def validation_epoch_end(self, inputs):
    #     self.epoch += 1

In [5]:
cifar10_dm = CIFAR10DataModule('data/', img_size, val_pct, batch_size)
model = Model(lr) # Model(start_epoch, lr)

In [6]:
wandb_logger = WandbLogger(project='test')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=1, logger=wandb_logger, log_every_n_steps=50, deterministic=True)
trainer.fit(model, cifar10_dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mzer0sh0t[0m (use `wandb login --relogin` to force relogin)



  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [15]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss_step,0.43711
train_acc_step,0.82812
epoch,2.0
trainer/global_step,1874.0
_runtime,630.0
_timestamp,1624805744.0
_step,42.0
train_loss_epoch,0.59721
train_acc_epoch,0.79397
val_loss,1.19134


0,1
train_loss_step,█▆▇▅▅▆▅▅▄▅▄▄▄▄▄▄▃▅▃▄▃▂▁▃▂▃▂▄▂▂▂▂▃▁▂▁▂
train_acc_step,▁▁▁▃▃▂▄▃▄▃▄▅▅▄▅▅▆▂▆▄▄▇▇▅▆▆▆▅▆▇▇▆▆█▇▇▆
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅█████████████
trainer/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss_epoch,█▃▁
train_acc_epoch,▁▆█
val_loss,█▁▃
