<a href="https://colab.research.google.com/github/sayanbanerjee32/TASI_ERAv2_S22/blob/main/VAE_for_CIFAR10_image_and_label_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
! pip install pytorch-lightning
# ! pip install pytorch-lightning-bolts
! pip install git+https://github.com/PytorchLightning/lightning-bolts.git@master --upgrade

In [None]:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


In [None]:
class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32, num_classes=10, label_embedding_dim=50):
        super().__init__()

        self.save_hyperparameters()

        # Initial conv layer to adjust input channels
        # self.initial_conv = nn.Conv2d(1 + num_classes, 3, kernel_size=3, stride=1, padding=1)
        self.initial_conv_block = nn.Sequential(
                    nn.Conv2d(3 + num_classes, 3, kernel_size=(1, 1), padding=0, bias=False),
                    nn.ReLU(inplace=True),
                    nn.BatchNorm2d(3),
        )

        # encoder, decoder
        self.encoder = resnet18_encoder(first_conv=False, maxpool1=False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim + label_embedding_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        # label embedding
        self.label_embedding = nn.Embedding(num_classes, label_embedding_dim)

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, mean, logscale, sample):
        scale = torch.exp(logscale)
        dist = torch.distributions.Normal(mean, scale)
        log_pxz = dist.log_prob(sample)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)
        kl = (log_qzx - log_pz).sum(-1)
        return kl

    def forward(self, x, labels):
        # One-hot encode labels
        labels_onehot = F.one_hot(labels, num_classes=self.hparams.num_classes).float()
        labels_onehot = labels_onehot.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat((x, labels_onehot), dim=1)  # batch_size x (embed_dim + num channels) x img_size x img_size
        x = self.initial_conv_block(x) # convert to 3 channels as resnet encoder supports 3 channels
        x_encoded = self.encoder(x)#.reshape(x.size(0), -1)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # Embed labels for the decoder
        labels_embedded = self.label_embedding(labels)
        z = torch.cat((z, labels_embedded), dim=1)
        x_hat = self.decoder(z)

        return x_hat, mu, log_var

    def training_step(self, batch, batch_idx):
        x, labels = batch
        x_hat, mu, log_var = self(x, labels)

        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)
        kl = self.kl_divergence(z, mu, std)
        elbo = (kl - recon_loss).mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean()
        })

        return elbo


  and should_run_async(code)


In [None]:
from pl_bolts.datamodules import CIFAR10DataModule
cifar = CIFAR10DataModule('.', batch_size = 512)
cifar.prepare_data()
cifar.setup()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:18<00:00, 9207752.96it/s]


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [None]:
pl.seed_everything(1234)
epochs = 50
vae = VAE()
trainer = pl.Trainer(devices=1, accelerator="gpu", max_epochs=epochs, enable_progress_bar=True)
trainer.fit(vae, cifar)

INFO:lightning_fabric.utilities.seed:Global seed set to 1234
  self.encoder = resnet18_encoder(first_conv=False, maxpool1=False)
  return ResNetEncoder(EncoderBlock, [2, 2, 2, 2], first_conv, maxpool1)
  layers.append(block(self.inplanes, planes, stride, downsample))
  self.conv1 = conv3x3(inplanes, planes, stride)
  conv1x1(self.inplanes, planes * block.expansion, stride),
  self.decoder = resnet18_decoder(
  return ResNetDecoder(DecoderBlock, [2, 2, 2, 2], latent_dim, input_height, first_conv, maxpool1)
  resize_conv1x1(self.inplanes, planes * block.expansion, scale),
  return nn.Sequential(Interpolate(scale_factor=scale), conv1x1(in_planes, out_planes))
  layers.append(block(self.inplanes, planes, scale, upsample))
  self.conv1 = resize_conv3x3(inplanes, inplanes)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available

Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name               | Type          | Params
-----------------------------------------------------
0 | initial_conv_block | Sequential    | 45    
1 | encoder            | ResNetEncoder | 11.2 M
2 | decoder            | ResNetDecoder | 9.0 M 
3 | label_embedding    | Embedding     | 500   
4 | fc_mu              | Linear        | 131 K 
5 | fc_var             | Linear        | 131 K 
-----------------------------------------------------
20.5 M    Trainable params
0         Non-trainable params
20.5 M    Total params
81.869    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


## Plot an image

In [None]:
test = cifar.val_dataloader()
x, y = next(iter(test))
x.shape, y.shape

  and should_run_async(code)


(torch.Size([512, 3, 32, 32]), torch.Size([512]))

In [None]:
num_preds = 25
y_t = y[0:num_preds]
print(y_t)
# Generate a random permutation of indices
indices = torch.randperm(y.numel())
shuffled_y = y[indices][0:num_preds]
shuffled_y

tensor([7, 6, 6, 9, 1, 1, 8, 9, 4, 7, 0, 5, 7, 2, 2, 1, 7, 6, 2, 8, 8, 2, 2, 5,
        7])


tensor([6, 3, 1, 2, 8, 3, 8, 2, 8, 3, 2, 4, 5, 0, 0, 6, 6, 3, 4, 1, 3, 0, 2, 1,
        9])

In [None]:
with torch.no_grad():
    pred , mu, log_var = vae(x[0:num_preds].to(vae.device), shuffled_y.to(vae.device))

print('mu:', mu.shape)
print('log_var:', log_var.shape)

# SAMPLE Z from Q(Z|x)
std = torch.exp(log_var / 2)
q = torch.distributions.Normal(mu, std)
z = q.rsample()

print('z shape:', z.shape)
print('pred size:', pred.size())

mu: torch.Size([25, 256])
log_var: torch.Size([25, 256])
z shape: torch.Size([25, 256])
pred size: torch.Size([25, 3, 32, 32])


In [None]:
from matplotlib.pyplot import imshow, figure
import numpy as np
from torchvision.utils import make_grid
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
figure(figsize=(8, 3), dpi=300)

pred = pred.cpu()
# UNDO DATA NORMALIZATION
normalize = cifar10_normalization()
mean, std = np.array(normalize.mean), np.array(normalize.std)
img = make_grid(pred,nrow = 5).permute(1, 2, 0).numpy() * std + mean

# PLOT IMAGES
imshow(img);

TypeError: cifar10_normalization() takes 0 positional arguments but 1 was given

<Figure size 2400x900 with 0 Axes>

## Labeled images

In [None]:
import matplotlib.pyplot as plt
class_labels = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def plot_image(images, target_labels, pred_labels = None, rows = 5, cols = 5,
               img_size=(5,5), font_size = 7):
    figure = plt.figure(figsize=img_size)
    for index in range(cols * rows):
        plt.subplot(rows, cols, index+1)
        if pred_labels is not None:
            plt.title(f'image: {class_labels[target_labels[index]]}\nlabel: {class_labels[pred_labels[index]]}',
                  fontsize = font_size)
        else:
            plt.title(f'image: {class_labels[target_labels[index]]}', fontsize = font_size)
        plt.axis('off')
        plt.imshow(images[index])
    figure.tight_layout()
    plt.show()

In [None]:
imgs = pred.permute(0, 2, 3, 1).numpy() * std + mean
plot_image(imgs, target_labels = y_t, pred_labels = shuffled_y)