<a href="https://colab.research.google.com/github/therealmolf/personal-ml/blob/main/Lightning_and_wandb_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q pytorch-lightning wandb

In [None]:
import wandb

wandb.login()



True

In [None]:
!ls -a 

.  ..  .config	MNIST  sample_data  wandb


In [None]:
!cat /root/.netrc

machine api.wandb.ai
  login user
  password 8c7cdab700da2c2d101e16f8bf8a244fea77b8c4



We use a vanilla PyTorch dataloader and the canonical MNIST dataset


In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


dataset = MNIST(
    root="./MNIST", 
    download=True, 
    transform=transform)
training_set, validation_set = random_split(
    dataset,
    [55_000, 5_000])

In [None]:
??random_split

In [None]:
training_loader = DataLoader(
    training_set,
    batch_size=64,
    shuffle=True
                             )
validation_loader = DataLoader(
    validation_set,
    batch_size=64
)

In [None]:
??Linear

Defining Our Model

In [None]:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule


class MNIST_LitModule(LightningModule):

  def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
    '''
      method used to define our model parameters
    '''
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, heitght)
    self.layer_1 = Linear(28*28, n_layer_1)
    self.layer_2 = Linear(n_layer_1, n_layer_2)
    self.layer_3 = Linear(n_layer_2, n_classes)

    # loss
    self.loss = CrossEntropyLoss()

    # lr
    self.lr = lr

    # auto save hyperparams to self.hparams (auto by wandb)
    self.save_hyperparameters()

  def forward(self, x):
    
    batch_size, channels, width, height = x.size()

    # (b,1, 28, 28) -> (8m 1*28*28)
    x = x.view(batch_size, -1)

    # 3 x (linear + relu)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)
    x = F.relu(x)
    x = self.layer_3(x)

    return x

  def training_step(self, batch, batch_idx):
    '''
     needs to return a loss from a single batch
    '''
    _, loss, acc = self._get_preds_loss_accuracy(batch)

    # log loss and metric
    self.log('train_loss', loss)
    self.log('train_accuracy', acc)

    return loss

  def validation_step(self, batch, batch_idx):
    '''used for logging metrics'''
    preds, loss, acc = self._get_preds_loss_accuracy(batch)


    # Log loss and metric
    self.log('val_loss', loss)
    self.log('val_accuracy', acc)


    # Let's return preds to use it in a custom callback
    return preds


  def test_step(self, batch, batch_idx):
    '''used for logging metrics'''
    _, loss, acc = self._get_preds_loss_accuracy(batch)


    # Log loss and metric
    self.log('test_loss', loss)
    self.log('test_accuracy', acc)

  def configure_optimizers(self):
    '''
      defines model optimizer
    '''
    return Adam(self.parameters(), lr=self.lr)

  def _get_preds_loss_accuracy(self, batch):
    '''
      convenience function since train/valid/test steps are similar
    '''

    x, y = batch
    logits = self(x)
    preds = torch.argmax(logits, dim=1)
    loss = self.loss(logits, y)
    acc = accuracy(preds, y, task="multiclass", num_classes=10)

    return preds, loss, acc


In [None]:
?accuracy

In [None]:
# Takes Input Shape and Output Shape
??Linear

In [None]:
??torch.Size

In [None]:
%%time
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)

CPU times: user 1.89 ms, sys: 994 µs, total: 2.89 ms
Wall time: 6.91 ms


In [None]:
%%time 
model.layer_1

CPU times: user 32 µs, sys: 3 µs, total: 35 µs
Wall time: 38.4 µs


Linear(in_features=784, out_features=128, bias=True)

### Experiment Tracking

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# this is required along with WandbLogger to log checkpoints to W&B
checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

* For wandb, we have a project of runs that can be named
* One training run contains the histograms of gradients & parameters, hyperparameters, custom objects

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer


wandb_logger = WandbLogger(
    project='MNIST',
    log_model='all'
)

  rank_zero_warn(


Callback for automatically logging sample predictions during validation
* WandbLogger provides convenient media logging functions
* WandbLogger.log_text
* WandbLogger.log_image
* WandbLogger.log_table

In [None]:
from pytorch_lightning.callbacks import Callback

class LogPredictionsCallback(Callback):

  def on_validation_batch_end(
      self,
      trainer,
      pl_module,
      outputs,
      batch,
      batch_idx,
      dataloader_idx
  ):
    """
      Called when validation batch ends

      `outputs` come from LightningModule.validation_step
      which corresponds to our model predictions in this case
    """
    
    # log 20 sample image predictions from first batch
    if batch_idx == 0:
      n = 20
      x, y = batch
      images = [img for img in x[:n]]
      captions = [f'Ground Truth: {y_i} - Prediction {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]

      # Option 1: log images with 'WandbLogger
      wandb_logger.log_image(key='sample_images', images=images, caption=captions)

      # Option 2: log predictions as a Table
      columns = ['image', 'ground truth', 'prediction']
      data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
      wandb_logger.log_table(key='sample_table', columns=columns, data=data)

log_predictions_callback = LogPredictionsCallback()


Actual Training

In [None]:
trainer = Trainer(
    logger=wandb_logger,
    callbacks=[
        log_predictions_callback,
        checkpoint_callback
    ],
    max_epochs=5
)



INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, training_loader, validation_loader)

INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | layer_1 | Linear           | 100 K 
1 | layer_2 | Linear           | 16.5 K
2 | layer_3 | Linear           | 1.3 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
118 K     Trainable params
0         Non-trainable params
118 K     Total params
0.473     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


wandb.finish() is esp useful for notebooks as it is automatically called in scripts


In [None]:
wandb.finish()

VBox(children=(Label(value='6.946 MB of 6.961 MB uploaded (0.038 MB deduped)\r'), FloatProgress(value=0.997877…

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
train_accuracy,▂▆▃▁█▆▄▆▇▅█▇▅▅▅█▆▇█▅▇▇▅▅██▅▇▆▆▇▇▇▇▆▆▆▆▆█
train_loss,█▃▆▆▂▃▆▂▂▅▁▂▃▃▅▁▂▄▁▃▂▁▃▃▁▁▃▂▂▂▂▂▂▂▂▃▂▄▄▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
val_accuracy,▁▄▆▇█
val_loss,█▅▃▁▁

0,1
epoch,4.0
train_accuracy,0.95833
train_loss,0.23112
trainer/global_step,4299.0
val_accuracy,0.975
val_loss,0.09263


In [None]:
??torchmetrics.functional.accuracy()

Object `torchmetrics.functional.accuracy()` not found.
