<a href="https://colab.research.google.com/github/timsetsfire/wandb-examples/blob/main/colab/Sweeps_with_Pytorch_Lightning_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://wandb.me/logo-im-png" width="400" alt="Weights & Biases" />

<!--- @wandbcode{pytorch-lightning-colab} -->

# ⚡ Pytorch Lightning models with Weights & Biases

### 🛠️ Installation and set-up

In [3]:
!pip install -q pytorch-lightning wandb

[K     |████████████████████████████████| 701 kB 5.1 MB/s 
[K     |████████████████████████████████| 1.8 MB 60.8 MB/s 
[K     |████████████████████████████████| 419 kB 72.8 MB/s 
[K     |████████████████████████████████| 141 kB 74.7 MB/s 
[K     |████████████████████████████████| 5.9 MB 50.1 MB/s 
[K     |████████████████████████████████| 596 kB 70.8 MB/s 
[K     |████████████████████████████████| 181 kB 64.8 MB/s 
[K     |████████████████████████████████| 157 kB 57.2 MB/s 
[K     |████████████████████████████████| 63 kB 2.5 MB/s 
[K     |████████████████████████████████| 157 kB 75.9 MB/s 
[K     |████████████████████████████████| 157 kB 78.0 MB/s 
[K     |████████████████████████████████| 157 kB 75.0 MB/s 
[K     |████████████████████████████████| 156 kB 75.7 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source 

We make sure we're logged into W&B so that our experiments can be associated with our account.

In [4]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## 📊 Setting up the dataloader

For the context of this tutorial we use vanilla pytorch dataloaders on the MNIST dataset

In [5]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import os
import glob

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

dataset = MNIST(root="./MNIST", download=True, transform=transform)
training_set, validation_set = random_split(dataset, [55000, 5000])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw



## 🤓 Defining the Model

**Tips**:
* Call `self.save_hyperparameters()` in `__init__` to automatically log your hyperparameters to **W&B**
* Call self.log in `training_step` and `validation_step` to log the metrics

In [6]:
from torch.optim import SGD
from torch.optim import Adagrad

In [7]:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

class MNIST_LitModule(LightningModule):

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3, optim = "Adam"):
        '''method used to define our model parameters'''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer parameters
        self.lr = lr

        # save hyper-parameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters()

        self.optim = optim

    def forward(self, x):
        '''method used for inference input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # let's do 3 x (linear + relu)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        return x

    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)

        return loss

    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        # Let's return preds to use it in a custom callback
        return preds

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)
    
    def configure_optimizers(self):
        '''defines model optimizer'''
        if self.optim == "Adam":
          return Adam(self.parameters(), lr=self.lr)
        elif self.optim == "SGD":
          return SGD(self.parameters(), lr = self.lr)
        elif self.optim == "Adagrad":
          return Adagrad(self.parameters(), lr = self.lr)
    
    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

## 💾 Save Model Checkpoints

The `ModelCheckpoint` callback is required along with the `WandbLogger` argument to log model checkpoints to W&B.

## 💡 Tracking Experiments with WandbLogger

PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`

#### `pytorch_lightning.loggers.WandbLogger()`

| Functionality | Argument/Function | PS |
| ------ | ------ | ------ |
| Logging models | `WandbLogger(... ,log_model='all')` or `WandbLogger(... ,log_model=True`) | Log all models if `log_model="all"` and at end of training if `log_model=True`
| Set custom run names | `WandbLogger(... ,name='my_run_name'`) | |
| Organize runs by project | `WandbLogger(... ,project='my_project')` | |
| Log histograms of gradients and parameters | `WandbLogger.watch(model)`  | `WandbLogger.watch(model, log='all')` to log parameter histograms  |
| Log hyperparameters | Call `self.save_hyperparameters()` within `LightningModule.__init__()` |
| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table` |

See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. 

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

## ⚙️ Using WandbLogger to log Images, Text and More
Pytorch Lightning is extensible through its callback system. We can create a custom callback to automatically log sample predictions during validation. `WandbLogger` provides convenient media logging functions:
* `WandbLogger.log_text` for text data
* `WandbLogger.log_image` for images
* `WandbLogger.log_table` for [W&B Tables](https://docs.wandb.ai/guides/data-vis).

An alternate to `self.log` in the Model class is directly using `wandb.log({dict})` or `trainer.logger.experiment.log({dict})`

In this case we log the first 20 images in the first batch of the validation dataset along with the predicted and ground truth labels.

In [6]:
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
class LogPredictionsCallback(Callback):
    
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        """Called when the validation batch ends."""
 
        # `outputs` comes from `LightningModule.validation_step`
        # which corresponds to our model predictions in this case
        
        # Let's log 20 sample image predictions from first batch
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]
            
            # Option 1: log images with `WandbLogger.log_image`
            wandb_logger.log_image(key='sample_images', images=images, caption=captions)

            # Option 2: log predictions as a Table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(key='sample_table', columns=columns, data=data)



## 🏋️‍ Train Your Model

In [None]:
# from pytorch_lightning.loggers import WandbLogger

# config = {"batch_size": 64,  # try log-spaced values from 1 to 50,000
#           "num_workers": os.cpu_count(),  # try 0, 1, and 2
#           "pin_memory": True,  # try False and True
#           "precision": 32,  # try 16 and 32
#           "optimizer": "Adam", 
#           "learning_rate": 0.001
#           }

# wandb_logger = WandbLogger(entity = 'tim-w', project='MNIST-v4', # group runs in "MNIST" project
#                            log_model='all', save_code=True, config=config, sync_tensorboard=True) # log all new checkpoints during training

# training_loader = DataLoader(training_set, batch_size=64, shuffle=True, pin_memory=True)
# validation_loader = DataLoader(validation_set, batch_size=64, pin_memory=True)
# ## Using a raw DataLoader, rather than LightningDataModule, for greater transparency

# # Set up model
# model = MNIST_LitModule(n_layer_1=128, n_layer_2=128, optim = wandb.config.optimizer, lr = wandb.config.learning_rate)
# wandb.watch(model)
# trainer = Trainer(gpus=None, max_epochs=5, profiler="pytorch",logger=wandb_logger,
#                           callbacks=[
#                                      log_predictions_callback, 
#                                      checkpoint_callback
#                                      ], 
#                       precision=32)
# trainer.profiler.dirpath="./wandb/latest-run/tbprofile"
# trainer.fit(model, training_loader, validation_loader)
# # trace_files = glob.glob("/content/lightning_logs/*.pt.trace.json")
# trace_files = glob.glob("./wandb/latest-run/tbprofile/*.pt.trace.json")
# for i, trace_file in enumerate(trace_files):
#     if "training_step" in trace_file:
#       profile_art = wandb.Artifact(f"train-trace{i}-{wandb.run.id}", type="profile")
#       profile_art.add_file(trace_file, "train_trace.pt.trace.json")
#     else:
#       profile_art = wandb.Artifact(f"validation-trace{i}-{wandb.run.id}", type="profile")
#       profile_art.add_file(trace_file, "validation_trace.pt.trace.json")
#     wandb.log_artifact(profile_art)
# wandb.finish()

In [16]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

sweep_config = {
  'method': 'grid', 
  'metric': {
      'name': 'Validation Metrics/loss',  ## matches what i write via SummaryWriter
      'goal': 'minimize'
  },
  'early_terminate':{
      'type': 'hyperband',
      'min_iter': 5
  },
  'parameters': {
      'learning_rate':{
          'values': [0.05,0.025,0.01,0.005,0.001]
      }, 
      'batch_size': { 
          'values': [128, 256]
      }
  }
}


def sweep_train(config_defaults = dict(learning_rate=0.01, batch_size = 128)): 

  config_standard = {
          "num_workers": os.cpu_count(),  # try 0, 1, and 2
          "pin_memory": True,  # try False and True
          "precision": 32,  # try 16 and 32
          "optimizer": "Adam", 
          }
  
  config = {**config_defaults, **config_standard}
  
  wandb_logger = WandbLogger(entity = 'tim-w', project='MNIST-v4', # group runs in "MNIST" project
                             log_model='all', save_code=True, config=config, sync_tensorboard=True)

  class LogPredictionsCallback(Callback):
      
      def on_validation_batch_end(
          self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
          """Called when the validation batch ends."""
  
          # `outputs` comes from `LightningModule.validation_step`
          # which corresponds to our model predictions in this case
          
          # Let's log 20 sample image predictions from first batch
          if batch_idx == 0:
              n = 20
              x, y = batch
              images = [img for img in x[:n]]
              captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]
              
              # Option 1: log images with `WandbLogger.log_image`
              wandb_logger.log_image(key='sample_images', images=images, caption=captions)

              # Option 2: log predictions as a Table
              columns = ['image', 'ground truth', 'prediction']
              data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
              wandb_logger.log_table(key='sample_table', columns=columns, data=data)  

  log_predictions_callback = LogPredictionsCallback()
  checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

  training_loader = DataLoader(training_set, batch_size=wandb.config.batch_size, shuffle=True, pin_memory=True)
  validation_loader = DataLoader(validation_set, batch_size=64, pin_memory=True)
  ## Using a raw DataLoader, rather than LightningDataModule, for greater transparency

  # Set up model
  model = MNIST_LitModule(n_layer_1=128, n_layer_2=128, optim = wandb.config.optimizer, lr = wandb.config.learning_rate)
  wandb.watch(model)
  trainer = Trainer(gpus=None, max_epochs=5,logger=wandb_logger,
                            callbacks=[
                                      log_predictions_callback
                                      ], 
                        precision=32)
  trainer.fit(model, training_loader, validation_loader)
  wandb.finish()

In [17]:
sweep_train()

Hint: Upgrade with `pip install --upgrade wandb`.
  f"Providing log_model={log_model} requires wandb version >= 0.10.22"


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | layer_1 | Linear           | 100 K 
1 | layer_2 | Linear           | 16.5 K
2 | layer_3 | Linear           | 1.3 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
118 K     Trainable params
0         Non-trainable params
118 K     Total params
0.473     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


VBox(children=(Label(value='7.140 MB of 7.140 MB uploaded (0.009 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
train_accuracy,▂▄▆▃▁▄▃▄▅▄▅▄▄▇▇▆▇▅▇▄▆▆▇▅▅▅▆▆█▆▇▆█▇█▅▇▅▅▇
train_loss,▇▅▃▆▇▃▆▆▄▅▃▃▃▂▁▂▂▄▂▅▆▃▂▅█▅▄▃▁▂▃▃▁▃▂▂▂▂▃▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇████
val_accuracy,▁▆█▇█
val_loss,█▅▁▂▂

0,1
epoch,4.0
train_accuracy,0.97727
train_loss,0.05155
trainer/global_step,2149.0
val_accuracy,0.9608
val_loss,0.15138


In [18]:
sweep_id = wandb.sweep(sweep_config, project="ptl-sweeps-example")

Create sweep with ID: 8iotwmxa
Sweep URL: https://wandb.ai/tim-w/MNIST-v4/sweeps/8iotwmxa


In [19]:
wandb_agent = wandb.agent(sweep_id, function=sweep_train, count = 1)

[34m[1mwandb[0m: Agent Starting Run: p5o450ok with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	learning_rate: 0.05
Hint: Upgrade with `pip install --upgrade wandb`.
  f"Providing log_model={log_model} requires wandb version >= 0.10.22"


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | layer_1 | Linear           | 100 K 
1 | layer_2 | Linear           | 16.5 K
2 | layer_3 | Linear           | 1.3 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
118 K     Trainable params
0         Non-trainable params
118 K     Total params
0.473     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


VBox(children=(Label(value='7.213 MB of 7.213 MB uploaded (0.009 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
train_accuracy,▁▄▃▂▄▄▃▄▅▂▃▄▄▄█▇▅▆▅▄▆▆▆▆▆▇▅▄▆▅▅▅▆▆▅▅▆▅▅█
train_loss,█▅▆▇▆▄▄▄▃▆▆▄▅▃▁▂▃▃▄▄▂▂▃▂▂▃▃▅▄▃▃▄▂▄▃▄▃▃▄▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇████
val_accuracy,▁█▇▇▅
val_loss,█▂▁▁▃

0,1
epoch,4.0
train_accuracy,0.75
train_loss,0.83828
trainer/global_step,2149.0
val_accuracy,0.6086
val_loss,1.02124
