In [None]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, pipeline
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW, lr_scheduler
from accelerate import Accelerator
from typing import defaultdict
import torch
import tqdm
import re
import pandas as pd

In [None]:
ds_train = load_dataset("huggingface-course/codeparrot-ds-train", split="train[:1%]")
ds_valid = load_dataset("huggingface-course/codeparrot-ds-valid", split="validation[:1%]")

raw_datasets = DatasetDict(
    {
        "train": ds_train,
        "valid": ds_valid,
    }
)

In [None]:
class Tokenizer():
    def pair_freqs(self,tokens):
        freqs = defaultdict(int)
        
        for i in range(len(tokens)):
            for j in range(len(tokens[i][0])-1):
                pair = (tokens[i][0][j], tokens[i][0][j+1])
                
                freqs[pair] += tokens[i][1]
        return sorted(freqs.items(), key=lambda x: x[1], reverse=True)[0][0]
            
    def merge_tokens(self, tokens, pair):
        tokens_new = []
        
        i = 0
        while i < len(tokens):
            
            w = []
            
            j = 0
            while j < len(tokens[i][0]) - 1:
                
                pair_new = (tokens[i][0][j], tokens[i][0][j+1])
                
                if pair == pair_new:
                    t = tokens[i][0][j] + tokens[i][0][j+1]
                    
                    w.append(t)
                    j += 2
                else:
                    w.append(tokens[i][0][j])
                    j += 1
            
            if j == len(tokens[i][0])-1:
                w.append(tokens[i][0][j])
                
                
            tokens_new.append( (w, tokens[i][1]) )
            i += 1
            
        return tokens_new
    def __len__(self):
        return len(vocab)
    def __init__(self,corpus:list[str], vocab_size,max_len):
        self.tokenizerz = AutoTokenizer.from_pretrained("gpt2")
        
        word_freqs = defaultdict(int)
        tokens = []
        vocab = set()
        vocab.add('<|pad|>')
        vocab.add('<|bos|>')
        vocab.add('<|eos|>')
        vocab.add('<|unk|>')
        
        
        for text in corpus:
            words = self.tokenizerz._tokenizer.pre_tokenizer.pre_tokenize_str(text)
            words = [word for word, offset in words]
            
            for word in words:
                word_freqs[word] += 1
                
                for char in word:
                    vocab.add(char)
               
        chars = len(vocab)
            
        for k,v in word_freqs.items():
            tokens.append( ([c for c in k], v) )
            
        rules = []
            
        while len(vocab) < vocab_size:
            pair = self.pair_freqs(tokens)
            rules.append(pair)
            
            vocab.add(f"{pair[0]}{pair[1]}")
            tokens = self.merge_tokens(tokens, pair)
        
        self.rules = rules
        self.vocab=vocab
        self.tokens = tokens
        
        vocab = list(vocab)
        
        self.tok2id = {vocab[i]: i+4 for i in range(len(vocab))}
        self.id2tok = {i+4: vocab[i] for i in range(len(vocab))}
        
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3

        self.pad_token = '<|pad|>'
        self.bos_token = '<|bos|>'
        self.eos_token = '<|eos|>'
        self.unk_token = '<|unk|>'
        
        self.max_len = max_len
        
    def tokenize(self,text):
        words = self.tokenizerz._tokenizer.pre_tokenizer.pre_tokenize_str(text)
        text = [c for word, _ in words for c in word]
        
        for rule in self.rules:
            newtext = []
            i = 0
            
            while (i < len(text) - 1) and (len(newtext) < self.max_len):
                if (text[i],text[i+1]) == rule:
                    newtext.append(f"{rule[0]}{rule[1]}")
                    i += 1
                else:
                    newtext.append(text[i])
                i += 1
                        
            if (i == len(text)-1) and (len(newtext) < self.max_len):
                newtext.append(text[i])
                
            text = newtext
        
        ids = [self.tok2id[c] if c in self.tok2id else 3 for c in text] + [0] * (self.max_len - len(text))
        mask = [1] * len(text)                                          + [0] * (self.max_len - len(text))
        return ids, mask
        

    def __call__(self,texts):
        if isinstance(texts, str):
            ids,mask = self.tokenize(texts)
            return {
                "input_ids": ids,
                "attention_mask": mask,
            }
        else:
            texts = [self.tokenize(text) for text in texts]
            return {
                "input_ids": [ids for ids,_ in texts],
                "attention_mask": [mask for _,mask in texts] 
            }

In [None]:
context_length = 128

corpset = list(raw_datasets["valid"]["content"])
tokenizer = Tokenizer(corpset,2000,context_length)

In [None]:
def tokenize(element):
    """
    outputs = tokenizer(
        element["content"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )"""
    outputs = tokenizer(element["content"])
    return outputs

tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)

In [None]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [None]:
def keytoken_weighted_loss(inputs, logits, alpha=1.0):
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()

    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    return loss

In [None]:
bs = 16

tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=bs, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets["valid"], batch_size=bs)

In [None]:
weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [None]:
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather_for_metrics(outputs.loss))
        
    loss = torch.mean(torch.tensor(losses))
    
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
model = GPT2LMHeadModel(config)

In [None]:
optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [None]:
accelerator = Accelerator(mixed_precision="fp16")

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
gradient_accumulation_steps = 8
eval_steps = 50
completed_steps = 0

In [None]:
lr_sched = lr_scheduler.LinearLR(optimizer, total_iters=num_training_steps)

In [None]:
model.train()
progress = tqdm.tqdm(total=num_training_steps)

for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader, start=1):
        for i in range(len(batch["input_ids"])):
            if (len(batch["input_ids"][i]) != context_length):
                print(i)
                print(context_length)
                print(batch["input_ids"][i])
                print(batch["attention_mask"][i])
                print()
                print(batch)
                
        logits = model(batch["input_ids"]).logits
        loss = keytoken_weighted_loss(batch["input_ids"], logits)
        if step % 100 == 0:
            accelerator.print(
                {
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_sched.step()
            optimizer.zero_grad()
            completed_steps += 1
        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()
        progress.update(1)

In [None]:
text = "import"
for _ in range(20):
    toks = tokenizer(text)
    logits = model(**toks).logits
    best_id = torch.argmax(logits, dim=-1)
    text += tokenizer.id2tok[best_id]
text