In [1]:
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
from torch.nn import functional as F
import torchmetrics
import wandb
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split

In [2]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger(project="f5611")

[34m[1mwandb[0m: Currently logged in as: [33mpkantek[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir='../data', batch_size=64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            cifar10 = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar10, [45000, 5000])
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=10 * self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=10 * self.batch_size, num_workers=4)

In [4]:
def get_acc(n_classes):
    if n_classes > 2:
        return torchmetrics.Accuracy(task="multiclass", num_classes=n_classes)
    return torchmetrics.Accuracy(task="bianry")

# Define the PyTorch Lightning model
class CIFAR10Model(pl.LightningModule):
    def __init__(self, model, in_dims, n_classes=10, lr=1e-4):
        super().__init__()
        self.model = model
        
        self.save_hyperparameters(ignore=['model'])
        
        self.train_acc = get_acc(n_classes)
        self.valid_acc = get_acc(n_classes)
        self.test_acc = get_acc(n_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)
    
        self.log('train/loss', loss, on_epoch=True)
        self.train_acc(preds, y)
        self.log('train/acc', self.train_acc, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)

        self.valid_acc(preds, y)
        self.log("valid/loss_epoch", loss)  # default on val/test is on_epoch only
        self.log('valid/acc_epoch', self.valid_acc)

        return logits

    def validation_epoch_end(self, outputs):
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
        torch.onnx.export(self, dummy_input, model_filename, opset_version=11)
        artifact = wandb.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log_artifact(artifact)

        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        self.logger.experiment.log(
            {"valid/logits": wandb.Histogram(flattened_logits.to("cpu")),
             "global_step": self.global_step})
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)

        self.test_acc(preds, y)
        self.log("test/loss_epoch", loss, on_step=False, on_epoch=True)
        self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True)
    
    def test_epoch_end(self, test_step_outputs):  # args are defined as part of pl API
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = "model_final.onnx"
        self.to_onnx(model_filename, dummy_input, export_params=True)
        artifact = wandb.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        wandb.log_artifact(artifact)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits, loss = self.loss(x, y)
        preds = torch.argmax(logits, 1)

        self.valid_acc(preds, y)
        self.log("valid/loss_epoch", loss)  # default on val/test is on_epoch only
        self.log('valid/acc_epoch', self.valid_acc)

        return logits

    def validation_epoch_end(self, validation_step_outputs):
        dummy_input = torch.zeros(self.hparams["in_dims"], device=self.device)
        model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
        torch.onnx.export(self, dummy_input, model_filename, opset_version=11)
        artifact = wandb.Artifact(name="model.ckpt", type="model")
        artifact.add_file(model_filename)
        self.logger.experiment.log_artifact(artifact)

        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        self.logger.experiment.log(
            {"valid/logits": wandb.Histogram(flattened_logits.to("cpu")),
             "global_step": self.global_step})

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])
    
    def loss(self, x, y):
        logits = self(x)
        cel = torch.nn.CrossEntropyLoss()
        loss = cel(logits, y)
        return logits, loss

In [5]:
class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]
          
    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)

        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, 1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=self.labels_to_caption(pred, y)) 
                            for x, pred, y in zip(val_imgs, preds, self.val_labels)],
            "global_step": trainer.global_step
            })
        
    def labels_to_caption(self, pred, y):
        mapping = {0: "airplane", 1:"car", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9: "truck"}
        pred_name = mapping[pred.item()]
        y_name = mapping[y.item()]
        return f"Pred:{pred_name}, Label:{y_name}"

In [8]:
mobilenet = torchvision.models.mobilenet_v3_large(weights=torchvision.models.MobileNet_V3_Large_Weights.IMAGENET1K_V2, progress=True)
mobilenet.classifier[3] = nn.Linear(in_features=1280, out_features=10, bias=True)
mobilenet.train()

In [7]:
# setup data
cifar = CIFAR10DataModule()
cifar.prepare_data()
cifar.setup()

# grab samples to log predictions on
samples = next(iter(cifar.val_dataloader()))

Files already downloaded and verified
Files already downloaded and verified


In [12]:
trainer = pl.Trainer(
    logger=wandb_logger,    # W&B integration
    log_every_n_steps=50,   # set the logging frequency
    max_epochs=5,           # number of epochs
    deterministic=True,     # keep it deterministic
    callbacks=[ImagePredictionLogger(samples)] # see Callbacks section
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
# setup model
model = CIFAR10Model(model=mobilenet, in_dims=(1, 3, 32, 32))

# fit the model
trainer.fit(model, cifar)

# evaluate the model on a test set
trainer.test(datamodule=cifar,
             ckpt_path=None)  # uses last-saved model

wandb.finish()

  rank_zero_warn(


Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | MobileNetV3        | 4.2 M 
1 | train_acc | MulticlassAccuracy | 0     
2 | valid_acc | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
4.2 M     Trainable params
0         Non-trainable params
4.2 M     Total params
16.859    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

(1, 3, 32, 32)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

(1, 3, 32, 32)


Validation: 0it [00:00, ?it/s]

(1, 3, 32, 32)


Validation: 0it [00:00, ?it/s]

(1, 3, 32, 32)


Validation: 0it [00:00, ?it/s]

(1, 3, 32, 32)


Validation: 0it [00:00, ?it/s]

(1, 3, 32, 32)


`Trainer.fit` stopped: `max_epochs=5` reached.
  rank_zero_warn(


Files already downloaded and verified
Files already downloaded and verified


Restoring states from the checkpoint path at .\f5611\20ew4paf\checkpoints\epoch=4-step=3520.ckpt
Loaded model weights from checkpoint at .\f5611\20ew4paf\checkpoints\epoch=4-step=3520.ckpt


Testing: 0it [00:00, ?it/s]

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
global_step,▁▁▂▂▄▄▅▅▇▇██
test/acc_epoch,▁
test/loss_epoch,▁
train/acc_epoch,▁▄▆▇█
train/acc_step,▁▁▂▄▄▅▄▄▄▅▅▆▆▆▅▅▆▆▆▆▆▆▆▆▇▅▇▇▇██▆▇▇██▇███
train/loss_epoch,█▅▃▂▁
train/loss_step,█▇▇▅▅▄▅▄▄▄▄▃▄▃▄▄▃▂▃▂▃▃▂▃▂▃▂▂▂▁▁▂▂▂▁▁▂▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
valid/acc_epoch,▁▆▇██

0,1
epoch,4.0
global_step,3520.0
test/acc_epoch,0.6718
test/loss_epoch,1.10433
train/acc_epoch,0.83258
train/acc_step,0.82812
train/loss_epoch,0.48509
train/loss_step,0.43491
trainer/global_step,3520.0
valid/acc_epoch,0.6588


In [15]:
wandb.finish()