# Assignment 4

In this assignment, you will refactor the entire code to PyTorch, making it more modular and efficient.

## Importing Libraries

In [None]:
import os
from dataclasses import dataclass
from typing import List, Tuple
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import wandb
from utils import load_text, set_seed, configure_device

## Configuration

In [None]:
@dataclass
class MLPConfig:
    root_dir: str = os.getcwd() + "/../../"
    dataset_path: str = "data/names.txt"
    device: torch.device = torch.device('cpu')  # Automatic device configuration

    # Tokenizer
    vocab_size: int = 0  # Set later
    
    # Model
    context_size: int = 3
    d_embed: int = 8
    d_hidden: int = 64
    
    # Training
    val_size: float = 0.1
    batch_size: int = 32
    max_steps: int = 6000  # Max of max_steps = 6421
    lr: float = 0.01
    val_interval: int = 100
    log_interval: int = 100

    seed: int = 101

config = MLPConfig()

## Reproducibility

In [None]:
set_seed(config.seed)

## Device

In [None]:
config.device = configure_device()

## Tokenizer

In [None]:
chars = [chr(i) for i in range(97, 123)]  # all alphabet characters
chars.insert(0, ".")  # Add special token
config.vocab_size = len(chars)
str2idx = {char: idx for idx, char in enumerate(chars)}
idx2str = {idx: char for char, idx in str2idx.items()}

## Dataset

In [None]:
names = load_text(config.root_dir + config.dataset_path).splitlines()

## Preprocessing

In [None]:
# Train-Val Split
train_names, val_names = train_test_split(names, test_size=config.val_size, random_state=config.seed)

In [None]:
print(f"Train Size: {len(train_names)}")
print(f"Validation Size: {len(val_names)}")
print(f"Train Example: {train_names[0]}")
print(f"Validation Example: {val_names[0]}")

In [None]:
def prepare_dataset(_names):
    _inputs, _targets = [], []

    for name in _names:
        context = [0] * config.context_size

        for char in name + ".":
            idx = str2idx[char]
            _inputs.append(context)
            _targets.append(idx)
            context = context[1:] + [idx]  # Shift the context by 1 character

    _inputs = torch.tensor(_inputs)
    _targets = torch.tensor(_targets)

    return _inputs, _targets

### Task 1: PyTorch DataLoader

We have been using plain Python lists to and then converted them to PyTorch tensors. This is not efficient since it is loading the entire dataset into memory.

PyTorch provides `Dataset` and `DataLoader` class to load the data in memory on the fly. [PyTorch Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

Refactor the `prepare_dataset` function into a PyTorch `Dataset` class and use the `DataLoader` to efficiently load the data in batches.

In [None]:
# Dataset
class NamesDataset(Dataset):
    ################################################################################
    # TODO:                                                                        #
    # PyTorch Dataset requires 3 methods:                                          #
    # __init__ method to initialize the dataset                                    #
    # __len__ method to return the size of the dataset                             #
    # __getitem__ method to return a sample from the dataset                       #
    ################################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    def __init__(self, _names: List[str], context_size: int):
        """
        Initialize the dataset

        Args:
            _names (List[str]): List of names
            context_size (int): Context size of the model
        """

        
    def __len__(self) -> int:
        """
        Return the number of samples in the dataset

        Returns:
            (int): Number of samples
        """

        return length
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Return a sample from the dataset

        Args:
            idx (int): Index of the sample

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Input and target tensors
        """

        return input_ids, target_id
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

In [None]:
# Initialize the dataset
train_dataset = NamesDataset(train_names, config.context_size)
val_dataset = NamesDataset(val_names, config.context_size)

In [None]:
print(f"Number of Train Samples: {len(train_dataset)}")
print(f"Number of Validation Samples: {len(val_dataset)}")
print(f"First train (input, target): {train_dataset[0]}")
print(f"First validation (input, target): {val_dataset[0]}")
print(f"Second train (input, target): {train_dataset[1]}")
print(f"Second validation (input, target): {val_dataset[1]}")

In [None]:
# DataLoader
################################################################################
# TODO:                                                                        #
# Initialize the DataLoader for the training and validation datasets.          #
################################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

In [None]:
# Example batch
_x, _y = next(iter(train_loader))
print(f"Input Shape: {_x.shape}")   # (batch_size, context_size)
print(f"Target Shape: {_y.shape}")  # (batch_size)
print(f"Input: {_x[0]}")
print(f"Target: {_y[0]}")

## Model

### Task 2: MLP Model

Initialize the weights of the model using the `Kaiming` initialization.

What are other activation functions that can be used instead of `tanh`? What are the advantages and disadvantages? Use different activation functions and compare the results.


In [None]:
class MLP(nn.Module):
    ################################################################################
    # TODO:                                                                        #
    # Define the __init__ and forward methods for the MLP model.                   #
    # Use the Kaiming initialization for the weights.                              #
    # Use other activation functions instead of tanh and compare the results.      #
    ################################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

In [None]:
# Initialize the model
mlp = MLP(config.vocab_size, config.context_size, config.d_embed, config.d_hidden)
mlp.to(config.device) # Move the model to the device
print(mlp)
print("Number of parameters:", sum(p.numel() for p in mlp.parameters()))

## Training

### Task 3: Wandb Integration

[Weights and Biases](https://wandb.ai/site) is a platform to track your machine learning experiments. It is very useful to log the hyperparameters, metrics, and weights of the model. (We can't use matplotlib every time to visualize the results)

Create a free account on Wandb. Initialize the wandb run and log the hyperparameters and metrics.

**How to set up WANDB API KEY**
- Create an account on Wandb
- Go to `wandb.ai` -> `Settings` -> `API Keys` -> `Copy API Key`
- Set the API key as an environment variable `WANDB_API_KEY`
    - What is an environment variable? How to set it? Google `.env`

Note: Do not hardcode the API key in the script. Use environment variables.



In [None]:
wandb.login(key=os.environ.get("WANDB_API_KEY"))
wandb.init(
    project="Assignment-04",
    config={
        "d_embed": config.d_embed,
        "d_hidden": config.d_hidden,
        "lr": config.lr,
    },
    dir=config.root_dir
)

### Task 4: Training

Train the model. Change the hyperparameters and configurations. Log the results and analyze it.

In [None]:
def train(
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        max_steps: int,
        lr: float,
        val_interval: int,
        log_interval: int,
        device: torch.device,
):
    """
    Train the model for a fixed number of steps.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for the training data.
        val_loader (DataLoader): DataLoader for the validation data.
        max_steps (int): Maximum number of steps to train.
        lr (float): Learning rate.
        val_interval (int): Interval for validation.
        log_interval (int): Interval for logging.
        device (torch.device): Device to run the model on.
    """
    wandb.watch(model, log="all", log_freq=log_interval)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    progress_bar = tqdm(enumerate(train_loader, start=1), total=max_steps, desc="Training")
    running_loss = 0.0

    for step, (train_inputs, train_targets) in progress_bar:
        if step > max_steps:
            break
        model.train()
        train_inputs, train_targets = train_inputs.to(device), train_targets.to(device)
        optimizer.zero_grad()
        logits = model(train_inputs)
        loss = F.cross_entropy(logits, train_targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=f"{running_loss / step:.4f}")

        if step % val_interval == 0:
            model.eval()
            val_loss = 0.0
            total_samples = 0
            with torch.no_grad():
                for val_inputs, val_targets in val_loader:
                    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
                    val_logits = model(val_inputs)
                    batch_loss = F.cross_entropy(val_logits, val_targets)
                    val_loss += batch_loss.item() * val_inputs.size(0)
                    total_samples += val_inputs.size(0)
            wandb.log({"Val Loss": val_loss / total_samples}, step=step)

        if step % log_interval == 0:
            wandb.log({"Train Loss": running_loss / step}, step=step)

    progress_bar.close()
    wandb.finish()

Note: Unfortunatley PyTorch does not support infinite DataLoader. The train will stop when it reaches the end of the DataLoader. (max_steps=6421)

In [None]:
train(
    model=mlp,
    train_loader=train_loader,
    val_loader=val_loader,
    max_steps=config.max_steps,
    lr=config.lr,
    val_interval=config.val_interval,
    log_interval=config.log_interval,
    device=config.device
)

In [None]:
################################################################################
# TODO:                                                                        #
# Analyze the results                                                          #
# What hyperparameters worked well? What activation did you use? etc.          #
################################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
#
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

## Inference

In [None]:
def generate_name(model: nn.Module, context_size: int, decoder: dict, end_id: int, device: torch.device) -> str:
    """
    Generate a name using the model.

    Args:
        model (nn.Module): Model to generate the name.
        context_size (int): Context size of the model.
        decoder (dict): Decoder dictionary to convert indices to characters.
        end_id (int): End token id.
        device (torch.device): Device to run the model on

    Returns:
        (str): Generated name
    """
    new_name = []
    context = [end_id] * context_size

    while True:
        x = torch.tensor(context).unsqueeze(0).to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=-1)
        idx = torch.multinomial(probs, num_samples=1).item()
        new_name.append(decoder[idx])
        context = context[1:] + [idx]
        if idx == end_id:
            break

    return "".join(new_name)

In [None]:
for _ in range(5):
    print(generate_name(
        model=mlp,
        context_size=config.context_size,
        decoder=idx2str,
        end_id=str2idx["."],
        device=config.device
    ))