In [1]:
!pip install -qqq wandb lightning torchmetrics onnx

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m78.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m61.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m86.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.5/258.5 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.9/800.9 kB[0m [31m66.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# numpy for non-GPU array math
import numpy as np

# 🍦 Vanilla PyTorch
import torch
from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split

# 👀 Torchvision for CV
from torchvision.datasets import MNIST
from torchvision import transforms

# remove slow mirror from list of MNIST mirrors
MNIST.mirrors = [mirror for mirror in MNIST.mirrors
                 if not mirror.startswith("http://yann.lecun.com")]

In [3]:
# ⚡ PyTorch Lightning
import lightning.pytorch as pl
import torchmetrics
pl.seed_everything(hash("setting random seeds") % 2**32 - 1)

# 🏋️‍♀️ Weights & Biases
import wandb

# ⚡ 🤝 🏋️‍♀️
from lightning.pytorch.loggers import WandbLogger


INFO: Seed set to 3526018962
INFO:lightning.fabric.utilities.seed:Seed set to 3526018962


In [4]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
class LitMLP(pl.LightningModule):

    def __init__(self, in_dims, n_classes=10,
                 n_layer_1=128, n_layer_2=256, lr=1e-4):
        super().__init__()

        # we flatten the input Tensors and pass them through an MLP
        self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)
        self.layer_2 = nn.Linear(n_layer_1, n_layer_2)
        self.layer_3 = nn.Linear(n_layer_2, n_classes)

        # log hyperparameters
        self.save_hyperparameters()

        # compute the accuracy -- no need to roll your own!
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)
        self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, x):
        """
        Defines a forward pass using the Stem-Learner-Task
        design pattern from Deep Learning Design Patterns:
        https://www.manning.com/books/deep-learning-design-patterns
        """
        batch_size, *dims = x.size()

        # stem: flatten
        x = x.view(batch_size, -1)

        # learner: two fully-connected layers
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))

        # task: compute class logits
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)

        return x

    # convenient method to get the loss on a batch
    def loss(self, xs, ys):
        logits = self(xs)  # this calls self.forward
        loss = F.nll_loss(logits, ys)
        return logits, loss


In [7]:
def training_step(self, batch, batch_idx):
    xs, ys = batch
    logits, loss = self.loss(xs, ys)
    preds = torch.argmax(logits, 1)

    # logging metrics we calculated by hand
    self.log('train/loss', loss, on_epoch=True)
    # logging a pl.Metric
    self.train_acc(preds, ys)
    self.log('train/acc', self.train_acc, on_epoch=True)

    return loss

def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])

LitMLP.training_step = training_step
LitMLP.configure_optimizers = configure_optimizers

In [8]:
def test_step(self, batch, batch_idx):
    xs, ys = batch
    logits, loss = self.loss(xs, ys)
    preds = torch.argmax(logits, 1)

    self.test_acc(preds, ys)
    self.log("test/loss_epoch", loss, on_step=False, on_epoch=True)
    self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True)

LitMLP.test_step = test_step

In [9]:
def on_test_epoch_end(self):  # args are defined as part of pl API
    dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
    model_filename = "model_final.onnx"
    self.to_onnx(model_filename, dummy_input, export_params=True)
    artifact = wandb.Artifact(name="model.ckpt", type="model")
    artifact.add_file(model_filename)
    wandb.log_artifact(artifact)

LitMLP.on_test_epoch_end = on_test_epoch_end

In [10]:
def on_validation_epoch_start(self):
    self.validation_step_outputs = []

def validation_step(self, batch, batch_idx):
    xs, ys = batch
    logits, loss = self.loss(xs, ys)
    preds = torch.argmax(logits, 1)
    self.valid_acc(preds, ys)

    self.log("valid/loss_epoch", loss)  # default on val/test is on_epoch only
    self.log('valid/acc_epoch', self.valid_acc)

    self.validation_step_outputs.append(logits)

    return logits

def on_validation_epoch_end(self):

    validation_step_outputs = self.validation_step_outputs

    dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
    model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
    torch.onnx.export(self, dummy_input, model_filename, opset_version=11)
    artifact = wandb.Artifact(name="model.ckpt", type="model")
    artifact.add_file(model_filename)
    self.logger.experiment.log_artifact(artifact)

    flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
    self.logger.experiment.log(
        {"valid/logits": wandb.Histogram(flattened_logits.to("cpu")),
         "global_step": self.global_step})

LitMLP.on_validation_epoch_start = on_validation_epoch_start
LitMLP.validation_step = validation_step
LitMLP.on_validation_epoch_end = on_validation_epoch_end

In [11]:
class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]

    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)

        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, 1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
                            for x, pred, y in zip(val_imgs, preds, self.val_labels)],
            "global_step": trainer.global_step
            })

In [12]:
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir='./', batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download data, train then test
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # we set up only relevant datasets when stage is specified
        if stage == 'fit' or stage is None:
            mnist = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    # we define a separate DataLoader for each of train/val/test
    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)
        return mnist_test

# setup data
mnist = MNISTDataModule()
mnist.prepare_data()
mnist.setup()

# grab samples to log predictions on
samples = next(iter(mnist.val_dataloader()))

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5057257.79it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 133886.51it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1258898.33it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4602688.76it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [13]:
wandb_logger = WandbLogger(project="lit-wandb")

In [14]:
trainer = pl.Trainer(
    logger=wandb_logger,    # W&B integration
    log_every_n_steps=50,   # set the logging frequency
    max_epochs=5,           # number of epochs
    deterministic=True,     # keep it deterministic
    callbacks=[ImagePredictionLogger(samples)] # see Callbacks section
    )

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [15]:
# setup model
model = LitMLP(in_dims=(1, 28, 28))

# fit the model
trainer.fit(model, mnist)

# evaluate the model on a test set
trainer.test(datamodule=mnist,
             ckpt_path=None)  # uses last-saved model

wandb.finish()

INFO: You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Currently logged in as: [33msaahilkataria1[0m ([33msaahil1[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name      | Type               | Params
-------------------------------------------------
0 | layer_1   | Linear             | 100 K 
1 | layer_2   | Linear             | 33.0 K
2 | layer_3   | Linear             | 2.6 K 
3 | train_acc | MulticlassAccuracy | 0     
4 | valid_acc | MulticlassAccuracy | 0     
5 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name      | Type               | Params
-------------------------------------------------
0 | layer_1   | Linear             | 100 K 
1 | layer_2   | Linear             | 33.0 K
2 | layer_3   | Linear             | 2.6 K 
3 | train_acc | MulticlassAccuracy | 0     
4 | va

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO: Restoring states from the checkpoint path at ./lit-wandb/h9qsu54c/checkpoints/epoch=4-step=2150.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at ./lit-wandb/h9qsu54c/checkpoints/epoch=4-step=2150.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at ./lit-wandb/h9qsu54c/checkpoints/epoch=4-step=2150.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at ./lit-wandb/h9qsu54c/checkpoints/epoch=4-step=2150.ckpt
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers`

Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='3.723 MB of 3.723 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇▇█
global_step,▁▁▂▂▄▄▅▅▇▇██
test/acc_epoch,▁
test/loss_epoch,▁
train/acc_epoch,▁▆▇██
train/acc_step,▁▂▅▃▆▆▆▅▆▆▇▆▇▅▇▆▄▇▇▆▅▇▇▇▇▇█▇▇▇▆▆▆▇▆▆▇▅█▇
train/loss_epoch,█▂▂▁▁
train/loss_step,█▅▃▃▃▂▂▂▂▂▂▂▂▂▂▂▃▂▁▂▂▁▁▂▁▁▁▂▁▁▂▂▂▁▂▂▂▂▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
valid/acc_epoch,▁▄▆▇█

0,1
epoch,5.0
global_step,2150.0
test/acc_epoch,0.9488
test/loss_epoch,0.16834
train/acc_epoch,0.94849
train/acc_step,0.95455
train/loss_epoch,0.17752
train/loss_step,0.13514
trainer/global_step,2150.0
valid/acc_epoch,0.9478
