# Goal

To combine [tinystories](https://arxiv.org/abs/2305.07759), [text-to-sql](https://huggingface.co/datasets/b-mc2/sql-create-context), and [textbooks are all you need datasets](https://ar5iv.labs.arxiv.org/html/2306.11644), into one dataset to train an encoder-decoder Transformer model, for text-to-code tasks. All three are on Huggingface, to avoid data ingestion pains for now. 

To then tokenise this dataset via SentencePiece.

# Imports

In [None]:
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser
from datasets import load_dataset
import sentencepiece as spm
from datasets import concatenate_datasets


# Load three datasets

Note: login to huggingface-cli on command line to download textbooks dataset

In [None]:
# Run below on command line if it doesn't work here
# Generate token from HF
!huggingface-cli login

In [None]:
# Pull tinystories from HF
tiny_stories = load_dataset('roneneldan/TinyStories')

text_to_sql = load_dataset('b-mc2/sql-create-context')

textbooks_all_you_need = load_dataset('nampdn-ai/tiny-codes')

In [None]:
textbooks_all_you_need.keys()

# Combine/Prepare for SentencePiece

In [None]:
tiny_stories_train = tiny_stories['train']
text_to_sql_train = text_to_sql['train']
textbooks_all_you_need_train = textbooks_all_you_need['train']

print (tiny_stories_train.features)
print (text_to_sql_train.features)
print (textbooks_all_you_need_train.features)
print (len(tiny_stories_train))
print (len(text_to_sql_train))
print (len(textbooks_all_you_need_train))

Create mini versions of each dataset for testing:

In [None]:
tiny_stories_train_testing = tiny_stories_train.shuffle().select(range(1000))
text_to_sql_train_testing = text_to_sql_train.shuffle().select(range(1000))
textbooks_all_you_need_train_testing = textbooks_all_you_need_train.shuffle().select(range(1000))


I'm going to have to feed one 'language' of data into the encoder, and the other into the decoder. Languages = (English, code)? Or is it (English, SQL, Python, Java...)?

This will be way too complicated a task, let's just use SQL. 

In [None]:
# Define a filter function
def filter_sql_entries(example):
    return 'sql' in example['programming_language'].lower()

# Apply the filter function
textbooks_all_you_need_train_sql = textbooks_all_you_need_train.filter(filter_sql_entries)
textbooks_all_you_need_train_testing_sql = textbooks_all_you_need_train_testing.shuffle().filter(filter_sql_entries)


In [None]:
print (len(textbooks_all_you_need_train_sql))
print (len(textbooks_all_you_need_train_testing_sql))

# Store as individual sentences

In [None]:
# # Testing of SP Input
# # Combine the relevant fields from each dataset into a single text file for SentencePiece training
# with open('SP_data_encoder_decoder_testing.txt', 'w', encoding='utf-8') as f:
#     for example in tiny_stories_train_testing:
#         f.write(example['text'].replace('\n', '') + '\n')

#     print ('tiny_stories_train done')
#     for example in text_to_sql_train_testing:
#         f.write(example['context'] + '\n')
#         f.write(example['question'] + '\n')
#         f.write(example['answer'] + '\n')  # This is typically the target language in translation tasks
#     print ('text_to_sql_train done')
#     for example in textbooks_all_you_need_train_testing:
#         f.write('\n'.join([example[field] for field in example if field not in ['idx', 'response']]) + '\n')

# #         f.write(' '.join([example[field] for field in example if field not in ['idx', 'response']]) + '\n')
#         f.write(example['response'] + '\n')  # Include the 'response' field as it is part of the target language

In [None]:
# Combine the relevant fields from each dataset into a single text file for SentencePiece training
with open('SP_data_encoder_decoder.txt', 'w', encoding='utf-8') as f:
    for example in tiny_stories_train:
        f.write(example['text'].replace('\n', '') + '\n')
    print ('tiny_stories_train done')
    for example in text_to_sql_train:
        f.write(example['context'] + '\n')
        f.write(example['question'] + '\n')
        f.write(example['answer'] + '\n')  # This is typically the target language in translation tasks
    print ('text_to_sql_train done')
    for example in textbooks_all_you_need_train:
        f.write(example['prompt'] + '\n')
        f.write(example['response'] + '\n')  # Include the 'response' field as it is part of the target language

# Train SentencePiece

Initial runs crashing notebook. Restrict rows used. First get number of rows in data 

In [None]:
def count_rows(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        line_count = sum(1 for line in file)
    return line_count

# Example usage
num_rows = count_rows('SP_data_encoder_decoder.txt')  # Replace 'your_file.txt' with your file path
print(f"Number of rows in the file: {num_rows}")

In [None]:
import sentencepiece as spm

# Define parameters for training
train_args = {
    'input': 'SP_data_encoder_decoder.txt',             # Input file
    'model_prefix': 'SP_encoder_decoder_model',        # Prefix for the output model files (.model and .vocab)
    'vocab_size': 3200,              # Size of the vocabulary
    'character_coverage': 0.9997,     # Character coverage to be considered for the model. Good defaults are: 0.9995 for languages with rich character sets like Japanese or Chinese and 0.9997 for others
    'model_type': 'unigram',        # Model type can be 'unigram' (default), 'bpe', 'char', or 'word'
    'input_sentence_size': 1300000,
    'shuffle_input_sentence': True,
    'pad_id': 0,
    'unk_id': 1,
    'bos_id': 2,
    'eos_id': 3,
    'pad_piece': '[PAD]',
    'unk_piece': '[UNK]',
    'bos_piece': '[BOS]',
    'eos_piece': '[EOS]'}

# Train the model
spm.SentencePieceTrainer.Train(' '.join([f'--{k}={v}' for k, v in train_args.items()]))

print("Model trained and saved as mymodel.model and mymodel.vocab!")

# Check SP trained well

The sentencepiece call above will now:
1. identify [BOS] tokens as beginning a sentence, same for [EOS] and end of sentence
2. Replace unknown tokens with [UNK]/token ids with 1 
3. Replace padding tokens with [PAD], and token ids with 0

To check:
What is the token encoding for a sample of rows in each dataset? 

In [None]:
import sentencepiece as spm

# Initialize SentencePiece processor and load your model
sp = spm.SentencePieceProcessor()
sp.load('SP_encoder_decoder_model.model')  # Replace with your model file

# Your input string
input_string =     '''
    cin >> consentGiven;

    // Based on the user's answer, display appropriate instructions
    if (consentGiven) {
        cout << endl
'''

input_string_2 =     '''
cin >> consentGiven; // Based on the user's answer, display appropriate instructions if (consentGiven) {cout << endl
'''

# Encode the string into SentencePiece tokens
encoded_pieces = sp.EncodeAsPieces(input_string)
print("Encoded as pieces:", encoded_pieces)

# Alternatively, encode the string into token IDs
encoded_ids = sp.EncodeAsIds(input_string)
print("Encoded as ids:", encoded_ids)

# Encode the string into SentencePiece tokens
encoded_pieces_2 = sp.EncodeAsPieces(input_string_2)
print("Encoded as pieces:", encoded_pieces_2)

# Alternatively, encode the string into token IDs
encoded_ids_2 = sp.EncodeAsIds(input_string_2)
print("Encoded as ids:", encoded_ids_2)

Ok, looks alright over the SQL! Will assume it's fine over tinystories too then, not worried about that.

# Create training datasets for encoder and decoder

For this, we'll have:
1. separate out the inputs (natural language plus SQL contextual?) and the outputs (resultant SQL?). 
2. We'll also have to append BOS, EOS and Padding tokens, just as our trained SP model expects. 
3. We'll have to create datasets and dataloaders objects for the inputs and outputs separately too. 
4. We'll also need collate functions to pad each batch. 

Input dataset:
1. Encoder: tiny_stories['text'], Decoder: X 
2. text_to_sql['question'] + text_to_sql['context'], Decoder: text_to_sql['answer']
3. textbooks_all_you_need['prompt']: textbooks['response']

Maybe start with just (2) and (3) for now. In fact let's ignore context from (2) also.

- Append BOS and EOS tokens to target sequence/decoder input
- Do the same for input sequences/encoder input? Definitely for EOS token, probably doesn't matter for BOS for encoder (but the q/k/v matrices will be different size though. So only use EOS token?)
- Full target sequence is [BOS, 1,2,3,4, EOS] 
- decoder input is [BOS, 1,2,3,4]
- Decoder will output predictions on [1,2,3,4,EOS], so use [1,2,3,4,EOS] in loss calculation

In [None]:
# class EncoderDecoderDataset(t.utils.data.Dataset):
#     def __init__(self, hf_dataset):
#         self.hf_dataset = hf_dataset

#     def __len__(self):
#         return len(self.hf_dataset)

#     def __getitem__(self, idx):
#         # This method should return a single sample at a time
#         item = self.hf_dataset[idx]
#         # Process the item (e.g., tokenization, numericalization) as required
#         # ...
#         return item

In [None]:
from datasets import Features, Value

def standardize_types(example):
    example['prompt'] = str(example['prompt'])
    example['response'] = str(example['response'])
    return example



text_to_sql_train_mapped = text_to_sql_train.map(lambda example: {'prompt': example['question'], 
                                         'response': example['answer']}, 
                        remove_columns=['question', 'context', 'answer'])

cols_to_keep = {'prompt', 'response'}

columns_to_remove = [col for col in textbooks_all_you_need_train_sql.column_names if col not in cols_to_keep]

textbooks_all_you_need_train_sql_trimmed = textbooks_all_you_need_train_sql.map(lambda example: example, remove_columns=columns_to_remove)

textbooks_all_you_need_train_sql_trimmed_2 = textbooks_all_you_need_train_sql_trimmed.cast(
    Features({"response": Value("string"), "prompt": Value("string")}))
combined_dataset = concatenate_datasets([text_to_sql_train_mapped, textbooks_all_you_need_train_sql_trimmed_2])



In [None]:
print (len(textbooks_all_you_need_train_sql_trimmed))
print (len(text_to_sql_train_mapped))
print (len(combined_dataset))

for i in range(5):
    print (combined_dataset[i])
    print (combined_dataset[-i])

Looks like dataset concatenation was successful!

# Create Dataset Class for encoder-decoder

In [None]:
import dataset
import importlib
importlib.reload(dataset)
from dataset import EncoderDecoderDataset

In [None]:
import Config
import importlib
importlib.reload(Config)
from Config import Config

model_cfg = Config(
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab= 32000
)

# Create train and test sets

In [None]:
# EncoderDecoderData = EncoderDecoderDataset(
#     combined_dataset, 
#     sp)

# import torch
# from torch.utils.data import DataLoader, random_split

# # Assuming `my_dataset` is your dataset instance
# dataset_size = len(EncoderDecoderData)

# # Set the percentage for the test data (e.g., 20%)
# test_pct = 0.20  # 20% of the dataset
# test_size = int(dataset_size * test_pct)
# train_size = dataset_size - test_size

# # Split the dataset
# EncoderDecoderData_train, EncoderDecoderData_test = random_split(EncoderDecoderData, [train_size, test_size])

# Create two separate EncoderDecoderDataset instances with random sampling
num_samples = len(combined_dataset)
train_ratio = 0.8  # Adjust this ratio as needed

train_indices = np.random.choice(num_samples, int(train_ratio * num_samples), replace=False)
test_indices = np.setdiff1d(np.arange(num_samples), train_indices)

EncoderDecoderData_train = EncoderDecoderDataset(combined_dataset.select(train_indices), sp, model_cfg)
EncoderDecoderData_test = EncoderDecoderDataset(combined_dataset.select(test_indices), sp, model_cfg)

In [None]:
test = sp
test

# Create dataloaders

In [None]:
# train_loader = DataLoader(
#     EncoderDecoderDataset_train,
#     batch_size=args.batch_size,
#     shuffle=True,
#     num_workers=4,
#     pin_memory=False,
#     collate_fn = EncoderDecoderDataset.collate_fn
# )

# test_loader = DataLoader(
#     EncoderDecoderDataset_test,
#     batch_size=args.batch_size,
#     shuffle=False,
#     num_workers=4,
#     pin_memory=False,
#     collate_fn = EncoderDecoderDataset.collate_fn
# )

# EncoderDecoderDataLoader = t.utils.data.DataLoader(
#     EncoderDecoderDataset,
#     batch_size=64,
#     shuffle=True,
#     collate_fn=EncoderDecoderDataset.collate_fn)


In [None]:
type(EncoderDecoderData_train)

In [None]:
len(EncoderDecoderData)

In [None]:
len(train_loader)

In [None]:
len(test_loader)

# Create Encoder-Decoder Model

Adjustments to make, to Transformer Code, to implement encoder-decoder:
1. Create separate encoder and decoder transformer classes, replacing the TransformerBlock and DemoTransformer classes. 
2. Implement cross-attention, after self-attention, in the decoder block

## Dataclass

In [None]:


# @dataclass
# class Config:
#     d_model: int = 768
# #     debug: bool = True
#     layer_norm_eps: float = 1e-5
#     d_vocab: int = 50257
#     init_range: float = 0.02
#     n_ctx: int = 1024
#     d_head: int = 64
#     d_mlp: int = 3072
#     n_heads: int = 12
#     n_layers: int = 12
#     device: str = t.device("cuda" if t.cuda.is_available() else "cpu")

# cfg = Config()
# print(cfg)

## Device

In [None]:
# device = t.device("cuda" if t.cuda.is_available() else "cpu")


## Embedding Modules

In [None]:
import torch as t
class Embed(nn.Module):
    def __init__(self, cfg:Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model))
        nn.init.normal_(self.W_E, std = self.cfg.init_range)
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]
    
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty(cfg.n_ctx, cfg.d_model))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
        
    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch = batch)

## Transformer Modules

## Attention

In [None]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]
    
    def __init__(self, cfg:Config, is_causal: bool):
        super().__init__()
        self.cfg = cfg 
        self.is_causal = is_causal
        self.W_Q = nn.Parameter(t.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_K = nn.Parameter(t.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_V = nn.Parameter(t.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_O = nn.Parameter(t.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.device = cfg.device
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=self.device))
        # self.key_activations = t.empty(cfg.n_heads, cfg.d_model, cfg.d_head)
        # self.value_activations = t.empty(cfg.n_heads, cfg.d_model, cfg.d_head)

    def forward(
        self,
        normalized_resid_pre: Float[Tensor, "batch seq_len d_model"],
        key_activations: Float[Tensor, "batch seq_len n_heads d_head"] = None, 
        value_activations: Float[Tensor, "batch seq_len n_heads d_head"] = None, 
    ) -> Float[Tensor, "batch seq_len d_model"]:

        Queries = einops.einsum(
            normalized_resid_pre,
            self.W_Q,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
            ) + self.b_Q

        if key_activations is None:
            Keys = einops.einsum(
                normalized_resid_pre,
                self.W_K,
                "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
                ) + self.b_K
            # self.key_activations = Keys
        else:
            Keys = key_activations
        

        if value_activations is None:
            Values = einops.einsum(
                normalized_resid_pre,
                self.W_V,
                "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
                ) + self.b_V
            # self.value_activations = Values
        else:
            Values = value_activations
                
        Attention_Scores = einops.einsum(
            Queries,
            Keys,
            "batch seq_len_Q n_heads d_head, batch seq_len_K n_heads d_head -> batch n_heads seq_len_Q seq_len_K")
        
        # Only apply causal_attention if in decoder, via is_causal bool
        if (self.is_causal):
            Attention_Scores = self.apply_causal_mask(Attention_Scores)

        Attention_Scores_Scaled = Attention_Scores / (self.cfg.d_head**0.5)
        Attention_Scores_Scaled_Softmaxed = Attention_Scores_Scaled.softmax(-1)

        Z = einops.einsum(
            Values,
            Attention_Scores_Scaled_Softmaxed,
            "batch seq_len_K n_heads d_head, batch n_heads seq_len_Q seq_len_K -> batch seq_len_Q n_heads d_head")

        Attention_Out = einops.einsum(
            Z, 
            self.W_O, 
            "batch seq_len_Q n_heads d_head, n_heads d_head d_model -> batch seq_len_Q d_model"
            ) + self.b_O

        if (self.is_causal):
            return Attention_Out
        else:
            return Attention_Out, Keys, Values
    
    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        key_by_query_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device = self.device)
        mask = t.triu(key_by_query_ones, diagonal = 1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

## MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty(cfg.d_model, cfg.d_mlp))
        self.W_out = nn.Parameter(t.empty(cfg.d_mlp, cfg.d_model))
        self.b_in = nn.Parameter(t.zeros(cfg.d_mlp))
        self.b_out = nn.Parameter(t.zeros(cfg.d_model))
        nn.init.normal_(self.W_in, std = self.cfg.init_range)
        nn.init.normal_(self.W_out, std = self.cfg.init_range)
    
    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        
        post_W_in = einops.einsum(
            normalized_resid_mid,
            self.W_in,
            "batch seq_len d_model, d_model d_mlp -> batch seq_len d_mlp") + self.b_in
        
        post_activation = gelu_new(post_W_in) 
        
        post_W_out = einops.einsum(
            post_activation,
            self.W_out, 
            "batch seq_len d_mlp, d_mlp d_model -> batch seq_len d_model") + self.b_out
        return post_W_out

## LayerNorm

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b

## Assemble Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg, is_causal=True)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(
        self, resid_pre: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        attention
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post

## Assemble Encoder Block

In the encoder block, we just need to remove the causal mask, and then output the key and value activations, on a certain input. 

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg, is_causal=False)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
            
        attention_activations, key_activations, value_activations = self.attn(self.ln1(resid_pre))
        resid_mid = attention_activations + resid_pre

        mlp_activations = self.mlp(self.ln2(resid_mid))
        resid_post = mlp_activations + resid_mid
        return resid_post, key_activations, value_activations

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self,
                 cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg, is_causal=True)
        self.ln2 = LayerNorm(cfg)
        self.attn2 = Attention(cfg, is_causal=False)
        self.ln3 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(
        self,
        resid_pre: Float[Tensor, "batch seq_len d_model"],
        key_activations: Float[Tensor, "batch seq_len n_heads d_head"],
        value_activations: Float[Tensor, "batch seq_len n_heads d_head"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # self.attn2.key_activations = key_activations
        # self.attn2.value_activations = value_activations
        resid_post_causal_attention = self.attn(self.ln1(resid_pre)) + resid_pre
        
        # activations_post_cross_attention, _, _ = self.attn2(self.ln2(resid_post_causal_attention))
        activations_post_cross_attention, _, _ = self.attn2(self.ln2(resid_post_causal_attention), key_activations, value_activations)
        resid_post_cross_attention = activations_post_cross_attention + resid_post_causal_attention

        # resid_post_cross_attention, = self.attn2(self.ln2(resid_post_causal_attention))[0] + resid_post_causal_attention
        resid_post_mlp = self.mlp(self.ln2(resid_post_cross_attention)) + resid_post_cross_attention
        return resid_post_mlp

## Unembedding Module

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty(cfg.d_model, cfg.d_vocab))
        self.b_U = nn.Parameter(t.zeros(cfg.d_vocab), requires_grad = False)
        nn.init.normal_(self.W_U, std = self.cfg.init_range)
        
    def forward(
        self, resid_stream: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_vocab"]:
        
        Unembedding = einops.einsum(
            resid_stream,
            self.W_U,
            "batch seq_len d_model, d_model d_vocab -> batch seq_len d_vocab") + self.b_U
        return Unembedding

## Full Transformer

In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(cfg) for _ in range(cfg.n_layers)])
        self.decoder_blocks = nn.ModuleList([DecoderBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self,
                encoder_input: Float[Tensor, "batch seq_len"],
                decoder_target: Float[Tensor, "batch seq_len"],
               ) -> Float[Tensor, "batch seq_len d_vocab"]:
        
        # residual = self.embed(tokens) + self.pos_embed(tokens)
        # for block in self.encoder_blocks:
        #     residual, key_activations, value_activations = block(residual)
        # for block in self.decoder_blocks:
        #     residual = block(residual, key_activations, value_activations)
        # logits = self.unembed(self.ln_final(residual))
        
        # Encoding input sequence
        encoder_residual = self.embed(encoder_input) + self.pos_embed(encoder_input)
        for block in self.encoder_blocks:
            encoder_residual, key_activations, value_activations = block(encoder_residual)
        # Decoding target sequence
        decoder_residual = self.embed(decoder_target) + self.pos_embed(decoder_target)

        for block in self.decoder_blocks:
            decoder_residual = block(decoder_residual, key_activations, value_activations)
        logits = self.unembed(self.ln_final(decoder_residual))
        return logits

In [None]:
demo_transformer = DemoTransformer(Config).to(Config.device)

# Model Configs

In [None]:

model = DemoTransformer(model_cfg)

@dataclass
class TransformerTrainingArgs():
    batch_size = 4
    epochs = 2
    max_steps_per_epoch = 3
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: Optional[str] = "day2-demotransformer"
    wandb_name: Optional[str] = 'shaheen-ahmed'

args = TransformerTrainingArgs()

# Loss function

In [None]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"],
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens

# Training Loop

In [None]:
class TransformerTrainer:
    def __init__(self,
                 args: TransformerTrainingArgs,
                 model: DemoTransformer,
                 train_data: t.utils.data.Dataset,
                 test_data: t.utils.data.Dataset,
                 model_cfg: Config,
                 loss_fn):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0
        self.model_cfg = model_cfg
        self.loss_fn = loss_fn
        self.train_data = train_data
        self.test_data = test_data

    def training_step(self, batch: Dict[str, t.Tensor]) -> Float[Tensor, ""]:
        # Assuming batch contains 'tensor_model_input' and 'tensor_ground_truth' tensors
        encoder_input = batch['tensor_model_input'].to(self.model_cfg.device)
        decoder_target = batch['tensor_ground_truth'].to(self.model_cfg.device)

        # Forward pass through the model
        logits = self.model(encoder_input, decoder_target)

        # Shift target tokens for loss calculation (ignore the last token)
        shifted_target_tokens = decoder_target[:, :-1]

        # Calculate loss
        loss = -self.loss_fn(logits[:, :-1], shifted_target_tokens).mean()
        
        # Backpropagate loss
        loss.backward()
        
        # Update weights
        self.optimizer.step()
        
        # Zero out gradients
        self.optimizer.zero_grad()
        self.step += 1
        
        # wandb.log({"train_loss": loss}, step=self.step)
        return loss

    def validation_step(self, batch: Dict[str, t.Tensor]):
        encoder_input = batch["tensor_model_input"].to(self.model_cfg.device)
        decoder_target = batch["tensor_ground_truth"].to(self.model_cfg.device)

        logits: Tensor = self.model(encoder_input, decoder_target)[:, :-1]
        predicted_tokens = logits.argmax(dim=-1)
        
        # Shift target tokens for accuracy calculation (ignore the first token)
        shifted_target_tokens = decoder_target[:, 1:]
        correct_predictions = (predicted_tokens == shifted_target_tokens).flatten()
        return correct_predictions

    def train(self):
        print ('wandb init below')

        # wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        print ('wandb init done')

        validation_accuracy = np.nan

        progress_bar = tqdm(total = self.args.max_steps_per_epoch * self.args.epochs)
        print ('progress bar made')
        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader()):
                print ('train started')
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {validation_accuracy:.2f}")
                if i >= self.args.max_steps_per_epoch:
                    break
            validation_accuracy = 0
            n_batches_for_testing = 0
            for batch in self.test_loader():
                print ('validation started')
                validation_batch_lim = 1
                # correct_predictions = t.concat([self.validation_step(batch) for batch in self.test_loader()])
                # accuracy = correct_predictions.float().mean().item()

                correct_predictions = self.validation_step(batch)
                accuracy = correct_predictions.float().mean().item()
                validation_accuracy += accuracy 
                n_batches_for_testing += 1 
                if n_batches_for_testing > validation_batch_lim:
                    break
            validation_accuracy /= n_batches_for_testing
                
            # wandb.log({"accuracy": accuracy}, step=self.step)

        # wandb.finish()

    def train_loader(self) -> DataLoader:
        return DataLoader(self.train_data,
                          batch_size=self.args.batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True,
                          collate_fn = EncoderDecoderData_train.collate_fn)

    def test_loader(self) -> DataLoader:
        return DataLoader(self.test_data,
                          batch_size=self.args.batch_size,
                          shuffle=False,
                          num_workers=4,
                          pin_memory=True,
                          collate_fn = EncoderDecoderData_test.collate_fn)

In [None]:
trainer = TransformerTrainer(args,
                             model,
                             EncoderDecoderData_train,
                             EncoderDecoderData_test,
                             model_cfg,
                             get_log_probs)
trainer.train()

Epoch 1: Accuracy = 0.1

In [None]:
t.save(model.state_dict(), 'encoder_decoder_run.pth')


# Generate SQL from query

In [None]:

sampling_model = DemoTransformer(model_cfg)  # Initialize your encoder-decoder model
sampling_model.load_state_dict(t.load("encoder_decoder_run.pth"))  # Load pre-trained weights


In [None]:
class EncoderDecoderSampler:
    def __init__(self, model: DemoTransformer, sp_tokenizer):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = sp_tokenizer
        # self.model_cfg = model_cfg

    @t.inference_mode()
    def generate_text(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
        """
        Generates text autoregressively using the encoder-decoder model.

        Args:
        - prompt (str): The initial input text.
        - max_tokens_generated (int): The maximum number of tokens to generate.
        - verbose (bool): If True, print the generated text at each step.
        - **kwargs: Additional arguments for sampling.

        Returns:
        - generated_text (str): The autoregressively generated text.
        """
        self.model.eval()
        # max_tokens_generated = self.cfg.n_ctx
        # Define the shape for the generated_tokens tensor (batch_size, seq_len)
        batch_size = 1  # Replace with your batch size
        seq_len = self.cfg.n_ctx    # Replace with your sequence length
        
        # Encode the input prompt using the encoder
        input_ids = self.tokenizer.encode_as_ids(prompt)
        encoder_input_ids = t.tensor([input_ids], dtype = t.long).to(self.cfg.device)

        decoder_input_ids = t.tensor([[sp.bos_id()]], dtype = t.long).to(self.cfg.device)
        # encoder_input_ids_forward_pass = t.zeros((batch_size, seq_len), dtype=t.long).to(self.model_cfg.device)
        # encoder_input_ids_forward_pass[0, 0:encoder_input_ids.shape[0]] = encoder_input_ids

        # Initialize the generated_tokens tensor with zeros
        # generated_tokens = t.zeros((None, seq_len), dtype=t.long).to(self.model_cfg.device)
        # generated_tokens[None, 0] = sp.bos_id()
        # insertion_point_to_generated_tokens = 1
        
        for i in range(max_tokens_generated - 1):
            # Decode the tokens autoregressively using the decoder
            logits = self.model(encoder_input_ids, decoder_input_ids)

            # Get the logits for the next token
            next_token_logits = logits[:, -1, :]
            
            # Sample the next token using a sampling strategy
            next_token = self.sample_next_token(decoder_input_ids.squeeze(0), next_token_logits.squeeze(), **kwargs)
            next_token_tensor = t.tensor([[next_token]], dtype = decoder_input_ids.dtype, device = decoder_input_ids.device)

            # Concatenate the next token to the generated tokens
            # generated_tokens = t.cat([generated_tokens, next_token.unsqueeze(1)], dim=-1)
            # generated_tokens = t.cat([generated_tokens, t.tensor([next_token], dtype=t.long).unsqueeze(1)], dim=-1)
            # generated_tokens[0, insertion_point_to_generated_tokens] = next_token
            # print (f'next_token = {next_token}')
            # insertion_point_to_generated_tokens += 1
            
            decoder_input_ids = t.cat((decoder_input_ids, next_token_tensor), dim = 1)

            # Print out results, if required
            if verbose:
                generated_text = self.tokenizer.decode(generated_tokens[0])
                print(generated_text, end="\r")

            # Check if the generated text ends with an end-of-sequence token
            # if next_token.item() == getattr(self.tokenizer, "eos_token_id", None):
            if next_token == getattr(self.tokenizer, "eos_token_id", None):
                break

        # Decode the generated tokens into text
        print (f'decoder_input_ids = {decoder_input_ids}')
        print (f'decoder_input_ids[0] = {decoder_input_ids[0]}')
        
        # Convert the generated tokens tensor to a Python list of integers
        decoder_input_ids_list = decoder_input_ids[0].tolist()
        print (f'decoder_input_ids[0].tolist() = {decoder_input_ids_list}')

        # Decode the list of integers using the tokenizer
        print (self.tokenizer.get_piece_size())
        generated_text = self.tokenizer.decode(decoder_input_ids_list)
        return generated_text


    @staticmethod
    def sample_next_token(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "seq_len d_vocab"],
        temperature=1.0,
        top_k=0,
        top_p=0.0,
        frequency_penalty=0.0,
        seed=None,
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (
            top_p != 0 and top_k != 0
        ), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            t.manual_seed(seed)
            np.random.seed(seed)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return EncoderDecoderSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = EncoderDecoderSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = EncoderDecoderSampler.apply_frequency_penalty(
                input_ids, logits, frequency_penalty
            )
        if top_k > 0:
            return EncoderDecoderSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return EncoderDecoderSampler.sample_top_p(logits, top_p)
        return EncoderDecoderSampler.sample_basic(logits)

    @staticmethod
    def greedy_search(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        Returns the most likely token (as an int).
        """
        out = logits.argmax().item()
        return out

    @staticmethod
    def apply_temperature(
        logits: Float[Tensor, "d_vocab"], temperature: float
    ) -> Float[Tensor, "d_vocab"]:
        """
        Applies temperature scaling to the logits.
        """
        # SOLUTION
        return logits / temperature

    @staticmethod
    def apply_frequency_penalty(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        freq_penalty: float,
    ) -> Float[Tensor, "d_vocab"]:
        """
        Applies a frequency penalty to the logits.
        """
        # SOLUTION
        d_vocab = logits.size(0)
        id_freqs = t.bincount(input_ids, minlength=d_vocab)
        return logits - freq_penalty * id_freqs

    @staticmethod
    def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        Samples from the distribution defined by the logits.
        """
        # SOLUTION
        sampled_token = t.distributions.categorical.Categorical(logits=logits).sample()
        return sampled_token.item()

    @staticmethod
    def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
        """
        Samples from the top k most likely tokens.
        """
        # SOLUTION
        top_k_logits, top_k_token_ids = logits.topk(k)
        # Get sampled token (which is an index corresponding to the list of top-k tokens)
        sampled_token_idx = t.distributions.categorical.Categorical(
            logits=top_k_logits
        ).sample()
        # Get the actual token id, as an int
        return top_k_token_ids[sampled_token_idx].item()

    @staticmethod
    def sample_top_p(
        logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1
    ) -> int:
        """
        Samples from the most likely tokens which make up at least p cumulative probability.
        """
        # SOLUTION
        # Sort logits, and get cumulative probabilities
        logits_sorted, indices = logits.sort(descending=True, stable=True)
        cumul_probs = logits_sorted.softmax(-1).cumsum(-1)
        # Choose which tokens to keep, in the set we sample from
        n_keep = t.searchsorted(cumul_probs, top_p, side="left").item() + 1
        n_keep = max(n_keep, min_tokens_to_keep)
        keep_idx = indices[:n_keep]
        keep_logits = logits[keep_idx]
        # Perform the sampling
        sample = t.distributions.categorical.Categorical(logits=keep_logits).sample()
        return keep_idx[sample].item()

In [None]:
sampler = EncoderDecoderSampler(sampling_model, sp)


In [None]:
prompt = 'Select all the entries in the person column'

sample = sampler.generate_text(prompt = prompt, max_tokens_generated = 100)

In [None]:
sample