In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
# Import einops if not already imported
from einops import rearrange

In [None]:
# PixelRNN model using standard LSTM cells for classification
class PixelRNN(nn.Module):
    def __init__(self, input_channels=3, hidden_size=128, num_layers=2, num_classes=256):
        super(PixelRNN, self).__init__()
        self.input_channels = input_channels
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        
        # Use standard PyTorch LSTMCell
        self.cells = nn.ModuleList([
            nn.LSTMCell(input_channels if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
        
        # output layer: output num_classes for classification
        self.output_layer = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        device = x.device
        batch_size, nc, height, width = x.size()
        
        # Flatten image to 1D sequence
        x_flat = x.view(batch_size, 1, height * width)
        
        # Initialize hidden and cell states for each layer
        hidden_states = []
        cell_states = []
        for _ in range(self.num_layers):
            hidden_states.append(torch.zeros(batch_size, self.hidden_size).to(device))
            cell_states.append(torch.zeros(batch_size, self.hidden_size).to(device))
        
        # Output predictions (we predict all pixels except the first one)
        # Changed shape to include num_classes dimension
        outputs = torch.zeros(batch_size, height*width, self.num_classes).to(device)
        
        # Process pixels sequentially in raster scan order (outter for loop enforce autoregressive property)
        # This is different from RNN where outter loop is for layers. 
        for i in range(height * width):
            # Get current pixel (or zeros for the first pixel)
            if i == 0:
                pixel_input = torch.zeros(batch_size, self.input_channels).to(device)
            else:
                pixel_input = x_flat[:, 0, i-1].unsqueeze(1)
            
            # Forward through LSTM layers
            layer_input = pixel_input
            for l in range(self.num_layers):
                hidden_states[l], cell_states[l] = self.cells[l](
                    layer_input, 
                    (hidden_states[l], cell_states[l])
                )
                layer_input = hidden_states[l]
            
            # Predict next pixel (if not at the end)
            if i < height * width:
                prediction = self.output_layer(hidden_states[-1])
                outputs[:, i] = prediction
        
        # Reshape outputs to maintain the class dimension
        return outputs.view(batch_size, nc, height, width, self.num_classes)

    def sample(self, batch_size=1, image_size=(28, 28), channels=1, device='cpu'):
        height, width = image_size
        generated_images = torch.zeros(batch_size, channels, height, width).to(device)
        
        # Initialize hidden and cell states for each layer
        hidden_states = []
        cell_states = []
        for _ in range(self.num_layers):
            hidden_states.append(torch.zeros(batch_size, self.hidden_size).to(device))
            cell_states.append(torch.zeros(batch_size, self.hidden_size).to(device))
        
        # Generate pixels sequentially in raster scan order
        for i in range(height * width):
            # Get current pixel (or zeros for the first pixel)
            if i == 0:
                pixel_input = torch.zeros(batch_size, self.input_channels).to(device)
            else:
                # Use previously generated pixel as input
                h_idx, w_idx = divmod(i-1, width)
                pixel_input = generated_images[:, :, h_idx, w_idx].reshape(batch_size, channels)
            
            # Forward through LSTM layers
            layer_input = pixel_input
            for l in range(self.num_layers):
                hidden_states[l], cell_states[l] = self.cells[l](
                    layer_input, 
                    (hidden_states[l], cell_states[l])
                )
                layer_input = hidden_states[l]
            
            # Get prediction
            logits = self.output_layer(hidden_states[-1])
            
            # Sample from the predicted distribution
            probs = F.softmax(logits, dim=1)
            pixel_values = torch.multinomial(probs, 1).float() / (self.num_classes - 1)
            
            # Place sampled pixel in the generated image
            h_idx, w_idx = divmod(i, width)
            generated_images[:, :, h_idx, w_idx] = pixel_values.view(batch_size, channels)
        
        return generated_images        


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

In [None]:
# Initialize model
model = PixelRNN(input_channels=1, hidden_size=32, num_layers=1, num_classes=256)
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 50

In [None]:
model.train()
for epoch in range(epochs):
    running_loss = 0.0
    for i, (inputs, _) in enumerate(train_loader):
        inputs = inputs.to(device)
        batch_size, channels, height, width = inputs.size()
        
        # Quantize inputs to match output classes
        targets = (inputs * (model.num_classes - 1)).long()
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        
        # Rearrange outputs to (batch * height * width * channels, num_classes)
        outputs_flat = rearrange(outputs, 'b c h w n -> (b h w c) n')
        
        # Rearrange targets to (batch * height * width * channels)
        targets_flat = rearrange(targets, 'b c h w -> (b h w c)')
    
        loss = criterion(outputs_flat, targets_flat)

        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 10 == 9:
            print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.4f}')
            running_loss = 0.0


In [None]:
import matplotlib.pyplot as plt

In [None]:
# Generate samples
with torch.no_grad():
    model.eval()
    samples = model.sample(batch_size=16, image_size=(28, 28), device=device)
    # draw the sample using matplotlib
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(samples[i].view(samples.shape[-2], samples.shape[-1]).cpu().numpy(), cmap='gray')
        plt.axis('off')
# Save model
torch.save(model.state_dict(), 'pixelrnn_model.pth')