<a href="https://colab.research.google.com/github/wsudswong/SYDE770_Project/blob/main/Pytorch_Lightning_with_Weights_%26_Biases.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -qqq wandb
!pip install -qqq pytorch-lightning

In [None]:
# Weights & Biases
import wandb
from pytorch_lightning.loggers import WandbLogger

# Pytorch modules
import torch
from torch.nn import functional as F
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split

# Pytorch-Lightning
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
import pytorch_lightning as pl

# Dataset
from torchvision.datasets import MNIST
from torchvision import transforms

In [None]:
class LitMNIST(LightningModule):

    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        # (2) HIDDEN LAYER SIZE
        # (4) LEARNING RATE
        '''method used to define our model parameters'''
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, n_layer_1)
        self.layer_2 = torch.nn.Linear(n_layer_1, n_layer_2)
        self.layer_3 = torch.nn.Linear(n_layer_2, n_classes)

        # optimizer parameters
        self.lr = lr

        # metrics
        self.accuracy = nn.torchmetrics.accuracy()

        # optional - save hyper-parameters to self.hparams
        # they will also be automatically logged as config parameters in W&B
        self.save_hyperparameters()

    def forward(self, x):
        '''method used for inference input -> output'''

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x) # (3) ACTIVATION FUNCTION ReLU hidden
        x = self.layer_2(x)
        x = F.relu(x) # (3) ACTIVATION FUNCTION ReLU hidden
        x = self.layer_3(x)

        x = F.log_softmax(x, dim=1) # (3) ACTIVATION FUNCTION softmax output
        return x

    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # Log training loss
        self.log('train_loss', loss)

        # Log metrics
        #self.log('train_acc', self.accuracy(logits, y))

        return loss

    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # Log validation loss (will be automatically averaged over an epoch)
        self.log('valid_loss', loss)

        # Log metrics
        #self.log('valid_acc', self.accuracy(logits, y))

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # Log test loss
        self.log('test_loss', loss)

        # Log metrics
        #self.log('test_acc', self.accuracy(logits, y))
    
    def configure_optimizers(self):
        '''defines model optimizer'''
        return Adam(self.parameters(), lr=self.lr) # (7) OPTIMIZER

In [None]:
class MNISTDataModule(LightningDataModule):

    def __init__(self, data_dir='./', batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def prepare_data(self):
        '''called only once and on 1 GPU'''
        # download data
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        '''called on each GPU separately - stage defines if we are at fit or test step'''
        # we set up only relevant datasets when stage is specified (automatically set by Pytorch-Lightning)
        if stage == 'fit' or stage is None:
            mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        '''returns training dataloader'''
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        '''returns validation dataloader'''
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
        return mnist_val

    def test_dataloader(self):
        '''returns test dataloader'''
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
        return mnist_test

In [None]:
wandb.login()

In [None]:
wandb_logger = WandbLogger(project='MNIST')

In [None]:
# setup data
mnist = MNISTDataModule()
 
# setup model - choose different hyperparameters per experiment
model = LitMNIST(n_layer_1=128, n_layer_2=256, lr=1e-3) # (2) HIDDEN LAYER SIZE (4) LEARNING RATE

In [None]:
trainer = Trainer(
    logger=wandb_logger,    # W&B integration
    gpus=-1,                # use all GPU's
    max_epochs=3            # number of epochs
    )

In [None]:
trainer.fit(model, mnist)

In [None]:
trainer.test(model, datamodule=mnist)

In [None]:
wandb.finish()

In [None]:
sweep_config = {
  "method": "random",   # Random search
  "metric": {           # We want to maximize val_acc
      "name": "valid_acc",
      "goal": "maximize"
  },
  "parameters": {
        "n_layer_1": {
            # Choose from pre-defined values
            "values": [32, 64, 128, 256, 512]
        },
        "n_layer_2": {
            # Choose from pre-defined values
            "values": [32, 64, 128, 256, 512]
        },
        "lr": {
            # log uniform distribution between exp(min) and exp(max)
            "distribution": "log_uniform",
            "min": -9.21,   # exp(-9.21) = 1e-4
            "max": -4.61    # exp(-4.61) = 1e-2
        }
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="MNIST")

In [None]:
def sweep_iteration():
    # set up W&B logger
    wandb.init()    # required to have access to `wandb.config`
    wandb_logger = WandbLogger()

    # setup data
    mnist = MNISTDataModule()

    # setup model - note how we refer to sweep parameters with wandb.config
    model = LitMNIST(
        n_layer_1=wandb.config.n_layer_1,
        n_layer_2=wandb.config.n_layer_2,
        lr=wandb.config.lr
    )

    # setup Trainer
    trainer = Trainer(
        logger=wandb_logger,    # W&B integration
        gpus=-1,                # use all GPU's
        max_epochs=3            # number of epochs
        )

    # train
    trainer.fit(model, mnist)

In [None]:
wandb.agent(sweep_id, function=sweep_iteration)