# Setup Environment and Import Libraries
Import TensorFlow, NumPy, and visualization libraries. Set random seeds for reproducibility.

In [None]:
# Setup Environment and Import Libraries

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Load and Preprocess MNIST Dataset
Load MNIST dataset using tf.keras.datasets, normalize pixel values to [0,1], and reshape data for PixelCNN input.

In [None]:
# Load and Preprocess MNIST Dataset

# Load MNIST dataset
(mnist_train, mnist_train_labels), (mnist_test, mnist_test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize pixel values to [0, 1]
mnist_train = mnist_train.astype('float32') / 255.0
mnist_test = mnist_test.astype('float32') / 255.0

# Reshape data for PixelCNN input
mnist_train = np.expand_dims(mnist_train, axis=-1)
mnist_test = np.expand_dims(mnist_test, axis=-1)

# Print shapes to verify
print(f'Training data shape: {mnist_train.shape}')
print(f'Test data shape: {mnist_test.shape}')

# Define PixelCNN Model Architecture
Implement masked convolutions and build the PixelCNN model using tf.keras.Model with appropriate layers for pixel generation.

In [None]:
# Define PixelCNN Model Architecture

class MaskedConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, mask_type, activation=None, **kwargs):
        super(MaskedConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.mask_type = mask_type
        self.activation = tf.keras.activations.get(activation)
        self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')

    def build(self, input_shape):
        # Build the conv layer first
        self.conv.build(input_shape)
        
        # Create mask
        kernel_shape = (self.kernel_size, self.kernel_size, input_shape[-1], self.filters)
        self.mask = np.ones(kernel_shape, dtype=np.float32)
        center = self.kernel_size // 2

        if self.mask_type == 'A':
            self.mask[center, center:, :, :] = 0
            self.mask[center + 1:, :, :, :] = 0
        else:
            self.mask[center, center + 1:, :, :] = 0
            self.mask[center + 1:, :, :, :] = 0

        self.built = True

    def call(self, inputs):
        masked_kernel = self.conv.kernel * self.mask
        self.conv.kernel.assign(masked_kernel)
        x = self.conv(inputs)
        if self.activation is not None:
            x = self.activation(x)
        return x

def build_pixelcnn(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = MaskedConv2D(64, 7, mask_type='A', activation='relu')(inputs)
    for _ in range(7):
        x = MaskedConv2D(64, 7, mask_type='B', activation='relu')(x)
    x = tf.keras.layers.Conv2D(256, 1, activation='softmax')(x)
    model = tf.keras.Model(inputs, x)
    return model

# Define input shape
input_shape = (28, 28, 1)

# Build PixelCNN model
pixelcnn_model = build_pixelcnn(input_shape)

# Print model summary
pixelcnn_model.summary()

# Create Training Pipeline
Set up loss function, optimizer, and training loop. Implement custom training step and model checkpointing.

In [None]:
# Create Training Pipeline

# Define loss function and optimizer
pixelcnn_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
)

# Setup checkpointing
checkpoint_dir = './training_checkpoints'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_dir + '/ckpt-{epoch}',
    save_weights_only=True,
    save_best_only=True,
    monitor='loss'
)

# Training parameters
epochs = 10
batch_size = 64

# Train the model
history = pixelcnn_model.fit(
    mnist_train, mnist_train,
    epochs=epochs,
    batch_size=batch_size,
    callbacks=[checkpoint_callback],
    verbose=1
)

print('Training complete.')


# Generate New Images
Implement autoregressive sampling to generate new MNIST digits pixel by pixel using the trained model.

In [None]:
# Generate New Images

# Function to generate new images
def generate_images(model, num_images, img_shape):
    # Initialize images with zeros
    generated_images = np.zeros((num_images, *img_shape), dtype=np.float32)
    
    # Generate each pixel sequentially
    for i in range(img_shape[0]):
        for j in range(img_shape[1]):
            # Get the model's predictions
            logits = model(generated_images, training=False)
            # Extract probabilities for current pixel position
            logits = logits[:, i, j, :]  # Shape: [batch_size, 256]
            # Sample from the predicted distribution
            samples = tf.random.categorical(tf.math.log(logits), num_samples=1)  # Shape: [batch_size, 1]
            samples = tf.cast(samples, tf.float32) / 255.0  # Normalize to [0, 1]
            # Update the generated images
            generated_images[:, i, j, 0] = samples[:, 0]
            
    return generated_images

# Generate new images
num_images = 10
generated_images = generate_images(pixelcnn_model, num_images, input_shape)

# Plot generated images
plt.figure(figsize=(10, 10))
for i in range(num_images):
    plt.subplot(1, num_images, i + 1)
    plt.imshow(generated_images[i, :, :, 0], cmap='gray')
    plt.axis('off')
plt.show()

# Visualize Results
Display generated samples and compare with real MNIST digits. Create grid visualization of multiple generated images.

In [None]:
# Visualize Results

# Function to plot a grid of images
def plot_image_grid(images, grid_size=(5, 5), figsize=(10, 10)):
    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=figsize)
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            ax.imshow(images[i, :, :, 0], cmap='gray')
        ax.axis('off')
    plt.show()

# Display generated images
print("Generated Images:")
plot_image_grid(generated_images, grid_size=(2, 5))

# Display real MNIST images for comparison
print("Real MNIST Images:")
plot_image_grid(mnist_test[:10], grid_size=(2, 5))