In [7]:
"""
PixelCNN / Row LSTM / Diagonal BiLSTM (PixelRNN) - PyTorch implementation skeleton

This file contains:
- MaskedConv2d (mask type 'A' and 'B')
- PixelCNN (stack of masked convs)
- RowLSTM layer and stack
- Diagonal BiLSTM layer (skew/unskew + bidirectional LSTM along diagonals)
- CIFAR-10 dataloaders, training loop, evaluation (bits/dim), plotting

This is a reasonably complete, runnable starting point. It prioritizes clarity and
correctness over absolute performance. For larger-scale experiments you should
optimize/sketch kernels and batching of diagonals.

Notes / caveats:
- The Diagonal BiLSTM here uses Python loops over diagonals for clarity. It is
  correct but slower than highly-optimized implementations.
- MaskedConv2d implements both spatial masking and channel-order masking for RGB.

"""

import math
import os
import time
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ----------------------------- Utilities ---------------------------------

def to_device(x):
    return x.cuda() if torch.cuda.is_available() else x

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------- Masked Conv2d -------------------------------
class MaskedConv2d(nn.Conv2d):
    """A Conv2d with a mask applied to the weights to enforce autoregressive ordering.

    Supports two mask types: 'A' (first layer) and 'B' (subsequent layers).
    Also supports channel-wise masking to enforce RGB ordering when in_channels==3.

    Reference: van den Oord et al. PixelCNN/PixelRNN.
    """
    def __init__(self, in_channels, out_channels, kernel_size, mask_type,
                 stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size,
                         stride, padding, dilation, groups, bias)
        assert mask_type in ("A", "B"), "mask_type must be 'A' or 'B'"
        self.register_buffer('mask', self.weight.data.clone().zero_())
        kH, kW = self.kernel_size
        center_h = kH // 2
        center_w = kW // 2

        # Basic spatial mask: positions below/right of center are masked.
        mask = torch.ones_like(self.weight.data)
        for i in range(kH):
            for j in range(kW):
                if i > center_h or (i == center_h and j > center_w):
                    mask[:, :, i, j] = 0
        if mask_type == 'A':
            # center pixel excluded for mask A: zero center
            mask[:, :, center_h, center_w] = 0
        else:
            # mask_type == 'B' -> center allowed (no change)
            pass

        # Channel-wise masking for RGB ordering in first layer typically
        # If input and output channels are multiples of 3 we can apply the
        # classic PixelCNN channel mask to prevent e.g. G depending on G of same pixel.
        if in_channels % 3 == 0 and out_channels % 3 == 0:
            in_groups = in_channels // 3
            out_groups = out_channels // 3
            for out_g in range(out_groups):
                for in_g in range(in_groups):
                    for i in range(kH):
                        for j in range(kW):
                            if i == center_h and j == center_w:
                                # center pixel channel relation
                                if mask_type == 'A':
                                    # out_channel group must be greater than in_channel group
                                    if out_g <= in_g:
                                        mask[out_g*3:(out_g+1)*3, in_g*3:(in_g+1)*3, i, j] = 0
                                else:
                                    # mask B: allow center for equal groups but still block future channels
                                    # allow when out_g > in_g OR (out_g == in_g)
                                    # but need to block connections from a channel to "future" channels inside same pixel
                                    # We'll allow equal groups but still zero upper-triangular inside the 3x3 submatrix
                                    # For simplicity, block upper-triangular channel connections in center
                                    sub = torch.ones(3, 3)
                                    # Channel order: R(0),G(1),B(2). Allow:
                                    # out R can see in R only if ... (for mask B center, allow self)
                                    # Following common implementations, zero upper triangle
                                    sub[0,1] = 0
                                    sub[0,2] = 0
                                    sub[1,2] = 0
                                    mask[out_g*3:(out_g+1)*3, in_g*3:(in_g+1)*3, i, j] *= sub
                            else:
                                # Non-center positions (past pixels): full channels allowed
                                pass
        # else: if channels not multiples of 3, fall back to pure spatial mask

        self.mask.copy_(mask)

    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

# ----------------------------- PixelCNN ---------------------------------
# --------------------- Improved PixelCNN (residual, deeper) ---------------------
class ResidualBlock(nn.Module):
    """Residual block used in PixelCNN (masked conv -> ReLU -> masked conv + 1x1, residual add)."""
    def __init__(self, nr_filters, kernel_size=3):
        super().__init__()
        pad = kernel_size // 2
        self.net = nn.Sequential(
            nn.ReLU(inplace=True),
            MaskedConv2d(nr_filters, nr_filters, kernel_size=kernel_size, mask_type='B', padding=pad),
            nn.ReLU(inplace=True),
            nn.Conv2d(nr_filters, nr_filters, kernel_size=1)
        )

    def forward(self, x):
        return x + self.net(x)  # residual add


class PixelCNN(nn.Module):
    def __init__(self, in_channels=3, nr_residual=15, nr_filters=128, ks_first=7):
        super().__init__()
        pad = ks_first // 2
        # first layer: Mask A (prevents same-pixel leakage)
        self.first = MaskedConv2d(in_channels, nr_filters, kernel_size=ks_first, mask_type='A', padding=pad)

        # stack of residual blocks (each uses Mask B)
        self.res_blocks = nn.ModuleList([ResidualBlock(nr_filters, kernel_size=3) for _ in range(nr_residual)])

        # final processing
        self.final = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(nr_filters, nr_filters, kernel_size=1),
            nn.ReLU(inplace=True),
        )
        # output conv -> C * 256 logits per pixel
        self.out_conv = nn.Conv2d(nr_filters, in_channels * 256, kernel_size=1)

    def forward(self, x):
        # x should be float tensor (we keep input as float; target remains integer labels 0..255)
        h = self.first(x)
        for block in self.res_blocks:
            h = block(h)
        h = self.final(h)
        logits = self.out_conv(h)
        return logits


# ----------------------------- Row LSTM ---------------------------------
class RowLSTMCell(nn.Module):
    """A convolutional LSTM cell that moves row-by-row.

    Input-to-state: precomputed masked conv kx1 producing 4*h gates.
    State-to-state: conv with kernel K_ss applied to previous hidden state.
    """
    def __init__(self, in_channels, hidden_channels, kernel_size=3):
        super().__init__()
        self.hidden_channels = hidden_channels
        # input-to-state conv should be masked (no future pixels in same row)
        pad = (kernel_size // 2, 0)
        self.input_conv = MaskedConv2d(in_channels, 4*hidden_channels, kernel_size=(kernel_size,1),
                                       mask_type='B', padding=pad)
        # state-to-state conv: operates on previous row's hidden state
        # We'll use a conv that looks vertically (k x 1) on the previous hidden map
        self.state_conv = nn.Conv2d(hidden_channels, 4*hidden_channels, kernel_size=(kernel_size,1), padding=pad)

        # 1x1 projection to return to in_channels feature space (for residual)
        self.out_conv = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)

    def forward(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape
        device = x.device
        its = self.input_conv(x)  # [B, 4h, H, W]

        h_prev = torch.zeros(B, self.hidden_channels, W, device=device)
        c_prev = torch.zeros(B, self.hidden_channels, W, device=device)

        outputs = []
        for r in range(H):
            # gates for this row from input-to-state
            gates_x = its[:, :, r, :]  # [B, 4h, W]
            # compute state-to-state conv on previous hidden arranged as [B, h, 1, W]
            h_prev_map = h_prev.unsqueeze(2)  # [B, h, 1, W]
            rec = self.state_conv(h_prev_map).squeeze(2)  # [B, 4h, W]

            gates = rec + gates_x  # broadcasting
            # split gates
            i_gate, f_gate, o_gate, g_gate = gates.chunk(4, dim=1)
            i = torch.sigmoid(i_gate)
            f = torch.sigmoid(f_gate)
            o = torch.sigmoid(o_gate)
            g = torch.tanh(g_gate)

            c = f * c_prev + i * g
            h = o * torch.tanh(c)

            outputs.append(h.unsqueeze(2))  # [B, h, 1, W]
            h_prev = h
            c_prev = c

        out_map = torch.cat(outputs, dim=2)  # [B, h, H, W]
        # project back
        residual = self.out_conv(out_map)
        return residual

class RowLSTMStack(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=64, n_layers=3):
        super().__init__()
        layers = []
        cur_ch = in_channels
        for _ in range(n_layers):
            layers.append(RowLSTMCell(cur_ch, hidden_channels))
            cur_ch = in_channels  # RowLSTMCell returns features of in_channels due to residual projection
        self.layers = nn.ModuleList(layers)
        self.out_conv = nn.Conv2d(in_channels, in_channels * 256, kernel_size=1)

    def forward(self, x):
        h = x
        for layer in self.layers:
            h = h + layer(h)
        logits = self.out_conv(F.relu(h))
        return logits

# ------------------------- Diagonal BiLSTM -------------------------------

def skew_tensor(x):
    """Skew input map so diagonals become columns.

    Input x: [B, C, H, W]
    Output: [B, C, H, W+H-1] (we pad on the right)
    For row r (0-based) we shift it right by r positions.
    """
    B, C, H, W = x.shape
    new_W = W + H - 1
    device = x.device
    out = x.new_zeros((B, C, H, new_W))
    for r in range(H):
        out[:, :, r, r:r+W] = x[:, :, r, :]
    return out


def unskew_tensor(x, orig_W):
    # x: [B, C, H, W+H-1]
    B, C, H, Wfull = x.shape
    W = orig_W
    out = x.new_zeros((B, C, H, W))
    for r in range(H):
        out[:, :, r, :] = x[:, :, r, r:r+W]
    return out

class DiagonalBiLSTMLayer(nn.Module):
    """Diagonal BiLSTM: skew, run bidirectional LSTM along diagonals, unskew, combine.

    Implementation: For clarity we extract diagonals as sequences and run an nn.LSTM
    over each diagonal (packed as batch of sequences of varying lengths). This is
    simple to implement though not the most efficient.
    """
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        # input-to-state conv is 1x1
        self.input_conv = nn.Conv2d(in_channels, 4*hidden_channels, kernel_size=1)
        # We'll use linear layers for state-to-state inside LSTM cell (we'll use nn.LSTM for simplicity)
        # We'll process each diagonal as a sequence of vectors of size in_channels
        self.lstm_fwd = nn.LSTM(input_size=in_channels, hidden_size=hidden_channels, batch_first=True)
        self.lstm_bwd = nn.LSTM(input_size=in_channels, hidden_size=hidden_channels, batch_first=True)
        # projection back to input channels
        self.out_conv = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)

    def extract_diagonals(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape
        diags = []
        lengths = []
        for d in range(H+W-1):
            # diagonal d has elements where r + c = d
            elems = []
            for r in range(H):
                c = d - r
                if 0 <= c < W:
                    elems.append(x[:, :, r, c])  # [B, C]
            if len(elems) > 0:
                # stack along time dim: [B, len, C]
                seq = torch.stack(elems, dim=1)
                diags.append(seq)
                lengths.append(seq.size(1))
        return diags, lengths

    def forward(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape
        device = x.device
        # We will produce a hidden map of size [B, hidden, H, W]
        hmap = x.new_zeros((B, self.hidden_channels, H, W))

        # Extract diagonals as sequences
        diags, lengths = self.extract_diagonals(x)
        # Process each diagonal sequence with forward and backward LSTMs
        for idx, seq in enumerate(diags):
            # seq: [B, L, C]
            L = seq.size(1)
            # forward
            out_f, _ = self.lstm_fwd(seq)  # [B, L, hidden]
            # backward: reverse sequence along time dim
            seq_rev = torch.flip(seq, dims=[1])
            out_b_rev, _ = self.lstm_bwd(seq_rev)
            out_b = torch.flip(out_b_rev, dims=[1])
            out = out_f + out_b  # combine
            # place outputs back into hmap
            # need to map positions back: diagonal index idx corresponds to positions where r + c = d
            t = 0
            for r in range(H):
                c = idx - r
                if 0 <= c < W:
                    hmap[:, :, r, c] = out[:, t, :].transpose(1, 0).transpose(2, 1).squeeze(-1) if False else out[:, t, :].transpose(1,0).transpose(2,1).squeeze(-1)
                    # simpler: assign by matching shapes
                    # out[:, t, :] is [B, hidden]; target hmap[:, :, r, c] is [B, hidden]
                    hmap[:, :, r, c] = out[:, t, :].permute(1,0) if False else out[:, t, :].permute(1,0)
                    # the above two lines are awkward; fix below
                    t += 1
        # The above assignment approach is messy; rebuild properly below
        # Recompute with clean assignment
        hmap = x.new_zeros((B, self.hidden_channels, H, W))
        for idx, seq in enumerate(diags):
            out_f, _ = self.lstm_fwd(seq)
            seq_rev = torch.flip(seq, dims=[1])
            out_b_rev, _ = self.lstm_bwd(seq_rev)
            out_b = torch.flip(out_b_rev, dims=[1])
            out = out_f + out_b  # [B, L, hidden]
            t = 0
            for r in range(H):
                c = idx - r
                if 0 <= c < W:
                    # assign out[:, t, :] -> hmap[:, :, r, c]
                    hmap[:, :, r, c] = out[:, t, :].permute(1, 0)
                    # but shapes don't match; fix by direct assignment
                    # hmap[:, :, r, c] has shape [B, hidden]; out[:, t, :] has shape [B, hidden]
                    hmap[:, :, r, c] = out[:, t, :]
                    t += 1

        # project back
        residual = self.out_conv(hmap)
        return residual

# ----------------------------- Full wrappers ----------------------------
class PixelRNN_PixelCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = PixelCNN()

    def forward(self, x):
        return self.model(x)

class PixelRNN_RowLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = RowLSTMStack()
    def forward(self, x):
        return self.model(x)

class PixelRNN_DiagBiLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = DiagonalBiLSTMLayer(in_channels=3, hidden_channels=64)
        self.out_conv = nn.Conv2d(3, 3*256, kernel_size=1)
    def forward(self, x):
        h = self.layer(x)
        logits = self.out_conv(F.relu(h))
        return logits

# ----------------------------- Training ---------------------------------

def logits_to_nll(logits, target):
    """
    logits: [B, C*256, H, W]
    target: [B, C, H, W] with ints 0..255
    Returns total negative log-likelihood (sum over batch) in nats
    """
    B, C256, H, W = logits.shape
    C = target.shape[1]
    logits = logits.view(B, C, 256, H, W).permute(0,1,3,4,2).contiguous()  # [B,C,H,W,256]
    logits = logits.view(-1, 256)
    targets = target.permute(0,2,3,1).contiguous().view(-1)
    loss = F.cross_entropy(logits, targets, reduction='sum')
    return loss


def nll_to_bits_per_dim(nll, batch_size, C, H, W):
    # bits/dim = (nll / log(2)) / (N * D)
    nll_bits = nll / math.log(2)
    dims = batch_size * C * H * W
    return nll_bits / dims


def to_long_tensor(x):
    # torchvision ToTensor gives float in [0,1]
    return (x * 255).long()

def get_dataloaders(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        to_long_tensor
    ])
    train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # On Windows, set num_workers=0 to avoid multiprocessing issues
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader



def train_epoch(model, opt, loader, device):
    model.train()
    total_nll = 0.0
    total_samples = 0
    t0 = time.time()
    for xb, _ in loader:
        # xb: [B, C, H, W] floats? Our transforms return long ints
        xb = xb.to(device=device, dtype=torch.long)  # integers
        x_in = xb.float()  # feed through model as floats (some layers rely on float inputs)
        logits = model(x_in)
        loss = logits_to_nll(logits, xb)
        opt.zero_grad()

        loss.backward()
        # gradient clipping (important for stability, esp. if you later train LSTM variants)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()

        opt.step()
        total_nll += loss.item()
        total_samples += xb.size(0)
    t1 = time.time()
    return total_nll, total_samples, t1 - t0


def eval_epoch(model, loader, device):
    model.eval()
    total_nll = 0.0
    total_samples = 0
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device=device, dtype=torch.long)
            x_in = xb.float()
            logits = model(x_in)
            loss = logits_to_nll(logits, xb)
            total_nll += loss.item()
            total_samples += xb.size(0)
    return total_nll, total_samples


In [8]:
def init_weights_xavier(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight)
        if getattr(m, 'bias', None) is not None:
            nn.init.constant_(m.bias, 0)

In [9]:
# ----------------------------- Notebook Run -------------------------------
# Parameters (can be edited directly in notebook cells)
model_choice = 'pixelcnn'  # 'pixelcnn', 'rowlstm', or 'diagbilstm'
batch_size = 64
epochs = 50
lr = 1e-3

# Dataloaders
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

# Model selection
if model_choice == 'pixelcnn':
    model = PixelRNN_PixelCNN().to(DEVICE)
    model.apply(init_weights_xavier)
elif model_choice == 'rowlstm':
    model = PixelRNN_RowLSTM().to(DEVICE)
else:
    model = PixelRNN_DiagBiLSTM().to(DEVICE)

# Optimizer
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# Use RMSprop as in the paper (common setting)
opt = torch.optim.RMSprop(model.parameters(), lr=lr, alpha=0.95, eps=1e-8)

# # Optional: a scheduler to reduce LR every N epochs
# from torch.optim.lr_scheduler import StepLR
# scheduler = StepLR(opt, step_size=10, gamma=0.5)  # reduce lr by 2 every 10 epochs (tune as needed)

# Training loop
history = {'train_bpd': [], 'val_bpd': []}
for ep in range(epochs):
    t_nll, t_samples, t_time = train_epoch(model, opt, train_loader, DEVICE)

    #scheduler.step()

    train_bpd = nll_to_bits_per_dim(t_nll, t_samples, 3, 32, 32)
    v_nll, v_samples = eval_epoch(model, val_loader, DEVICE)
    val_bpd = nll_to_bits_per_dim(v_nll, v_samples, 3, 32, 32)
    history['train_bpd'].append(train_bpd)
    history['val_bpd'].append(val_bpd)
    print(f"Epoch {ep+1}/{epochs}  train bpd: {train_bpd:.4f}  val bpd: {val_bpd:.4f}  time: {t_time:.1f}s")

# Plot results
plt.plot(history['train_bpd'], label='train bpd')
plt.plot(history['val_bpd'], label='val bpd')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('bits/dim')
plt.title(f"Training {model_choice}")
plt.show()

print('Training complete.')

KeyboardInterrupt: 