# Train transformer

In [1]:
import math

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR

from transformer_implementation import Tokenizer, TransformerConfig, DataLoaderFactory, LayerNorm, MultiHeadAttention, FeedForward
from utils import plot_losses

## Init
### Tokenizer

In [2]:
# init tokenizer
tokenizer = Tokenizer()

## Model

### DecoderBlock

In [3]:
class DecoderBlock(nn.Module):
    """
    A class that implements a single decoder block in the Transformer model.

    Each block consists of three sub-layers: a multi-head self-attention mechanism,
    a multi-head attention mechanism over the encoder's output, and a position-wise 
    fully connected feed-forward network. There is a residual connection around 
    each of the three sub-layers, followed by layer normalization.

    Attributes:
        - ln_1 (LayerNorm): Layer normalization before the first multi-head attention layer.
        - attn1 (MultiHeadAttention): First multi-head attention layer, with self-attention.
        - ln_2 (LayerNorm): Layer normalization before the second multi-head attention layer.
        - attn2 (MultiHeadAttention): Second multi-head attention layer, attends to encoder outputs.
        - ln_3 (LayerNorm): Layer normalization before the feed-forward network.
        - ffw (FeedForward): Position-wise feed-forward network.

    Args:
        config (Config): A configuration object with attribute `n_embd` and `bias`.
    """
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn1 = MultiHeadAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.ffw = FeedForward(config)

    def forward(self, x, tgt_mask=None) -> torch.Tensor:
        """
        Defines the computation performed at every call.

        Args:
            - x (torch.Tensor): The input tensor to the forward pass.
            - tgt_mask (torch.Tensor, optional): The mask tensor to ignore padding, size (B, 1, 1, T).

        Returns:
            - torch.Tensor: The output tensor of the block.
            - encoder_attn: The encoder attention weight of the current block.
        """
        # Masked MultiHeadAttention
        x = self.ln_1(x)
        x_attn, encoder_attn = checkpoint(self.attn1, x, x, x, True, tgt_mask)
        x = self.ln_2(x + x_attn)
        # FeedForward
        x = x + checkpoint(self.ffw, x)
        return x, encoder_attn

### Decoder

In [4]:
class Decoder(nn.Module):
    """
    This class implements the decoder part of the Transformer model.

    The Decoder consists of several DecoderBlocks arranged in sequence. The input first goes through an embedding 
    layer followed by a positional encoding layer. The output of this is then passed through each DecoderBlock in 
    sequence.

    Attributes:
        - decoder (nn.ModuleDict): A dictionary of modules making up the transformer decoder.
        - lm_head (nn.Linear): The final linear layer mapping from the embedding dimension to the vocabulary size.
        - config (:obj:`Config`): The configuration object for the transformer model.

    .. note:: The weight of the embedding layer and the linear layer are shared.

    Args:
        - config (:obj:`Config`): The configuration object with attributes such as `vocab_size`, `block_size`, `n_embd`, `dropout`, `n_layer`, and `bias`.
    """

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.decoder = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            # Learned positional encoding:
            # In this case, instead of using a fixed function to determine positional encoding,
            # we initialize a tensor of positional encodings which gets updated during training via backpropagation.
            # This method may potentially capture more complex position-related patterns than fixed positional encoding,
            # but it also introduces additional parameters to the model.
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        self.decoder.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of Decoder parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding: bool = True) -> int:
        """
        Returns the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.

        Args:
            - non_embedding (bool): If True, excludes the position embeddings count from the total. Default is True.

        Returns:
            - int: The number of parameters in the model.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.decoder.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        """
        Initializes the weights of the model.

        Args:
            - module (torch.nn.Module): The module of the model to be initialized.
        """
        if isinstance(module, nn.Linear):
            # init Linear layers with normal distribution (Gaussian initialization)
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                # bias initialization if necessary
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            # init Embedding layers with normal distribution (Gaussian initialization)
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, tgt_mask=None):
        """
        Defines the computation performed at every call.

        Args:
            - idx (torch.Tensor): The input tensor to the forward pass.
            - tgt_mask (torch.Tensor, optional): The mask tensor to ignore padding, size (B, 1, 1, T).

        Returns:
            - torch.Tensor: The output tensor (logits) of the model.
            - list: all layers of decoder attentions weights.
        """
        device = idx.device
        b, t = idx.size()
        
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
        tok_emb = self.decoder.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.decoder.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.decoder.drop(tok_emb + pos_emb) # Addition of input embd + positional encoding

        decoder_attn_all = []
        for block in self.decoder.h:
            x, encoder_attn = block(x, tgt_mask)
            decoder_attn_all.append(encoder_attn)
            
        x = self.decoder.ln_f(x)
        return self.lm_head(x), decoder_attn_all

### MiniGpt

In [5]:
class MiniGPT(nn.Module):
    """
    This class implements the Transformer model, which includes both the encoder and decoder.

    The Transformer is a sequence transduction model that uses attention mechanisms.
    It is primarily used in tasks that require understanding of context or relationships among words in a text.

    Attributes:
        - decoder (Decoder): The transformer decoder.
        - config (:obj:`Config`): The configuration object for the transformer model.

    Args:
        - config (:obj:`Config`): The configuration object with attributes such as `vocab_size`, `block_size`, `n_embd`, `dropout`, `n_layer`, and `bias`.
    """
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        self.decoder = Decoder(config)

    def forward(self, tgt, tgt_tgt=None, tgt_mask=None):
        """
        Defines the computation performed at every call.

        Args:
            - tgt (torch.Tensor): The input tensor to the decoder.
            - tgt_mask (torch.Tensor): The target_masks tensor to the decoder, size (B, 1, 1, T).

        Returns:
            - torch.Tensor: The output tensor (logits) of the model.
            - torch.Tensor: The loss tensor calculated on the basis of the decoder's output and target tensor.
        """
        assert tgt.dim() == 2, "tgt should be 2D (B, S)"
        if tgt_mask is not None:
            assert tgt_mask.dim() == 4, "tgt_mask should be 4D (B, 1, 1 S)"

        # tgt_shifted = tgt[:, :-1] # Shifted target
        tgt_shifted = tgt
        output, _, = self.decoder(tgt_shifted, tgt_mask[:, :, :-1, :-1] if tgt_mask else None)

        if tgt_tgt is None:
            loss = None
        else:
            B, T, C = output.shape
            output = output.view(B*T, C)
            tgt_tgt = tgt_tgt.view(B*T)
            # Calculate the loss, using both the output and the target
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.tokenizer.PAD_IDX) # Ignore padding tokens
            loss = loss_fct(output, tgt_tgt)
            
        return output, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.config.block_size:]
            # get the predictions
            logits, decoder_attn_all = self.decoder(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
        
    def save_model(self, path: str):
        """
        Saves the current state of the model to a file.

        Args:
            path (str): The path to the file where the model state should be saved.
        """
        torch.save(self.state_dict(), path)

    def load_model(self, path: str):
        """
        Loads the model state from a file.

        Args:
            path (str): The path to the file from where the model state should be loaded.

        Raises:
            ValueError: If the specified file does not exist.
        """
        if not os.path.exists(path):
            raise ValueError(f"{path} does not exist.")
        self.load_state_dict(torch.load(path))

### Config

In [6]:
# init config
config = TransformerConfig(
    tokenizer,
    block_size = 256,
    batch_size = 12,
    n_layer = 6, # 6,
    n_head = 8,
    n_embd = 256,
    max_epochs=100,
    train_data_size = 30000, # batch * 500 iters
    max_iters = 2500,
    eval_iters = 250,
)
print(config)

TransformerConfig(
	self.tokenizer=<transformer_implementation.Tokenizer.Tokenizer object at 0x000002028C917D90>,
	self.block_size=256,
	self.batch_size=12,
	self.n_layer=6,
	self.n_head=8,
	self.n_embd=256,
	self.dropout=0.1,
	self.bias=False,
	self.device='cuda',
	self.learning_rate=0.0003,
	self.max_epochs=100,
	self.max_iters=2500,
	self.eval_iters=250,
	self.train_data_size=30000,
	self.visualize=False,
	self.vocab_size=100277,
)


### Loading dataset

In [7]:
class TranslationDataset(Dataset):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.

    Args:
        - dataset (Dataset): a dataset from HuggingFace datasets library.
        - tokenizer (Tokenizer): The custom tiktoken tokenizer used to encode sequences.
        - block_size (int): The maximum sequence length for tokenization.
    """

    def __init__(self, dataset, tokenizer, block_size):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.block_size = block_size

    def __getitem__(self, index):
        """
        Get a tokenized example from the dataset at the specified index.

        Args:
            - index (int): the index of the example to fetch.

        Returns:
            - Dict: dictionary with keys 'inputs', 'inputs_mask', 'targets' , 'targets_mask' and 'translation', containing tokenized input,
            target sequences and original translation.
        """
        x = self.dataset[index:index+self.block_size]
        y = self.dataset[index+1:index+self.block_size+1]
        return {
            'x': x,
            'y': y,
        }

    def __len__(self) -> int :
        """
        Returns the length of the dataset.

        Returns:
            - int: the length of the dataset.
        """
        return len(self.dataset)

class DataLoaderFactory():
    """
    A class to instantiate PyTorch DataLoaders for different splits of a HuggingFace Dataset.

    Args:
        - block_size (int): The maximum sequence length for tokenization.
        - batch_size (int): The batch size for DataLoader.
        - tokenizer (Tokenizer): a tokenizer that has an encode method.
        - device (str): 'cpu' or 'cuda', depending on whether we use CPU or GPU.
    """

    def __init__(self, block_size, batch_size, tokenizer, device):
        self.block_size = block_size
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.device = device
        
        self.text, self.train_data, self.val_data = self.__load_data()
        
        self.dataloader_train = DataLoader(self.train_data, batch_size=batch_size, shuffle=True)
        self.dataloader_val = DataLoader(self.val_data, batch_size=batch_size, shuffle=True)
        
    def __load_data(self):
        with open('./data/input.txt', 'r', encoding='utf-8') as f:
            text = f.read()

        dataset = torch.tensor(tokenizer.encoder.encode(text), dtype=torch.long)
        n = int(0.9*len(dataset)) # first 90% will be train, rest val
        train_data = TranslationDataset(dataset[:n], self.tokenizer, self.block_size)
        val_data = TranslationDataset(dataset[n:], self.tokenizer, self.block_size)
        return text, train_data, val_data
        
    def __len__(self) -> int :
        """
        Print the length of each dataset and returns the length of all datasets.

        Returns:
            - int: the length of all dataset (train + val + test).
        """
        print("\033[95m\033[1m\033[4mNumber of data by datasets splits\033[0m")
        print(f"Train\t\t: {len(self.train_data)}\t-> {len(self.train_data)/self.batch_size}")
        print(f"Validation\t: {len(self.val_data)}\t\t-> {len(self.val_data)/self.batch_size}")
        total = len(self.train_data) + len(self.val_data)
        print(f"Total\t\t: {total}")
        return total

    def get_batch(self, split):
        """
        Choose the correct DataLoader and yield batches from it.

        Args:
            - split (str): 'train', 'val' or 'test'.

        Yields:
            - Dict: a dictionary with keys 'inputs', 'targets' and 'translation', containing a batch of tokenized input,
            target sequences and original translation.
        """
        # choose the correct dataloader
        if split == 'train':
            dataloader = self.dataloader_train
        else:
            dataloader = self.dataloader_val

        for batch in dataloader:
            # Move tensors to device
            batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
            yield batch_on_device

In [8]:
# loading dataset
dataset = DataLoaderFactory(config.block_size, config.batch_size, tokenizer, config.device)
len(dataset)

[95m[1m[4mNumber of data by datasets splits[0m
Train		: 271646	-> 22637.166666666668
Validation	: 30183		-> 2515.25
Total		: 301829


301829

In [9]:
batch = dataset.get_batch('train')
next_batch = next(batch)

In [10]:
next_batch['x'].size()

torch.Size([12, 256])

In [11]:
next_batch['x']

tensor([[   11,   477,   832,  ...,   279,   198, 65227],
        [  556,  2751,   198,  ...,    11, 14174,   301],
        [ 3021,  4250,  4295,  ...,  2795,  4400,    25],
        ...,
        [  198,    76,   372,  ...,   279, 59213,  3647],
        [  305,   552,  1461,  ...,    11,   856, 38031],
        [  872, 11427,  8703,  ..., 53120,   287,   198]], device='cuda:0')

## Utils

In [12]:

@torch.no_grad()
def estimate_loss(model, dataset, config, splits = ['train', 'val']):
    """
    This function estimates the loss of a model on specified data splits without performing backpropagation.
    It sets the model to evaluation mode, iterates over the data splits and calculates the average loss.

    Args:
        model (Transformer): The model for which loss needs to be estimated.
        dataset (CustomDataset): The dataset used for estimation. It should provide a 'get_batch' method.
        config (Config): The configuration object defining the number of evaluation iterations.
        splits (list[str]): List of the names of data splits to use for estimation.

    Returns:
        out (dict): A dictionary with split names as keys and corresponding average loss as values.
    """
    # Create an empty dictionary to store the average loss for each split
    out = {}

    # Set the model to evaluation mode
    model.eval()

    # Loop over the data splits
    for split in splits:
        # Initialize a tensor to store the losses for each iteration in the current split
        losses = torch.zeros(config.eval_iters)

        # Get a batch iterator for the current split
        batch = dataset.get_batch(split)
        
        # Initialize a progress bar for the inner loop
        inner_loop = tqdm(range(config.eval_iters), desc=f"Evaluation - {split}", leave=False)

        # Start the inner loop
        for k in inner_loop:
            # Sample a new batch of data
            n_batch = next(batch)
            X = n_batch['x']
            Y = n_batch['y']
            
            # Evaluate the loss for the current batch
            logits, loss = model(X, Y)

            # Store the current loss
            losses[k] = loss.item()
            
        # Calculate and store the mean loss for the current split
        out[split] = losses.mean()

    # Return the dictionary with the average losses
    return out

### Model

In [13]:
# Create model
model = MiniGPT(config)
model.train()
# Use nn.DataParallel to wrap the model.
# This will distribute the operations to multiple GPUs if they are available.
model = nn.DataParallel(model)
model = model.to(config.device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.98), eps=1e-9)

number of Decoder parameters: 30.39M


## Training

### Loop

In [14]:
# learning rate warmup and then decay, which is a standard practice in Transformer training.
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """ "warm-up, then decay" strategy. """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda)

def training_loop(model, optimizer, dataset, config, saved_path = "./out/minigpt_state_dict.pth"):
    """
    This function performs the training loop for the given transformer model. It trains the model using the provided 
    optimizer and dataset according to the specified configuration. 

    Args:
        - model (Transformer): The transformer model to be trained.
        - optimizer (torch.optim.Optimizer): The optimizer used to update the model's parameters.
        - dataset (CustomDataset): The dataset used for training and validation. It should provide a 'get_batch' method.
        - config (Config): The configuration object that defines parameters like max_iters.

    Returns:
        - losses_list (dict): A dictionary that contains the training and validation losses per evaluation step.
    """
    # This is the total number of training steps,
    # which is typically the number of training examples times the number of epochs.
    num_training_steps = config.train_data_size * config.max_epochs
    # Choose warmup_steps such that it's 1% of total steps
    num_warmup_steps = num_training_steps // 100
    # You can add this after defining your optimizer
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    # Initialize a dictionary to keep track of training and validation losses
    losses_list = {
        'train': [],
        'val': [],
    }
    
    # Initialize minimum loss with a high value and the iteration number where the minimum was observed
    iter_saved = 0

    # init early stop
    best_loss = float('inf')
    val_loss = float('inf')
    # This is the number of epochs with no improvement after which training will be stopped.
    patience = 3
    # This is used to keep track of the number of epochs without improvement.
    patience_counter = 0

    # Initialize a progress bar for the outer training loop
    outer_loop = tqdm(range(config.max_epochs), desc="Train loss nan, Val loss nan, Saved nan", leave=True)

    # Start the training loop
    for epochs in outer_loop:
        # Initialize a batch of data from the 'train' part of the dataset
        iter_loop = tqdm(range(config.max_iters), leave=False)
        batch = dataset.get_batch('train')
        for iter in iter_loop:
            # Sample a new batch of data
            n_batch = next(batch)
            xb = n_batch['x']
            yb = n_batch['y']
        
            # Evaluate the loss
            logits, loss = model(xb, yb)
            # Reset gradients
            optimizer.zero_grad(set_to_none=True)
            # Perform backpropagation
            loss.backward()
            # Update the model parameters
            optimizer.step()
            # update sheduler
            scheduler.step()

        ############
        # Evaluation
        ############
        # Estimate the losses for both training and validation datasets
        losses = estimate_loss(model, dataset, config)
        # Return the model to training mode
        model.train()
        
        # Record the estimated losses
        losses_list['train'].append(losses['train'])
        losses_list['val'].append(losses['val'])
                    
        # Get the latest losses
        last_loss_train = losses_list['train'][-1]
        last_loss_val = losses_list['val'][-1]
        
        ############
        # early stop
        ############
        val_loss = losses['val']
        if val_loss < best_loss:
            torch.save(model.module.state_dict(), saved_path)
            best_loss = val_loss
            patience_counter = 0
            iter_saved = epochs
        else:
            patience_counter += 1

        # Update the description of the progress bar
        outer_loop.set_description(f"Train loss {last_loss_train:.4f}, Val loss {last_loss_val:.4f}, Saved {iter_saved}")
        
        if patience_counter >= patience:
            print("Early stopping")
            break
    
    # Return the list of losses
    return losses_list

In [15]:
losses_list = training_loop(model, optimizer, dataset, config)

Train loss nan, Val loss nan, Saved nan:   0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])
output.size()=torch.Size([12, 256, 100277])
tgt_tgt.size()=torch.Size([12, 256])


KeyboardInterrupt: 

### Plotting losses

In [None]:
# Call the function
plot_losses(losses_list)

## Testing

In [None]:
context = torch.full((src.size(0), 1), self.config.tokenizer.BOS_IDX).long().to(src.device)
# context = torch.zeros((1, 1), dtype=torch.long, device=device)

In [None]:
print(decode(m.generate(context, max_new_tokens=100)[0].tolist()))