<a href="https://colab.research.google.com/github/shivakrishnaah/advanced_deep_learning/blob/main/PixelCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import libraries

In [None]:
#! /usr/bin/env python

import os
import time
import sys

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim, cuda, backends
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, transforms, utils
backends.cudnn.benchmark = True

## [Pixel CNN paper ](https://arxiv.org/abs/1606.05328)
### Mask Type A
	- Used in the first convolutional layer.
	- Ensures that the current pixel cannot see itself during prediction.
	- The mask removes the center pixel and all future pixels in both row and column directions.

### Mask Type B
	- Used in subsequent convolutional layers.
	- Allows the network to see the current pixel but not future pixels.
	- The mask removes only future pixels, while the current pixel remains accessible.

In [None]:
class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride, padding, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, height, width = self.weight.size()
        self.mask.fill_(1)
        yc, xc = height // 2, width // 2
        self.mask[:, :, yc, xc+ (mask_type == 'B'):] = 0
        self.mask[:, :, yc + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

In [None]:
class PixelCNN(nn.Module):
  def __init__(self, input_channels=3, n_filters=64, kernel_size=7, n_layers=8) -> None:
      super(PixelCNN, self).__init__()
      layers = []
      layers.append(MaskedConv2d('A', input_channels, n_filters, kernel_size, 1, padding=kernel_size//2, bias=False))
      layers.append(nn.BatchNorm2d(n_filters))
      layers.append(nn.ReLU(True))

      for i in range(n_layers -1):
          layers.append(MaskedConv2d('B', n_filters, n_filters, kernel_size, 1, padding=kernel_size//2, bias=False))
          layers.append(nn.BatchNorm2d(n_filters))
          layers.append(nn.ReLU(True))
          # nn.Conv2d(in_channels=input_channels,  out_channels=input_channels*256, kernel_size=7, padding=3)
      layers.append(nn.Conv2d(in_channels=n_filters, out_channels=input_channels * 256, kernel_size=1))
      self.net = nn.Sequential(*layers)

  def forward(self, x):
      return self.net(x)

In [38]:
import torch.nn.functional as F

def training_model(model, train_loader, optimizer, device, epochs=10, loss_function=F.cross_entropy):
    # Detect the channels from the input
    input_channels = next(iter(train_loader))[0].shape[1]

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, _ in train_loader:
            x_input = x.to(device).float()
            logits = model(x_input)
            if loss_function.__name__ == "discretized_logistic_loss":
                # ✅ PixelCNN++ path: do NOT reshape logits
                loss = loss_function(x_input, logits)

            else:
                # ✅ PixelCNN path: standard cross entropy loss
                x_int = (x * 255).long()
                batch_size, _, height, width = logits.shape
                logits = logits.view(batch_size, x.shape[1], 256, height, width)
                logits = logits.permute(0, 3, 4, 1, 2).contiguous()
                logits = logits.view(-1, 256)
                targets = x_int.view(-1)
                loss = loss_function(logits, targets)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_loader = data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)
sample_batch, _ = next(iter(train_loader))
input_channels = sample_batch.shape[1]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
minst_model = PixelCNN(input_channels=input_channels, n_filters=64).to(device)
optimizer = torch.optim.Adam(minst_model.parameters(), lr=1e-3)
training_model(minst_model, train_loader, optimizer, device, epochs=10)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.11MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.47MB/s]


Epoch 1: Loss = 1.3173
Epoch 2: Loss = 0.7900
Epoch 3: Loss = 0.7625
Epoch 4: Loss = 0.7494
Epoch 5: Loss = 0.7405
Epoch 6: Loss = 0.7334
Epoch 7: Loss = 0.7272
Epoch 8: Loss = 0.7213
Epoch 9: Loss = 0.7147
Epoch 10: Loss = 0.7073


In [None]:
import torchvision.utils as vutils

def sample(model, img_size=(28, 28), n_channels=1, n_samples=64):
    model.eval()
    samples = torch.zeros(n_samples, n_channels, *img_size).to(device)
    with torch.no_grad():
        for i in range(img_size[0]):
            for j in range(img_size[1]):
              for c in range(n_channels):
                logits = model(samples)
                logits = logits.view(n_samples, n_channels, 256, img_size[0], img_size[1])
                probs = F.softmax(logits[:, c, :, i, j], dim=1)  # [n_samples, 256]
                pixel_values = torch.multinomial(probs, num_samples=1).squeeze(-1)  # Sample pixel
                samples[:, c, i, j] = pixel_values.float() / 255.0
    return samples

In [None]:
samples = sample(minst_model, img_size=(28, 28), n_channels=1, n_samples=64)
vutils.save_image(samples, "pixelcnn_mnist_samples1.png", nrow=8, padding=2)

In [21]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_loader = data.DataLoader(
    datasets.CIFAR10('data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)
sample_batch, _ = next(iter(train_loader))
input_channels = sample_batch.shape[1]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cifar_model = PixelCNN(input_channels=input_channels, n_filters=192, n_layers=15).to(device)
optimizer = torch.optim.Adam(cifar_model.parameters(), lr=1e-3)
training_model(cifar_model, train_loader, optimizer, device, epochs=10)

Epoch 1: Loss = 5.2068


In [9]:
samples = sample(cifar_model, img_size=(32, 32), n_channels=3, n_samples=64)
vutils.save_image(samples, "pixelcnn_cifar10_samples.png", nrow=8, padding=2)

In [23]:
class GatedMaskedConv2d(MaskedConv2d):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super(GatedMaskedConv2d, self).__init__(mask_type, in_channels, out_channels * 2, kernel_size, stride, padding, bias)

    def forward(self, x):
        out = super(GatedMaskedConv2d, self).forward(x)
        out_filter, out_gate = out.chunk(2, dim=1)
        return torch.tanh(out_filter) * torch.sigmoid(out_gate)

In [24]:
class PixelCNNpp(PixelCNN):
    def __init__(self, input_channels=3, n_filters=64, kernel_size=7, n_layers=8, nr_mix=5):
        super(PixelCNNpp, self).__init__(input_channels, n_filters, kernel_size, n_layers)
        self.nr_mix = nr_mix
        num_params = nr_mix * (3 * input_channels + (input_channels > 1) * 3)
        self.net[-1] = nn.Conv2d(n_filters, num_params, kernel_size=1)

    def forward(self, x):
        return self.net(x)

In [39]:
def discretized_logistic_loss(x, l):
    """
    Simplified PixelCNN++ loss with a single logistic component.
    x: [B,C,H,W] in [-1,1]
    l: logits [B,C*3,H,W] -> mean, log_scale, logit_prob (no mixture)
    """
    B, C, H, W = x.size()
    l = l.permute(0, 2, 3, 1)  # [B,H,W,C*3]

    means = l[..., :C]
    log_scales = torch.clamp(l[..., C:2*C], min=-7.)

    x = x.permute(0, 2, 3, 1)
    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)

    plus_in = inv_stdv * (centered_x + 1./255.)
    cdf_plus = torch.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1./255.)
    cdf_min = torch.sigmoid(min_in)

    log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = torch.log((1. - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min

    log_probs = torch.where(x < -0.999, log_cdf_plus,
                            torch.where(x > 0.999, log_one_minus_cdf_min,
                                        torch.log(cdf_delta.clamp(min=1e-12))))

    return -log_probs.mean()

In [None]:
# transform = transforms.Compose([
#     transforms.ToTensor()
# ])
# train_loader = data.DataLoader(
#     datasets.CIFAR10('data', train=True, download=True, transform=transform),
#     batch_size=128, shuffle=True
# )
# sample_batch, _ = next(iter(train_loader))
# input_channels = sample_batch.shape[1]

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
cifar_model_pp = PixelCNNpp(input_channels=input_channels, n_filters=192, n_layers=15, nr_mix=10).to(device)
optimizer = torch.optim.Adam(cifar_model_pp.parameters(), lr=1e-3)
training_model(cifar_model_pp, train_loader, optimizer, device, epochs=10, loss_function=discretized_logistic_loss)

Epoch 1: Loss = 4.9278
Epoch 2: Loss = 4.6382
Epoch 3: Loss = 3.8745
Epoch 4: Loss = 3.6185


In [None]:
samples = sample(cifar_model_pp, img_size=(32, 32), n_channels=3, n_samples=64)
vutils.save_image(samples, "pixelcnn_cifar10pp_samples1.png", nrow=8, padding=2)