<a href="https://colab.research.google.com/github/shivakrishnaah/advanced_deep_learning/blob/main/PixelCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import libraries

In [None]:
#! /usr/bin/env python

import os
import time
import sys

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim, cuda, backends
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, transforms, utils
backends.cudnn.benchmark = True

## [Pixel CNN paper ](https://arxiv.org/abs/1606.05328)
### Mask Type A
	- Used in the first convolutional layer.
	- Ensures that the current pixel cannot see itself during prediction.
	- The mask removes the center pixel and all future pixels in both row and column directions.

### Mask Type B
	- Used in subsequent convolutional layers.
	- Allows the network to see the current pixel but not future pixels.
	- The mask removes only future pixels, while the current pixel remains accessible.

In [None]:
class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride, padding, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, height, width = self.weight.size()
        self.mask.fill_(1)
        yc, xc = height // 2, width // 2
        self.mask[:, :, yc, xc+ (mask_type == 'B'):] = 0
        self.mask[:, :, yc + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

In [None]:
class PixelCNN(nn.Module):
  def __init__(self, input_channels=3, n_filters=64, kernel_size=7, n_layers=8) -> None:
      super(PixelCNN, self).__init__()
      layers = []
      layers.append(MaskedConv2d('A', input_channels, n_filters, kernel_size, 1, padding=kernel_size//2, bias=False))
      layers.append(nn.BatchNorm2d(n_filters))
      layers.append(nn.ReLU(True))

      for i in range(n_layers -1):
          layers.append(MaskedConv2d('B', n_filters, n_filters, kernel_size, 1, padding=kernel_size//2, bias=False))
          layers.append(nn.BatchNorm2d(n_filters))
          layers.append(nn.ReLU(True))
          # nn.Conv2d(in_channels=input_channels,  out_channels=input_channels*256, kernel_size=7, padding=3)
      layers.append(nn.Conv2d(in_channels=n_filters, out_channels=input_channels * 256, kernel_size=1))
      self.net = nn.Sequential(*layers)

  def forward(self, x):
      return self.net(x)

In [None]:
import torch.nn.functional as F

def training_model(model, train_loader, optimizer, device, epochs=10):
    # Detect the channels from the input
    input_channels = next(iter(train_loader))[0].shape[1]

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, _ in train_loader:
            # Scale target to integers in [0, 255]
            x = (x * 255).long().to(device)
            x_input = x.float() / 255.0

            # Forward pass
            logits = model(x_input)          # [batch, input_channels * 256, H, W]
            batch_size, _, height, width = logits.shape

            # Reshape logits: [batch, channels, 256, H, W]
            logits = logits.view(batch_size, input_channels, 256, height, width)

            # Move channels to last dimension: [batch, H, W, channels, 256]
            logits = logits.permute(0, 3, 4, 1, 2).contiguous()

            # Flatten for loss: [batch * H * W * channels, 256]
            logits = logits.view(-1, 256)
            targets = x.view(-1)  # Flatten targets: [batch * H * W * channels]

            loss = F.cross_entropy(logits, targets)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_loader = data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)
sample_batch, _ = next(iter(train_loader))
input_channels = sample_batch.shape[1]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
minst_model = PixelCNN(input_channels=input_channels, n_filters=64).to(device)
optimizer = torch.optim.Adam(minst_model.parameters(), lr=1e-3)
training_model(minst_model, train_loader, optimizer, device, epochs=10)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.11MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.47MB/s]


Epoch 1: Loss = 1.3173
Epoch 2: Loss = 0.7900
Epoch 3: Loss = 0.7625
Epoch 4: Loss = 0.7494
Epoch 5: Loss = 0.7405
Epoch 6: Loss = 0.7334
Epoch 7: Loss = 0.7272
Epoch 8: Loss = 0.7213
Epoch 9: Loss = 0.7147
Epoch 10: Loss = 0.7073


In [None]:
import torchvision.utils as vutils

def sample(model, img_size=(28, 28), n_channels=1, n_samples=64):
    model.eval()
    samples = torch.zeros(n_samples, n_channels, *img_size).to(device)
    with torch.no_grad():
        for i in range(img_size[0]):
            for j in range(img_size[1]):
              for c in range(n_channels):
                logits = model(samples)
                logits = logits.view(n_samples, n_channels, 256, img_size[0], img_size[1])
                probs = F.softmax(logits[:, c, :, i, j], dim=1)  # [n_samples, 256]
                pixel_values = torch.multinomial(probs, num_samples=1).squeeze(-1)  # Sample pixel
                samples[:, c, i, j] = pixel_values.float() / 255.0
    return samples

In [None]:
samples = sample(minst_model, img_size=(28, 28), n_channels=1, n_samples=64)
vutils.save_image(samples, "pixelcnn_mnist_samples1.png", nrow=8, padding=2)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_loader = data.DataLoader(
    datasets.CIFAR10('data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)
sample_batch, _ = next(iter(train_loader))
input_channels = sample_batch.shape[1]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cifar_model = PixelCNN(input_channels=input_channels, n_filters=192, n_layers=15).to(device)
optimizer = torch.optim.Adam(cifar_model.parameters(), lr=1e-3)
training_model(cifar_model, train_loader, optimizer, device, epochs=10)

100%|██████████| 170M/170M [00:13<00:00, 12.8MB/s]


Epoch 1: Loss = 5.1866
Epoch 2: Loss = 4.8837
Epoch 3: Loss = 4.7629
Epoch 4: Loss = 4.6849
Epoch 5: Loss = 4.6198
Epoch 6: Loss = 4.5785
Epoch 7: Loss = 4.5464
Epoch 8: Loss = 4.5156
Epoch 9: Loss = 4.4803
Epoch 10: Loss = 4.4465


In [9]:
samples = sample(cifar_model, img_size=(32, 32), n_channels=3, n_samples=64)
vutils.save_image(samples, "pixelcnn_cifar10_samples.png", nrow=8, padding=2)

In [10]:
class GatedMaskedConv2d(nn.Module):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(GatedMaskedConv2d, self).__init__()
        self.mask_type = mask_type
        self.conv = MaskedConv2d(mask_type, in_channels, 2 * out_channels, kernel_size, stride, padding)
        self.out_channels = out_channels

    def forward(self, x):
        out = self.conv(x)
        value, gate = out.chunk(2, dim=1)
        return torch.tanh(value) * torch.sigmoid(gate)

In [13]:
class PixelCNNpp(nn.Module):
    def __init__(self, input_channels=3, n_filters=64, kernel_size=7, n_layers=8):
        super(PixelCNNpp, self).__init__()
        layers = []

        # First layer uses Mask-A
        layers.append(GatedMaskedConv2d('A', input_channels, n_filters, kernel_size, padding=kernel_size//2))

        # Subsequent layers use Mask-B
        for _ in range(n_layers - 1):
            layers.append(GatedMaskedConv2d('B', n_filters, n_filters, kernel_size, padding=kernel_size//2))

        self.net = nn.Sequential(*layers)
        self.output_conv = nn.Conv2d(n_filters, input_channels * 256, kernel_size=1)

    def forward(self, x):
        x = self.net(x)
        x = self.output_conv(x)
        return x

In [None]:
# transform = transforms.Compose([
#     transforms.ToTensor()
# ])
# train_loader = data.DataLoader(
#     datasets.CIFAR10('data', train=True, download=True, transform=transform),
#     batch_size=128, shuffle=True
# )
# sample_batch, _ = next(iter(train_loader))
# input_channels = sample_batch.shape[1]

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cifar_model_pp = PixelCNNpp(input_channels=input_channels, n_filters=192, n_layers=15).to(device)
optimizer = torch.optim.Adam(cifar_model_pp.parameters(), lr=1e-3)
training_model(cifar_model_pp, train_loader, optimizer, device, epochs=10)

In [None]:
samples = sample(cifar_model_pp, img_size=(32, 32), n_channels=3, n_samples=64)
vutils.save_image(samples, "pixelcnn_cifar10pp_samples.png", nrow=8, padding=2)