<a href="https://colab.research.google.com/github/testgithubprecious/Ml_projects/blob/main/PixelCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install if not already: pip install torch torchvision matplotlib

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# ------------------------------
# Masked Convolution Layer
# ------------------------------
class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert mask_type in {"A", "B"}, "mask_type must be 'A' or 'B'"
        self.register_buffer("mask", torch.ones_like(self.weight))

        _, _, h, w = self.weight.shape
        yc, xc = h // 2, w // 2

        # Mask out future pixels
        self.mask[:, :, yc, xc + (mask_type == "B") :] = 0
        self.mask[:, :, yc + 1 :] = 0

    def forward(self, x):
        # Safe masking (no in-place modification of .data)
        weight = self.weight * self.mask
        return nn.functional.conv2d(
            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )


# ------------------------------
# Simple PixelCNN Model
# ------------------------------
class PixelCNN(nn.Module):
    def __init__(self, input_channels=1, n_filters=64, kernel_size=7, n_layers=7):
        super().__init__()
        layers = [
            MaskedConv2d("A", input_channels, n_filters, kernel_size, padding=kernel_size // 2),
            nn.ReLU(),
        ]
        for _ in range(n_layers - 2):
            layers += [
                MaskedConv2d("B", n_filters, n_filters, kernel_size, padding=kernel_size // 2),
                nn.ReLU(),
            ]
        layers.append(MaskedConv2d("B", n_filters, 256, 1))  # 256 logits for 8-bit grayscale
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# ------------------------------
# Dataset and Dataloader
# ------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),  # scales to [0, 1]
])

dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)


# ------------------------------
# Model, Loss, Optimizer
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PixelCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ------------------------------
# One Training Step (Demo)
# ------------------------------
model.train()
for i, (imgs, _) in enumerate(dataloader):
    imgs = imgs.to(device)

    # Convert targets to discrete [0,255] for CrossEntropy
    targets = (imgs * 255).long().squeeze(1)  # (B, H, W)

    outputs = model(imgs)  # model expects normalized [0,1]
    loss = criterion(outputs, targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"PixelCNN Training Loss (1 batch): {loss.item():.4f}")
    break