### Colorization AutoEncoder PyTorch Demo using CIFAR10

In this demo, we build a simple colorization autoencoder using PyTorch.

In [1]:
import torch
import torchvision
import wandb
import time

from torch import nn
from einops import rearrange, reduce
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

2022-05-25 20:41:40.216074: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


#### CNN Encoder using PyTorch

We use 3 CNN layers to encode the grayscale input image. We use stride of 2 to reduce the feature map size. The last MLP layer resizes the flattened feature map to the target latent vector size. We use more filters and a much bigger latent vector size of 256 to encode more information.


In [2]:
class Encoder(nn.Module):
    def __init__(self, n_features=1, kernel_size=3, n_filters=64, feature_dim=256):
        super().__init__()
        self.conv1 = nn.Conv2d(n_features, n_filters, kernel_size=kernel_size, stride=2)
        self.conv2 = nn.Conv2d(n_filters, n_filters*2, kernel_size=kernel_size, stride=2)
        self.conv3 = nn.Conv2d(n_filters*2, n_filters*4, kernel_size=kernel_size, stride=2)
        self.fc1 = nn.Linear(2304, feature_dim)
        
    def forward(self, x):
        y = nn.ReLU()(self.conv1(x))
        y = nn.ReLU()(self.conv2(y))
        y = nn.ReLU()(self.conv3(y))
        y = rearrange(y, 'b c h w -> b (c h w)')

        y = self.fc1(y)
        return y


# use this to get the correct input shape for  fc1. 
encoder = Encoder(n_features=1)
x = torch.Tensor(1, 1, 32, 32)
h = encoder(x)
print("h.shape:", h.shape)

h.shape: torch.Size([1, 256])


### CNN Decoder using PyTorch

A decoder is used to reconstruct the input image. The decoder is trained to reconstruct the input data from the latent space. The architecture is similar to the encoder but inverted. A latent vector is resized using an MLP layer so that it is suitable for a convolutional layer. We use strided tranposed convolutional layers to upsample the feature map until the desired image size is reached. The target image is the colorized version of the input image.

In [3]:
class Decoder(nn.Module):
    def __init__(self, kernel_size=3, n_filters=256, feature_dim=256, output_size=32, output_channels=3):
        super().__init__()
        self.init_size = output_size // 2**2 
        self.fc1 = nn.Linear(feature_dim, self.init_size**2 * n_filters)
        # output size of conv2dtranspose is (h-1)*2 + 1 + (kernel_size - 1)
        self.conv1 = nn.ConvTranspose2d(n_filters, n_filters//2, kernel_size=kernel_size, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(n_filters//2, n_filters//4, kernel_size=kernel_size, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose2d(n_filters//4, n_filters//4, kernel_size=kernel_size, padding=1)
        self.conv4 = nn.ConvTranspose2d(n_filters//4, output_channels, kernel_size=kernel_size+1)
        
    def forward(self, x):
        B, _ = x.shape
        y = self.fc1(x)
        y = rearrange(y, 'b (c h w) -> b c h w', b=B, h=self.init_size, w=self.init_size)
        y = nn.ReLU()(self.conv1(y))
        y = nn.ReLU()(self.conv2(y))
        y = nn.ReLU()(self.conv3(y))
        y = nn.Sigmoid()(self.conv4(y))

        return y

decoder = Decoder()
x_tilde = decoder(h)
print("x_tilde.shape:", x_tilde.shape)

x_tilde.shape: torch.Size([1, 3, 32, 32])


#### PyTorch Lightning Colorization AutoEncoder

In the colorization autoencoder, the encoder extracts features from the input image and the decoder reconstructs the input image from the latent space. The decoder adds color. The decoder's last layer has 3 output channels corresponding to RGB.

We `gray_collate_fn` to generate gray images from RGB images. 

In [4]:
def gray_collate_fn(batch):
        x, _ = zip(*batch)
        x = torch.stack(x, dim=0)
        xn = reduce(x,"b c h w -> b 1 h w", 'mean')
        return xn, x

class LitColorizeCIFAR10Model(LightningModule):
    def __init__(self, feature_dim=256, lr=0.001, batch_size=64,
                 num_workers=4, max_epochs=30, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Encoder(feature_dim=feature_dim)
        self.decoder = Decoder(feature_dim=feature_dim)
        self.loss = nn.MSELoss()

    def forward(self, x):
        h = self.encoder(x)
        x_tilde = self.decoder(h)
        return x_tilde

    # this is called during fit()
    def training_step(self, batch, batch_idx):
        x_in, x = batch
        x_tilde = self.forward(x_in)
        loss = self.loss(x_tilde, x)
        return {"loss": loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    # this is called at the end of an epoch
    def test_step(self, batch, batch_idx):
        x_in, x = batch
        x_tilde = self.forward(x_in)
        loss = self.loss(x_tilde, x)
        return {"x_in" : x_in, "x": x, "x_tilde" : x_tilde, "test_loss" : loss,}

    # this is called at the end of all epochs
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)

    # validation is the same as test
    def validation_step(self, batch, batch_idx):
       return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    # we use Adam optimizer
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler], 
    
    # this is called after model instatiation to initiliaze the datasets and dataloaders
    def setup(self, stage=None):
        self.train_dataloader()
        self.test_dataloader()

    # build train and test dataloaders using MNIST dataset
    # we use simple ToTensor transform
    def train_dataloader(self):        
        return torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(
                "./data", train=True, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            collate_fn=gray_collate_fn
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(
                "./data", train=False, download=True, 
                transform=torchvision.transforms.ToTensor()
            ),
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            collate_fn=gray_collate_fn
        )

    def val_dataloader(self):
        return self.test_dataloader()

#### Arguments

Similar to the MNIST AE but we use a bigger latent vector size of 256 given that the colorization task needs more feature information from the input image.

In [5]:
def get_args():
    parser = ArgumentParser(description="PyTorch Lightning Colorization AE CIFAR10 Example")
    parser.add_argument("--max-epochs", type=int, default=30, help="num epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="batch size")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--feature-dim", type=int, default=256, help="ae feature dimension")

    parser.add_argument("--devices", default=1)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--num-workers", type=int, default=4, help="num workers")
    
    args = parser.parse_args("")
    return args

#### Weights and Biases Callback

The callback logs train and validation metrics to `wandb`. It also logs sample predictions. This is similar to our `WandbCallback` example for MNIST.

In [6]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # process first 10 images of the first batch
        if batch_idx == 0:
            x, c = batch
            n = pl_module.hparams.batch_size // 4
            outputs = outputs["x_tilde"]
            columns = ['gt color', 'gray', 'colorized']
            data = [[wandb.Image(c_i), wandb.Image(x_i), wandb.Image(x_tilde_i)] for c_i, x_i, x_tilde_i in list(zip(c[:n], x[:n], outputs[:n]))]
            wandb_logger.log_table(key="cifar10-colorize-ae", columns=columns, data=data)

#### Training an AE

We train the autoencoder on the CIFAR10 dataset. 
The results can be viewed on [wandb](https://app.wandb.ai/).

In [7]:
if __name__ == "__main__":
    args = get_args()
    ae = LitColorizeCIFAR10Model(feature_dim=args.feature_dim, lr=args.lr, 
                                 batch_size=args.batch_size, num_workers=args.num_workers,
                                 max_epochs=args.max_epochs)
    #ae.setup()
   

In [8]:

    wandb_logger = WandbLogger(project="colorize-cifar10")
    start_time = time.time()
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger,
                      callbacks=[WandbCallback()])
    trainer.fit(ae)

    elapsed_time = time.time() - start_time
    print("Elapsed time: {}".format(elapsed_time))

    wandb.finish()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrowel[0m. Use [1m`wandb login --relogin`[0m to force relogin
2022-05-25 20:41:45.914314: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 959 K 
1 | decoder | Decoder | 4.6 M 
2 | loss    | MSELoss | 0     
------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.317    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Elapsed time: 361.66529393196106


VBox(children=(Label(value='1.470 MB of 1.470 MB uploaded (0.069 MB deduped)\r'), FloatProgress(value=0.999662…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test_loss,█▅▄▃▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████

0,1
epoch,30.0
test_loss,0.00777
train_loss,0.00365
trainer/global_step,23459.0
