In [1]:
%load_ext autoreload
%autoreload 2

# Exercise 6

<img src="./images/06.png" width=800>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms

from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.autonotebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from utils import train_network, View, set_seed
import mlflow
from torchinfo import summary
import os

  from tqdm.autonotebook import tqdm


In [3]:
os.environ['MLFLOW_TRACKING_URI'] = './mlruns07_6'
mlflow.set_tracking_uri(os.environ.get('MLFLOW_TRACKING_URI'))

In [4]:
mlflow.set_experiment('Exercise07_6')

<Experiment: artifact_location='/home/spakdel/my_projects/Books/Inside-Deep-Learning/Exercises_InsideDeepLearning/Chapter_07/mlruns07_6/460950000382384422', creation_time=1750948549809, experiment_id='460950000382384422', last_update_time=1750948549809, lifecycle_stage='active', name='Exercise07_6', tags={}>

In [5]:
torch.backends.cudnn.deterministic = True
set_seed(42)

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Dataset and DataLoader

In [7]:
class AutoencodDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, index):
        # x, y = self.dataset.__getitem__(index)
        x, _ = self.dataset[index]
        x = x.flatten().unsqueeze(-1)
        return  x, x

In [8]:
train_data = AutoencodDataset(torchvision.datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True))
test_data_xy = torchvision.datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=True)
test_data_xx = AutoencodDataset(test_data_xy)
batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data_xx, batch_size=batch_size)


## Model

In [9]:
W = 28
H = 28
D = W * H
C =1
classes = 10

In [None]:
class AutoRegressive(nn.Module):

    def __init__(self, input_size, hidden_size, layers=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.layers = nn.ModuleList([nn.GRUCell(input_size, hidden_size)] + 
                                    [nn.GRUCell(hidden_size, hidden_size) for i in range(layers-1)])
        self.norms = nn.ModuleList([nn.LayerNorm(hidden_size) for i in range(layers)])
        
        self.pred_class = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),# (B, *, D)
            nn.LeakyReLU(),
            nn.LayerNorm(hidden_size), # (B, *, D)
            nn.Linear(hidden_size, input_size) #(B, *. D) -> B(B, *, VocabSize)
        )
        
    def initHiddenStates(self, B):
        """
        Creates an initial hidden state list for the RNN layers. 
        
        B: the batch size for the hidden states. 
        """
        return [torch.zeros(B, self.hidden_size, device=device) for _ in range(len(self.layers))]
        
    def step(self, x_in, h_prevs=None):
        """
        x_in: the input for this current time step and has shape (B) if the values need 
            to be embedded, and (B, D) if they have alreayd been embedded. 

        h_prevs: a list of hidden state tensors each with shape (B, self.hidden_size) for each 
            layer in the network. These contain the current hidden state of the RNN layers and 
            will be updated by this call. 
        """

        if h_prevs is None:
            h_prevs = self.initHiddenStates(x_in.shape[0])
        
        #Process the input 
        for l in range(len(self.layers)):
            h_prev = h_prevs[l]
            h = self.norms[l](self.layers[l](x_in, h_prev))

            h_prevs[l] = h
            x_in = h
        #Make predictions about the token
        return self.pred_class(x_in)
    
    def forward(self, input):
        #Input should be (B, T)
        #What is the batch size?
        B = input.size(0)
        #What is the max number of time steps?
        T = input.size(1)
        
        x = input
        
        #Initial hidden states
        h_prevs = self.initHiddenStates(B)
        
        last_activations = []
        for t in range(T):
            x_in = x[:,t,:] #(B, D)
            last_activations.append(self.step(x_in, h_prevs))
        
        last_activations = torch.stack(last_activations, dim=1) #(B, T, D)
        
        return last_activations

In [11]:
model = AutoRegressive(1, 128, layers=2)

## Training

In [None]:
loss_func = nn.MSELoss()
epochs = 5
params = {
    'device': device,
    'loss_func': loss_func.__class__.__name__,
    'epochs': epochs,
    'batch_size': batch_size,
    }

In [None]:
optimizer = optim.AdamW(model.parameters())
with open('model_summary.txt', 'w') as f:
    f.write(str(summary(model, inpt_size=(batch_size, D))))
with mlflow.start_run(nested=True, run_name='exercise_6'):
    params['optimizer'] = optimizer.defaults
    mlflow.log_artifact('model_summary.txt')
    mlflow.log_params(params)

    results = train_network(
        model=model,
        optimizer=optimizer,
        loss_func=loss_func,
        train_loader=train_loader,
        valid_loader=test_loader,
        epochs=epochs,
        device=device,                
    )

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

#

## Results

In [None]:
def generate_image(model, D=W*H, C=1, H=28, W=28, device="cpu"):
    model.eval()
    with torch.no_grad():
        generated_pixels = torch.zeros(1, D, 1).to(device) # Start with a blank image (all zeros)
        h_prevs = model.initHiddenStates(1) # Initialize hidden states for a single image

        for t in range(D): # Iterate through all D pixels
            # Input for this step is the pixel generated in the previous step (or 0 for the first)
            # The model's `step` function expects a batch dimension, so generated_pixels[:, t-1, :]
            # or just generated_pixels[0, t-1, :].unsqueeze(0) for a single input
            # For the first pixel (t=0), we feed a zero.
            if t == 0:
                current_pixel_input = torch.zeros(1, 1).to(device) # (B=1, input_size=1)
            else:
                current_pixel_input = generated_pixels[:, t-1, :] # Take the previously predicted pixel

            # Get the prediction for the current pixel and update hidden states
            predicted_pixel_value, h_prevs = model.step(current_pixel_input, h_prevs)

            # Store the predicted pixel value. It's already (1, 1)
            generated_pixels[0, t, :] = predicted_pixel_value.squeeze(0) # Store it in the sequence

        # Reshape the generated sequence into an image
        generated_image = generated_pixels.view(C, H, W)
        return generated_image.cpu() # Move to CPU for plotting/saving

generated_digit = generate_image(model, D=D, C=C, H=H, W=W, device=device)
import matplotlib.pyplot as plt
plt.imshow(generated_digit.squeeze().numpy(), cmap='gray')
plt.title("Generated Digit")
plt.show()