In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
# Import einops if not already imported
from einops import rearrange

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kH, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
        self.mask[:, :, kH // 2 + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

In [5]:
class DiagonalLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(DiagonalLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Input to hidden weights
        self.W_i = nn.Conv2d(input_size, hidden_size, kernel_size=1)
        self.W_f = nn.Conv2d(input_size, hidden_size, kernel_size=1)
        self.W_o = nn.Conv2d(input_size, hidden_size, kernel_size=1)
        self.W_g = nn.Conv2d(input_size, hidden_size, kernel_size=1)
        
        # Hidden to hidden weights (masked)
        self.U_i = MaskedConv2d('B', hidden_size, hidden_size, kernel_size=3, padding=1)
        self.U_f = MaskedConv2d('B', hidden_size, hidden_size, kernel_size=3, padding=1)
        self.U_o = MaskedConv2d('B', hidden_size, hidden_size, kernel_size=3, padding=1)
        self.U_g = MaskedConv2d('B', hidden_size, hidden_size, kernel_size=3, padding=1)
        
        # Biases
        self.b_i = nn.Parameter(torch.zeros(hidden_size))
        self.b_f = nn.Parameter(torch.zeros(hidden_size))
        self.b_o = nn.Parameter(torch.zeros(hidden_size))
        self.b_g = nn.Parameter(torch.zeros(hidden_size))
        
    def forward(self, x, state):
        h, c = state
        
        # Calculate gates
        i = torch.sigmoid(self.W_i(x) + self.U_i(h) + self.b_i.view(1, -1, 1, 1))
        f = torch.sigmoid(self.W_f(x) + self.U_f(h) + self.b_f.view(1, -1, 1, 1))
        o = torch.sigmoid(self.W_o(x) + self.U_o(h) + self.b_o.view(1, -1, 1, 1))
        g = torch.tanh(self.W_g(x) + self.U_g(h) + self.b_g.view(1, -1, 1, 1))
        
        # Update cell and hidden state
        c_new = f * c + i * g
        h_new = o * torch.tanh(c_new)
        
        return h_new, c_new

In [6]:
class DiagonalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(DiagonalLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.cells = nn.ModuleList([DiagonalLSTMCell(
            input_size if i == 0 else hidden_size, 
            hidden_size) for i in range(num_layers)])
        
    def forward(self, x, states=None):
        batch_size, _, height, width = x.size()
        
        if states is None:
            states = [(torch.zeros(batch_size, self.hidden_size, height, width).to(x.device),
                       torch.zeros(batch_size, self.hidden_size, height, width).to(x.device)) 
                      for _ in range(self.num_layers)]
        
        new_states = []
        for i, cell in enumerate(self.cells):
            h, c = states[i]
            new_h, new_c = cell(x if i == 0 else new_h, (h, c))
            new_states.append((new_h, new_c))
        
        return new_h, new_states


In [7]:
class PixelCNN(nn.Module):
    def __init__(self, input_channels=3, hidden_size=128, num_layers=2, num_classes=256):
        super(PixelCNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        
        # Initial convolution to process input
        self.input_conv = MaskedConv2d('A', input_channels, hidden_size, kernel_size=7, padding=3)
        
        # DiagonalLSTM layers
        self.lstm = DiagonalLSTM(hidden_size, hidden_size, num_layers)
        
        # Output convolutions
        self.output_conv1 = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
        self.output_conv2 = nn.Conv2d(hidden_size, input_channels * num_classes, kernel_size=1)
        
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Initial convolution
        x = self.input_conv(x)
        x = F.relu(x)
        
        # DiagonalLSTM
        x, _ = self.lstm(x)
        
        # Output layers
        x = F.relu(self.output_conv1(x))
        x = self.output_conv2(x)
        
        # Reshape output for softmax over color channels
        x = x.view(batch_size, channels, self.num_classes, height, width)
        
        return x
    
    def sample(self, batch_size=1, image_size=(28, 28), channels=1, device='cpu'):
        sample = torch.zeros(batch_size, channels, *image_size).to(device)
        
        # Generate image pixel by pixel
        for i in range(image_size[0]):
            for j in range(image_size[1]):
                for c in range(channels):
                    output = self.forward(sample)
                    probs = F.softmax(output[:, c, :, i, j], dim=1)
                    sample[:, c, i, j] = torch.squeeze(torch.multinomial(probs, 1).float() / (self.num_classes - 1))
        
        return sample


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

In [9]:
# Initialize model
model = PixelCNN(input_channels=1, hidden_size=128, num_layers=2, num_classes=256)
model.to(device)


PixelCNN(
  (input_conv): MaskedConv2d(1, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (lstm): DiagonalLSTM(
    (cells): ModuleList(
      (0-1): 2 x DiagonalLSTMCell(
        (W_i): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (W_f): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (W_o): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (W_g): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
        (U_i): MaskedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (U_f): MaskedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (U_o): MaskedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (U_g): MaskedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (output_conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
  (output_conv2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
epochs = 5

In [None]:
model.train()

for epoch in range(epochs):
    running_loss = 0.0
    for i, (inputs, _) in enumerate(train_loader):
        inputs = inputs.to(device)
        batch_size, channels, height, width = inputs.size()
        
        # Quantize inputs to match output classes
        targets = (inputs * (model.num_classes - 1)).long()
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        
        # Rearrange outputs to (batch * height * width * channels, num_classes)
        outputs_flat = rearrange(outputs, 'b c n h w -> (b h w c) n')
        
        # Rearrange targets to (batch * height * width * channels)
        targets_flat = rearrange(targets, 'b c h w -> (b h w c)')
    
        loss = criterion(outputs_flat, targets_flat)

        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 100 == 99:
            print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.4f}')
            running_loss = 0.0


In [None]:
import matplotlib.pyplot as plt

In [None]:
# Generate samples
with torch.no_grad():
    model.eval()
    samples = model.sample(batch_size=16, image_size=(28, 28), device=device)
    # draw the sample using matplotlib
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(samples[i].view(samples.shape[-2], samples.shape[-1]).cpu().numpy(), cmap='gray')
        plt.axis('off')
# Save model
torch.save(model.state_dict(), 'pixelcnn_model.pth')