#### **Simple Question Answering Using a Transformer Decoder**

We will explore how to train a character-level tranformer language model for a simple question answering task. Given a question of the form `Where was [X] born?` where `[X]` is the name of a public figure, the model will be trained to predict the output `[Y]` which is the name of the birthplace of that person. Our goal is to first pretrain the model on a wikipedia corpus (next character prediction) from which it is expected to acquire knowledge of persons and their birthplaces. Then we finetune the model with supervised training on `(x,y)` sequence pairs of the following form:

`x: Where was Albert Einstein born?%Germany%□□□□□□□□□□□□□□`

`y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□%Germany%□□□□□□□□□□□□□□□`

where `x` is the input sequence and `y` is the predicted output sequence and `□` is a special padding token. This is a simple next character prediction task, however we do not want the model to predict the question itself, only the answer, which is why in the output sequence, we replace all characters from the question with the padding token and only have the model predict the characters from the answer. 

i.e. instead of

`y: here was Albert Einstein born?%Germany%□□□□□□□□□□□□□□□`

we use 

`y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□%Germany%□□□□□□□□□□□□□□□`

We've also used a special token `%` to mark the beginning and end of the span containing the answer.

**The idea is that by training the model on this task, it can learn to answer a question by retreiving information pertaining to the answer from it's pretrained knowledge.** After training, we can test this idea by giving the model an input sequence which does not contain an answer, i.e. after the start of answer token `%`, we fill the rest of the sequence with padding tokens: 

`x: Where was Enrico Fermi born?%□□□□□□□□□□□□□□□□□□□□□□□□□`

Then if the predicted output sequence contains the right answer, then it will support our idea. We also make sure that person names which were not in the training set will be used during testing.


We will first train the model on the finetuning task without pretraining it and then look at the difference in performance with and without pretraining. 

(Note: torch.nn.TransformerDecoder does not support autoregressive decoding. Beware!)



In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
import psutil
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
mask_token = u"\u25A0"
pad_token = u"\u25A1"

print(f"mask token: {mask_token}")
print(f" token: {pad_token}")

mask token: ■
 token: □


In [5]:
# first get the character vocabulry from the pretraining dataset
with open("birth_place_data/wiki.txt", 'r', encoding='utf-8') as file:
    pretrain_text = file.read()

vocab = list(sorted(list(set(pretrain_text))))
assert mask_token not in vocab, "mask token should not be in the vocabulary"
assert pad_token not in vocab, "pad token should not be in the vocabulary"
vocab = [pad_token, mask_token] + vocab 
print(f"vocabulary: {vocab}")

vocabulary: ['□', '■', '\n', ' ', '!', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\xa0', '£', '\xad', 'Á', 'Å', 'É', 'Ó', 'Ö', 'Ø', 'Ü', 'ß', 'à', 'á', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é', 'ë', 'í', 'ï', 'ñ', 'ó', 'ô', 'ö', 'ø', 'ü', 'ý', 'ă', 'ą', 'ć', 'Č', 'č', 'ě', 'ğ', 'ī', 'İ', 'ı', 'ł', 'ń', 'ō', 'Ő', 'ő', 'œ', 'ř', 'ś', 'ş', 'Š', 'š', 'ť', 'ū', 'Ż', 'ż', 'Ž', 'ž', 'ș', 'Γ', 'Μ', 'ά', 'έ', 'α', 'γ', 'η', 'ι', 'κ', 'ν', 'ο', 'ρ', 'ς', 'τ', 'υ', 'ω', 'ώ', 'Ј', 'А', 'В', 'Г', 'И', 'К', 'П', 'Р', 'С', 'а', 'б', 'в', 'г', 'д', 'е', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'р', 'с', 'т', 'ц', 'ч', 'ь', 'я', 'ћ', 'א', 'ג', 'ה', 'ו', 'ז', 'ח', 'י', 'כ', 'ל', 

In [6]:
print(f"Vocab size: {len(vocab)}")

Vocab size: 256


Create pytorch dataset for finetuning

In [7]:
class NameBirthplaceDataset(Dataset):
    def __init__(self, vocab, mask_token, pad_token, block_size=128, split="train"):
        self.vocab= vocab
        self.ctoi = {c:i for i,c in enumerate(vocab)}
        self.mask_token = mask_token 
        self.pad_token = pad_token
        self.block_size = block_size
        if split == "train":
            data_filename="birth_place_data/birth_places_train.tsv"
        elif split == "dev":
            data_filename="birth_place_data/birth_places_dev.tsv"
        self.data = self.read_data(data_filename)
         
    def read_data(self, filename):
        with open(filename, 'r', encoding='utf-8') as f: 
            lines = f.read()
        data = list(lines.encode('utf-8').decode('ascii', errors='ignore').split('\n'))
        return data    

    @property
    def pad_token_index(self):
        return self.ctoi[self.pad_token]

    @property
    def mask_token_index(self):
        return self.ctoi[self.mask_token]

    def __len__(self):
        return len(self.data)-1

    def __getitem__(self, index):
        line = self.data[index]
        question, answer = line.split('\t') 
        question, answer = list(question), list(answer) 
        x = question + [self.mask_token] + answer + [self.mask_token]
        x = x + (self.block_size-len(x)) * [self.pad_token] 
        y = x[1:]
        x = x[:-1] 
        y[:len(question)-1] = (len(question)-1) * [self.pad_token]

        x = torch.tensor([self.ctoi[c] for c in x], dtype=torch.long)
        y = torch.tensor([self.ctoi[c] for c in y], dtype=torch.long)
        return x, y


Create `Span Corrpution` dataset for pre-training. This is a modified next word prediction task where an input sequence is first truncated to a random lengtg, then random-sized spans of contiguous tokens in the input sequence are corruputed, each entire span is replaced with a single mask token. For simplicity, we will only corrput a single span in each sequence. This span of tokens is then appended to the end of the sequence and its beginning and end are marked by the mask token (similar to how the beginning and end of the answer in the finetuning task is marked by the mask token, for good reason...). Then we fill the rest of the sequence with pad tokens to make it block_size long.

e.g. Given the sequence `x: Where was Enrico Fermi born?`

we first randomly truncate it:

`x: Where was Enrico Ferm`

Then we corrput a span:

`x: Where was En% Ferm%rico%`

and now fill with pad tokens:

`x: Where was En% Ferm%rico%□□□□□□□□□□□□□□□□□□□□□□□□□`

Finally, the output sequence is just the usual shifted version of the input:

`x: here was En% Ferm%rico%□□□□□□□□□□□□□□□□□□□□□□□□□□`

We will choose random truncation lengths between `4 and 7/8 * block_size` and a random corruption span length which is on average `1/4 * the truncated document length` 


In [26]:
class SpanCorruptionDataset(Dataset):
    def __init__(self, vocab, mask_token, pad_token, block_size=128):
        self.vocab= vocab
        self.ctoi = {c:i for i,c in enumerate(vocab)}
        self.mask_token = mask_token 
        self.pad_token = pad_token
        self.block_size = block_size
        data_filename="birth_place_data/wiki.txt"
        self.data = self.read_data(data_filename)
         
    def read_data(self, filename):
        with open(filename, 'r', encoding='utf-8') as f: 
            data = f.read().split('\n')
        return data

    @property
    def pad_token_index(self):
        return self.ctoi[self.pad_token]

    @property
    def mask_token_index(self):
        return self.ctoi[self.mask_token]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get the sentence
        line = self.data[index]
        # apply random truncation
        trunc_len = random.randint(4, int(self.block_size*7/8))
        line_trunc = line[:trunc_len]
        # apply random span corruption
        # draw random number from gaussian with mean 1/4 * trunc_len and std 1/8 * trunc_len
        span_len = min(max(0,int(random.gauss(mu=trunc_len/4, sigma=trunc_len/10))), int(0.8*trunc_len))
        # draw random start position
        span_start = random.randint(0, trunc_len-span_len)
        # extract the span
        span = line_trunc[span_start:span_start+span_len]
        # replace the span with mask tokens
        line_span_corrupted = line_trunc[:span_start] + self.mask_token + line_trunc[span_start+span_len:] + self.mask_token + span + self.mask_token 
        # add padding
        line_span_corrupted = line_span_corrupted + (self.block_size-len(line_span_corrupted)) * self.pad_token
        # input-target pair
        x = line_span_corrupted[:-1]
        y = line_span_corrupted[1:]
        # convert to tensors
        x = torch.tensor([self.ctoi[c] for c in x], dtype=torch.long)
        y = torch.tensor([self.ctoi[c] for c in y], dtype=torch.long)
        return x, y


In [9]:
train_data = NameBirthplaceDataset(vocab, mask_token, pad_token)
dev_data = NameBirthplaceDataset(vocab, mask_token, pad_token, split="dev")

pad_token_index = train_data.pad_token_index
mask_token_index = train_data.mask_token_index

In [10]:
x, y = train_data[1000]

def decode_token_indices(x):
    return "".join([vocab[i] for i in x])

x_decoded = decode_token_indices(x)
y_decoded = decode_token_indices(y)
print(x_decoded)
print(y_decoded)

Where was Yang Yang born?■Beijing■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□■Beijing■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□


In [27]:
pretrain_data = SpanCorruptionDataset(vocab, mask_token, pad_token)

In [43]:
x, y = pretrain_data[2654]

x_decoded = decode_token_indices(x)
y_decoded = decode_token_indices(y)
print(x_decoded)
print(y_decoded)

Lavrent■ Tiflis of a deacon, Ardaziani■i Ardaziani. Born in■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
avrent■ Tiflis of a deacon, Ardaziani■i Ardaziani. Born in■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□


#### Create the transformer question answering model

In [46]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, block_size, embedding_dim, total_head_size, num_heads, dropout_rate):
        super().__init__()

        assert total_head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.total_head_size = total_head_size 
        self.head_size = total_head_size // num_heads 
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # define parameters
        self.key = torch.nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.query = torch.nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.value = torch.nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.attn_dropout = torch.nn.Dropout(dropout_rate)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = torch.nn.Linear(total_head_size, embedding_dim) 
        self.output_dropout = torch.nn.Dropout(dropout_rate)


    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,H) where H is the total_head_size
        q = self.query(x) # (B,T,H)
        v = self.value(x) # (B,T,H)

        # reshape (B,T,H) --> (B,T,n,h), where n=num_heads and h=head_size and H=n*h
        k = k.view(B,T,self.num_heads,self.head_size) 
        q = q.view(B,T,self.num_heads,self.head_size) 
        v = v.view(B,T,self.num_heads,self.head_size) 

        # now we transpose so that the num_heads is the second dimension followed by T,h
        # this allows us to batch matrix mutliply for all heads simulataneously to compute their attention weights
        # (B,T,n,h) --> (B,n,T,h) 
        k = k.transpose(1,2) 
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        
        # use pytorch built-in function for faster computation of attention scores (set the 'is_causal' parameter for applying causal masking)
        out = F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout_rate if self.training else 0,is_causal=True)

        # we can transpose the output from (B,n,T,h) --> (B,T,n,h)
        # since the last two dimensions of the transposed tensor are non-contiguous, we apply 
        # contiguous() which return a contiguous tensor
        out = out.transpose(1,2).contiguous()

        # finally we collapse the last two dimensions to get the concatenated output, (B,T,n,h) --> (B,T,n*h) 
        out = out.view(B,T,self.total_head_size)

        # now we project the concatenated output so that it has the same dimensions as the multihead attention layer input
        # (we need to add it with the input because of the residual connection, so need to be same size) 
        out = self.proj(out) # (B,T,C) 

        # apply dropout
        out = self.output_dropout(out)

        return out
    

# a simple mlp 
class FeedForward(torch.nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, 4*embedding_dim),  
            torch.nn.GELU(),
            torch.nn.Linear(4*embedding_dim, embedding_dim), 
            torch.nn.Dropout(dropout_rate)
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)
    

# transformer block with residual connection and layer norm
class TransformerBlock(torch.nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads, dropout_rate):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads, dropout_rate) # multi-head attention layer 
        self.ff = FeedForward(embedding_dim, dropout_rate)   # feed-forward layer
        self.ln1 = torch.nn.LayerNorm(embedding_dim) # layer norm at input of multi-head attention
        self.ln2 = torch.nn.LayerNorm(embedding_dim) # layer norm at input of feed-forward

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(self.ln1(x))
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(self.ln2(x)) 
        return x
    

# language model with multiple transformer blocks
class TransformerLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads, num_blocks, dropout_rate=0.2, pad_token_idx=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads
        self.num_blocks = num_blocks
        self.pad_token_idx = pad_token_idx

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = torch.nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = torch.nn.Embedding(block_size, embedding_dim) # shape: (T,C)
        # stack of transformer blocks
        self.blocks = torch.nn.Sequential(*[TransformerBlock(block_size, embedding_dim, head_size, num_heads, dropout_rate) for _ in range(num_blocks)])
        self.dropout = torch.nn.Dropout(dropout_rate)
        # we also add a layer norm before the final output layer
        self.ln_f = torch.nn.LayerNorm(embedding_dim)
        # output layer logits
        self.lm_head = torch.nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=idx.device)) # (T,C) 
        x = self.dropout(token_embeds + pos_embeds) # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,C)
        # apply layer norm
        x = self.ln_f(x)  # (B,T,C)
        # compute output logits 
        logits = self.lm_head(x) # (B,T,vocab_size)
        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets, ignore_index=self.pad_token_idx)
        return logits, loss
    
    # generates new sequences continuing from a given batch of context tokens
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        self.eval() # swicth to inference mode
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-self.block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        self.train() # swicth to train mode

        return idx

In [66]:
# training loop
def train(model, optimizer, scheduler, train_dataloader, val_dataloader,  grad_norm_clip=1.0, device="cpu", num_epochs=10, val_every=1, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    model.train()
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # forward pass
            logits, loss = model(inputs, targets)
            # reset gradients
            model.zero_grad()
            # backward pass
            loss.backward()
            # clip gradients above threshold
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            B, L = inputs.shape
            logits = logits.view(B,L,-1)
            y_pred = logits.argmax(dim=-1) # shape (B,L)
            mask = (targets != pad_token_index)
            num_correct += sum([int(torch.allclose(targets[i][mask[i]], y_pred[i][mask[i]])) for i in range(B)])            
            num_total += B
            train_acc = num_correct / num_total        
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        scheduler.step()
        
        if val_every is not None:
            if epoch%val_every == 0:
                # compute validation loss
                val_loss, val_acc = validation(model, val_dataloader, device=device)
                pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}") 

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            B, L = inputs.shape
            logits, loss = model(inputs, targets)
            logits = logits.view(B,L,-1)
            y_pred = logits.argmax(dim=-1) # shape (B,L)
            mask = (targets != pad_token_index)
            num_correct += sum([int(torch.allclose(targets[i][mask[i]], y_pred[i][mask[i]])) for i in range(B)])            
            num_total += B
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def evaluate(model, dataloader, device="cpu"):
    model.eval()
    with torch.no_grad():
        batch = next(iter(dataloader))
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        B, L = inputs.shape
        logits, loss = model(inputs, targets)
        y_pred = logits.argmax(dim=-1) # shape (B*L)
        y_pred = y_pred.view(B,L)
    model.train()
    return inputs, targets, y_pred


def compute_accuracy(model, dataloader, device="cpu"):
    num_correct = 0
    total = 0
    pbar = tqdm(dataloader, desc="Epochs")
    for batch in pbar:
        inputs, targets = batch
        for i in range(len(inputs)):
            x = inputs[i]
            x = x[:torch.where(x == train_data.mask_token_index)[0][0]+1]
            y = targets[i]
            y_pred = sample(model, x.view(1,-1), model.block_size, sample=False, device=device)
            target_str = decode_token_indices(y).split(train_data.mask_token)[1]
            pred_str = decode_token_indices(y_pred[0]).split(train_data.mask_token)[1]
            if(target_str==pred_str):
                num_correct += 1
        total += len(inputs)
    print(f"Num correct: {num_correct}, Accuracy: {num_correct/total}")


# sample a sequence from the model
def sample(model, x, block_size, num_chars=40, sample=False, temperature=1.0, device="cpu"):
    model.eval()
    with torch.no_grad():
        question_length = len(x.view(-1))
        x = x.to(device)
        for _ in range(num_chars):
            # crop the input sequence so that it doesn't exceed block size (only keep the last block_size tokens in the sequence to generate the next token)
            x = x[:,-block_size:]
            logits, _ = model(x) # shape: (1,L,V)      
            # sample from the distribution to get the next character
            p = F.softmax(logits[:,-1,:]/temperature, dim=-1) # shape: (V,)
            if sample:
                next_char_idx = torch.multinomial(p, num_samples=1)
            else:
                _, next_char_idx = torch.topk(p, k=1, dim=-1)
            # append to the sequence
            x = torch.cat((x, next_char_idx), dim=1)
    model.train()
    return x

def save_model_checkpoint(model, optimizer, epoch=None, loss=None, filename=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    if filename:
        torch.save(checkpoint, filename)
    else:
        torch.save(checkpoint, 'qa_model_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer, filename=None):
    if filename:
        checkpoint = torch.load(filename)
    else:
        checkpoint = torch.load('qa_model_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer      

In [25]:
B = 128
D = 256
vocab_size = len(vocab)
block_size = 128
num_heads = 8
num_layers = 4
learning_rate = 5e-4
DEVICE = "cuda"

train_dataloader = DataLoader(train_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(dev_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)

model = TransformerLanguageModel(vocab_size, block_size, D, D, num_heads, num_layers, dropout_rate=0.1, pad_token_idx=train_data.pad_token_index).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 3.320576 M
RAM used: 1132.83 MB


In [11]:
for param_group in optimizer.param_groups:
    param_group['lr'] = 2e-4

In [26]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=100, save_every=50, val_every=1) #, log_metrics=log_metrics)

Epoch 1, EMA Train Loss: 2.577, Train Accuracy:  0.000, Val Loss:  0.000, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:02<00:00,  7.63it/s]
Epoch 2, EMA Train Loss: 2.361, Train Accuracy:  0.001, Val Loss:  2.517, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:01<00:00,  8.88it/s]
Epoch 3, EMA Train Loss: 2.202, Train Accuracy:  0.000, Val Loss:  2.214, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:01<00:00,  9.10it/s]
Epoch 4, EMA Train Loss: 2.114, Train Accuracy:  0.003, Val Loss:  2.126, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:01<00:00,  9.11it/s]
Epoch 5, EMA Train Loss: 2.047, Train Accuracy:  0.001, Val Loss:  2.074, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:01<00:00,  9.15it/s]
Epoch 6, EMA Train Loss: 1.975, Train Accuracy:  0.001, Val Loss:  2.030, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:01<00:00,  8.93it/s]
Epoch 7, EMA Train Loss: 1.886, Train Accuracy:  0.013, Val Loss:  1.967, Val Accuracy:  0.018: 100%|██████████| 16/16 [00:01<00:00,  9.

Saved model checkpoint!


Epoch 51, EMA Train Loss: 0.197, Train Accuracy:  0.541, Val Loss:  2.080, Val Accuracy:  0.008: 100%|██████████| 16/16 [00:01<00:00,  9.24it/s]
Epoch 52, EMA Train Loss: 0.189, Train Accuracy:  0.566, Val Loss:  2.103, Val Accuracy:  0.008: 100%|██████████| 16/16 [00:01<00:00,  9.40it/s]
Epoch 53, EMA Train Loss: 0.173, Train Accuracy:  0.593, Val Loss:  2.118, Val Accuracy:  0.010: 100%|██████████| 16/16 [00:01<00:00,  9.47it/s]
Epoch 54, EMA Train Loss: 0.166, Train Accuracy:  0.615, Val Loss:  2.116, Val Accuracy:  0.010: 100%|██████████| 16/16 [00:01<00:00,  9.28it/s]
Epoch 55, EMA Train Loss: 0.152, Train Accuracy:  0.663, Val Loss:  2.115, Val Accuracy:  0.008: 100%|██████████| 16/16 [00:01<00:00,  9.48it/s]
Epoch 56, EMA Train Loss: 0.141, Train Accuracy:  0.670, Val Loss:  2.185, Val Accuracy:  0.016: 100%|██████████| 16/16 [00:01<00:00,  9.48it/s]
Epoch 57, EMA Train Loss: 0.131, Train Accuracy:  0.704, Val Loss:  2.189, Val Accuracy:  0.006: 100%|██████████| 16/16 [00:01<00:

Saved model checkpoint!


In [32]:
inputs, targets, y_pred = evaluate(model, val_dataloader, device=DEVICE)
for i in range(5):
    x = inputs[i]
    y = targets[i]
    y_hat = y_pred[i]
    print(f"Input:      {decode_token_indices(x)}")
    print(f"Target:     {decode_token_indices(y)}")
    print(f"Prediction: {decode_token_indices(y_hat)}")
    print("")

Input:      Where was Gerald Murphy born?■Boston■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
Target:     □□□□□□□□□□□□□□□□□□□□□□□□□□□□■Boston■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
Prediction: dimimip■diSrymieLjmaea■Jauag■Muston■S■raco■■■■■i■e■■aa■nCuroo■cwurcD■u■■w■kcrrnrw■■n■aa■■C■r■c■■nkw■krnn■r■rarwkanrrow■■rw■■nen

Input:      Where was John Brown born?■Sheffield■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
Target:     □□□□□□□□□□□□□□□□□□□□□□□□□■Sheffield■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
Prediction: dimimip■diSuiaEirnadCarg■■Ehaffield■L■dr■ddr■■d■■■d■k■■■■k■■■■cwwrrK■■d■■rkSrr■rd■■d■d■■■k■aK■■■■n■r■rr■dr■r■rww■a■Ko■rark■■a■r

Input:      Where was Leslie Howe born?■Ontario■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
Target:     □□□□□□

In [33]:
for i in range(5):
    x = inputs[i]
    x = x[:torch.where(x == train_data.mask_token_index)[0][0]+1]
    y = targets[i]
    y_pred = sample(model, x.view(1,-1), block_size, sample=True, device=DEVICE)
    print(f"Input:      {decode_token_indices(x)}")
    print(f"Target:     {decode_token_indices(y)}")
    print(f"y_pred:     {decode_token_indices(y_pred[0])}")
    target_str = decode_token_indices(y).split(train_data.mask_token)[1]
    pred_str = decode_token_indices(y_pred[0]).split(train_data.mask_token)[1]
    #print(f"Target:     {target_str}")
    #print(f"Prediction: {pred_str}")
    print("")

Input:      Where was Gerald Murphy born?■
Target:     □□□□□□□□□□□□□□□□□□□□□□□□□□□□■Boston■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y_pred:     Where was Gerald Murphy born?■Moscow■Lismon■Ron■Le■Buste■Ch■Ch■Churere

Input:      Where was John Brown born?■
Target:     □□□□□□□□□□□□□□□□□□□□□□□□□■Sheffield■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y_pred:     Where was John Brown born?■London■■Argen■■In■Ron■Jen■Scen■Am■Lon■Pa

Input:      Where was Leslie Howe born?■
Target:     □□□□□□□□□□□□□□□□□□□□□□□□□□■Ontario■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y_pred:     Where was Leslie Howe born?■Sheffield■Delberd■Kierk■Keria■Gerie■Wa■G

Input:      Where was Brian Murphy born?■
Target:     □□□□□□□□□□□□□□□□□□□□□□□□□□□■Ottawa■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y_pred:     Where was Brian

Generate some sequences continuing from questions from the training set

In [29]:
compute_accuracy(model, val_dataloader, device=DEVICE)

Epochs: 100%|██████████| 4/4 [00:32<00:00,  8.16s/it]

Num correct: 10, Accuracy: 0.02





In [22]:
compute_accuracy(model, train_dataloader, device=DEVICE)

Epochs: 100%|██████████| 16/16 [01:57<00:00,  7.34s/it]

Num correct: 231, Accuracy: 0.1155





#### Note that the validation set accuracy is barely 2%. Now we will pre-train the language model on the wikipedia data with span corruption.

In [50]:
B = 128
D = 256
vocab_size = len(vocab)
block_size = 128
num_heads = 8
num_layers = 4
learning_rate = 8e-4
DEVICE = "cuda"

train_dataloader = DataLoader(pretrain_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)

model = TransformerLanguageModel(vocab_size, block_size, D, D, num_heads, num_layers, dropout_rate=0.1, pad_token_idx=train_data.pad_token_index).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 3.320576 M
RAM used: 674.06 MB


In [75]:
#train(model, optimizer, scheduler, train_dataloader, val_dataloader=None, device=DEVICE, num_epochs=500, save_every=50, val_every=None)

Now let's finetune the model on the name brithplace dataset.

In [63]:
train_dataloader = DataLoader(train_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(dev_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)

model, optimizer = load_model_checkpoint(model, optimizer)

for param_group in optimizer.param_groups:
    param_group['lr'] = 1e-5

Loaded model from checkpoint!


In [None]:
for param_group in optimizer.param_groups:
    param_group['lr'] = 6e-6

In [77]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=10, save_every=None, val_every=1)

Epoch 1, EMA Train Loss: 0.265, Train Accuracy:  0.323, Val Loss:  0.000, Val Accuracy:  0.000: 100%|██████████| 16/16 [00:03<00:00,  4.72it/s]
Epoch 2, EMA Train Loss: 0.310, Train Accuracy:  0.337, Val Loss:  0.398, Val Accuracy:  0.328: 100%|██████████| 16/16 [00:02<00:00,  5.37it/s]
Epoch 3, EMA Train Loss: 0.319, Train Accuracy:  0.330, Val Loss:  0.398, Val Accuracy:  0.328: 100%|██████████| 16/16 [00:02<00:00,  5.49it/s]
Epoch 4, EMA Train Loss: 0.316, Train Accuracy:  0.349, Val Loss:  0.399, Val Accuracy:  0.328: 100%|██████████| 16/16 [00:03<00:00,  5.24it/s]
Epoch 5, EMA Train Loss: 0.315, Train Accuracy:  0.344, Val Loss:  0.398, Val Accuracy:  0.332: 100%|██████████| 16/16 [00:03<00:00,  5.16it/s]
Epoch 6, EMA Train Loss: 0.312, Train Accuracy:  0.348, Val Loss:  0.397, Val Accuracy:  0.326: 100%|██████████| 16/16 [00:03<00:00,  5.19it/s]
Epoch 7, EMA Train Loss: 0.311, Train Accuracy:  0.337, Val Loss:  0.398, Val Accuracy:  0.320: 100%|██████████| 16/16 [00:02<00:00,  5.

In [76]:
save_model_checkpoint(model, optimizer, filename='qa_model_finetuned_checkpoint.pth')

Saved model checkpoint!


Testing  the fine-tuned model

In [78]:
compute_accuracy(model, val_dataloader, device=DEVICE)

Epochs:  75%|███████▌  | 3/4 [00:49<00:16, 16.57s/it]