### CNN for CIFAR10

CNN model that can be used for classification tasks. 

In this demo, we will train a 3-layer CNN on the CIFAR10 dataset. We will show 2 implementations of the CNN model. First is using PyTorch's  built-in `nn.Conv2d` API. Second, is we will use tensor level convolution. 

Let us first import the required modules.

In [56]:
import torch
import torchvision
import wandb
import time
#import numpy as np
#import matplotlib.pyplot as plt
from torch import nn
from einops import rearrange
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
#from torchmetrics.functional import accuracy
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
#from matplotlib import image


#### CNN Encoder using PyTorch

In this example, we use `nn.Conv2D` to create a 3-layer CNN model. Note the following:

In [58]:
class Encoder(nn.Module):
    def __init__(self, n_features=1, kernel_size=3, n_filters=8, feature_dim=16):
        super().__init__()
        self.conv1 = nn.Conv2d(n_features, n_filters, kernel_size=kernel_size, stride=2)
        self.conv2 = nn.Conv2d(n_filters, n_filters*2, kernel_size=kernel_size, stride=2)
        self.conv3 = nn.Conv2d(n_filters*2, n_filters*4, kernel_size=kernel_size, stride=2)
        self.fc1 = nn.Linear(128, feature_dim)
        
    def forward(self, x):
        y = nn.ReLU()(self.conv1(x))
        y = nn.ReLU()(self.conv2(y))
        y = nn.ReLU()(self.conv3(y))
        y = rearrange(y, 'b c h w -> b (c h w)')

        y = self.fc1(y)
        return y
        # we dont need to compute softmax since it is already
        # built into the CE loss function in PyTorch
        #return F.log_softmax(y, dim=1)


# use this to get the correct input shape for  fc1. In this case,
# comment out y=self.fc1(y) and run the code below.
encoder = Encoder(n_features=1)
x = torch.Tensor(1, 1, 28, 28)
h = encoder(x)
print("h.shape:", h.shape)

h.shape: torch.Size([1, 128])


In [None]:
class Decoder(nn.Module):
    def __init__(self, kernel_size=3, n_filters=32, feature_dim=16, output_size=28, output_channels=1):
        super().__init__()
        self.init_size = output_size // 2**2 - 1
        self.fc1 = nn.Linear(feature_dim, self.init_size**2 * n_filters)
        # output size of conv2dtranspose is (h-1)*2 + 1 + (kernel_size - 1)
        self.conv1 = nn.ConvTranspose2d(n_filters, n_filters//2, kernel_size=kernel_size, stride=2)
        self.conv2 = nn.ConvTranspose2d(n_filters//2, n_filters//4, kernel_size=kernel_size, stride=2)
        self.conv3 = nn.ConvTranspose2d(n_filters//4, output_channels, kernel_size=kernel_size-1)
        
    def forward(self, x):
        B, _ = x.shape
        y = self.fc1(x)
        y = rearrange(y, 'b (c h w) -> b c h w', b=B, h=self.init_size, w=self.init_size)
        y = nn.ReLU()(self.conv1(y))
        y = nn.ReLU()(self.conv2(y))
        y = nn.Sigmoid()(self.conv3(y))

        return y

decoder = Decoder()
x_tilde = decoder(h)
print("x_tilde.shape:", x_tilde.shape)

x_tilde.shape: torch.Size([1, 1, 28, 28])


#### PyTorch Lightning AutoEncoder


In [None]:
def noise_collate_fn(self, batch):
        x, y = zip(*batch)
        noise = torch.random.normal(0, 1, size=x.shape)
        x = x + noise
        x = torch.clamp(x, 0, 1)
        x = torch.stack(x, dim=0)
        return x, y

class LitAEMNISTModel(LightningModule):
    def __init__(self, feature_dim=16, lr=0.001, batch_size=64,
                 num_workers=4, max_epochs=30, denoise=False, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Encoder(feature_dim=feature_dim)
        self.decoder = Decoder(feature_dim=feature_dim)
        self.loss = nn.MSELoss()
        self.denoise = denoise

    def forward(self, x):
        h = self.encoder(x)
        x_tilde = self.decoder(h)
        return x_tilde

    # this is called during fit()
    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_tilde = self.forward(x)
        loss = self.loss(x_tilde, x)
        return {"loss": loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    # this is called at the end of an epoch
    def test_step(self, batch, batch_idx):
        x, _ = batch
        x_tilde = self.forward(x)
        loss = self.loss(x_tilde, x)
        return {"x" : x, "x_tilde" : x_tilde, "test_loss" : loss,}

    # this is called at the end of all epochs
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)

    # validation is the same as test
    def validation_step(self, batch, batch_idx):
       return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    # we use Adam optimizer
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]
    
    # this is called after model instatiation to initiliaze the datasets and dataloaders
    def setup(self, stage=None):
        self.train_dataloader()
        self.test_dataloader()

    # build train and test dataloaders using MNIST dataset
    # we use simple ToTensor transform
    def train_dataloader(self):        
        collate_fn = noise_collate_fn if self.denoise else None
        return torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
                "./data", train=True, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            collate_fn=collate_fn
        )

    def test_dataloader(self):
        collate_fn = noise_collate_fn if self.denoise else None
        return torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
                "./data", train=False, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            collate_fn=collate_fn
        )

    def val_dataloader(self):
        return self.test_dataloader()

#### Arguments


In [None]:
def get_args():
    parser = ArgumentParser(description="PyTorch Lightning AE MNIST Example")
    parser.add_argument("--max-epochs", type=int, default=30, help="num epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--feature-dim", type=int, default=16, help="ae feature dimension")
    # if denoise is true
    parser.add_argument("--denoise", action="store_true", help="denoise")

    parser.add_argument("--devices", default=1)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--num-workers", type=int, default=4, help="num workers")
    
    args = parser.parse_args("")
    return args

#### Weights and Biases Callback

The callback logs train and validation metrics to `wandb`. It also logs sample predictions. This is similar to our `WandbCallback` example for MNIST.

In [None]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # process first 10 images of the first batch
        if batch_idx == 0:
            x, _ = batch
            n = 10
            outputs = outputs["x_tilde"]
            # log image, ground truth and prediction on wandb table
            columns = ['image', 'reconstructed', ]
            key = "mnist-ae-denoising" if pl_module.denoise else "mnist-ae-reconstruction"
            data = [[wandb.Image(x_i), wandb.Image(x_tilde_i)] for x_i, x_tilde_i in list(zip(x[:n], outputs[:n]))]
            wandb_logger.log_table(key=key, columns=columns, data=data)

#### Training an AE

In [None]:
if __name__ == "__main__":
    args = get_args()
    ae = LitAEMNISTModel(feature_dim=args.feature_dim, lr=args.lr, 
                         batch_size=args.batch_size, num_workers=args.num_workers,
                         denoise=args.denoise, max_epochs=args.max_epochs)
    ae.setup()
   

LitAEMNISTModel(
  (encoder): Encoder(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
    (fc1): Linear(in_features=512, out_features=16, bias=True)
  )
  (decoder): Decoder(
    (fc1): Linear(in_features=16, out_features=1152, bias=True)
    (conv1): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (conv2): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2))
    (conv3): ConvTranspose2d(8, 1, kernel_size=(2, 2), stride=(1, 1))
  )
  (loss): MSELoss()
)


In [None]:


    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="ae-mnist")
    start_time = time.time()
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger,
                      callbacks=[WandbCallback()])
    trainer.fit(ae)

    elapsed_time = time.time() - start_time
    print("Elapsed time: {}".format(elapsed_time))

    wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 100 K 
1 | decoder | Decoder | 25.4 K
2 | loss    | MSELoss | 0     
------------------------------------
126 K     Trainable params
0         Non-trainable params
126 K     Total params
0.505     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Elapsed time: 353.8311905860901
