### CNN for CIFAR10

CNN model that can be used for classification tasks. 

In this demo, we will train a 3-layer CNN on the CIFAR10 dataset. 

Let us first import the required modules.

In [4]:
import torch
import torchvision
import wandb
import math
from torch import nn
from einops import rearrange
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

#### CNN using PyTorch `nn.Conv2D`




In [18]:
class SimpleCNN(nn.Module):
    def __init__(self, n_features=3, kernel_size=3, n_filters=32, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(n_features, n_filters, kernel_size=kernel_size)
        self.conv2 = nn.Conv2d(n_filters, n_filters*2, kernel_size=kernel_size)
        self.conv3 = nn.Conv2d(n_filters*2, n_filters*4, kernel_size=kernel_size)
        self.fc1 = nn.Linear(2048, num_classes)

    def forward(self, x):
        y = nn.ReLU()(self.conv1(x))
        y = nn.MaxPool2d(kernel_size=2)(y)
        y = nn.ReLU()(self.conv2(y))
        y = nn.MaxPool2d(kernel_size=2)(y)
        y = nn.ReLU()(self.conv3(y))
        y = y.view(y.size(0), -1)
        y = self.fc1(y)
        return y
        # we dont need to compute softmax since it is already
        # built into the CE loss function in PyTorch
        #return F.log_softmax(y, dim=1)


# use this to get the correct input shape for  fc1. In this case,
# comment out y=self.fc1(y) and run the code below.
#model = SimpleCNN()
#data = torch.Tensor(1, 3, 32, 32)
#y = model(data)
#print("Y.shape:", y.shape)

#### PyTorch Lightning Module for CNN

This is the PL module so we can easily change the implementation of the CNN and compare the results.  More detailed results can be found on the `wandb.ai` page.

Using `model` parameter, we can easily switch between different model implementations. We also benchmark the result using a ResNet18 model. 

In [19]:
class LitCIFAR10Model(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, batch_size=64,
                 num_workers=4, max_epochs=30,
                 model=SimpleCNN):
        super().__init__()
        self.save_hyperparameters()
        self.model = model(num_classes=num_classes)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    # this is called during fit()
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        return {"loss": loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    # this is called at the end of an epoch
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y) * 100.
        # we use y_hat to display predictions during callback
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    # this is called at the end of all epochs
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc, on_epoch=True, prog_bar=True)

    # validation is the same as test
    def validation_step(self, batch, batch_idx):
       return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    # we use Adam optimizer
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]
    
    # this is called after model instatiation to initiliaze the datasets and dataloaders
    def setup(self, stage=None):
        self.train_dataloader()
        self.test_dataloader()

    # build train and test dataloaders using MNIST dataset
    # we use simple ToTensor transform
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(
                "./data", train=True, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(
                "./data", train=False, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return self.test_dataloader()

#### Arguments

Please change the `--model` argument to switch between the different models to be used as CIFAR10 classifier.

In [28]:
def get_args():
    parser = ArgumentParser(description="PyTorch Lightning MNIST Example")
    parser.add_argument("--max-epochs", type=int, default=30, help="num epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--num-classes", type=int, default=10, help="num classes")

    parser.add_argument("--devices", default=1)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--num-workers", type=int, default=4, help="num workers")
    
    #parser.add_argument("--model", default=torchvision.models.resnet18)
    parser.add_argument("--model", default=SimpleCNN)
    args = parser.parse_args("")
    return args

#### Weights and Biases Callback

The callback logs train and validation metrics to `wandb`. It also logs sample predictions. This is similar to our `WandbCallback` example for MNIST.

In [29]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # process first 10 images of the first batch
        if batch_idx == 0:
            label_human = ["airplane", "automobile", "bird", "cat",
                           "deer", "dog", "frog", "horse", "ship", "truck"]
            n = 10
            x, y = batch
            outputs = outputs["y_hat"]
            outputs = torch.argmax(outputs, dim=1)
            # log image, ground truth and prediction on wandb table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), label_human[y_i], label_human[y_pred]] for x_i, y_i, y_pred in list(
                zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(
                key=pl_module.model.__class__.__name__,
                columns=columns,
                data=data)

#### Training and Validation of Different Models

The validation accuracy of our `SimpleCNN` at `~73%`. 

Meanwhile the ResNet18 model has accuracy of `~78%`. Consider that `SimpleCNN` uses 113k parameters while `ResNet18` uses 11.2M parameters. `SimpleCNN` is very efficient. Recall that our `SimpleMLP` hsa accuracy of `~53%` only.

In [30]:
if __name__ == "__main__":
    args = get_args()
    model = LitCIFAR10Model(num_classes=args.num_classes,
                           lr=args.lr, batch_size=args.batch_size,
                           num_workers=args.num_workers,
                           model=args.model,)
    model.setup()

    # printing the model is useful for debugging
    print(model)
    print(model.model.__class__.__name__)

    

Files already downloaded and verified
Files already downloaded and verified
LitCIFAR10Model(
  (model): SimpleCNN(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (fc1): Linear(in_features=2048, out_features=10, bias=True)
  )
  (loss): CrossEntropyLoss()
)
SimpleCNN


In [31]:


    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="cnn-cifar")
    
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger,
                      callbacks=[WandbCallback()])
    trainer.fit(model)
    trainer.test(model)

    wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type             | Params
-------------------------------------------
0 | model | SimpleCNN        | 113 K 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
113 K     Trainable params
0         Non-trainable params
113 K     Total params
0.455     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  " and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Files already downloaded and verified


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             73.17874145507812
        test_loss           0.8493992686271667
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────



VBox(children=(Label(value='0.077 MB of 0.077 MB uploaded (0.023 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
test_acc,▁▃▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████
test_loss,█▆▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▆▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████

0,1
epoch,30.0
test_acc,73.17874
test_loss,0.8494
train_loss,0.4154
trainer/global_step,23460.0
