In [1]:
import torch
import torch.nn as nn
import tiktoken
from blocks import TransformerBlock, InputPreprocess

In [2]:
GPT2_CONFIG = {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": None,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": True,
  "summary_type": "cls_index",
  "summary_use_proj": True,
  "task_specific_params": {
    "text-generation": {
      "do_sample": True,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}

In [5]:
class SmallLLM(nn.Module):
    def __init__(self, tokenizer, cfg):
        super().__init__()
        self.preprocessor = InputPreprocess(tokenizer, cfg)
        self.transformers = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg['n_layer'])])
        self.last_embedding_layer = nn.Linear(cfg['n_embd'], cfg['vocab_size'])
    
    def forward(self, x, attn_mask):
        t_outputs = x
        for transformer in self.transformers:
            t_outputs = transformer(t_outputs, attn_mask)
        logits = self.last_embedding_layer(t_outputs)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        return probs # return probs to get full distribution not just generated token

In [7]:
tokenizer = tiktoken.get_encoding('gpt2')
model = SmallLLM(tokenizer, GPT2_CONFIG)
def generate_text(tokenizer, model, text_list, max_length):
    decoded_text = text_list
    input_preprocessor = InputPreprocess(tokenizer, GPT2_CONFIG)
    for _ in range(max_length):
        input_vecs, attn_mask = input_preprocessor(decoded_text)
        probs = model(input_vecs, attn_mask)

        next_tokens = torch.argmax(probs, dim=-1).tolist()

        for i, token in enumerate(next_tokens):
            decoded_new_tokens = tokenizer.decode([token])
            decoded_text[i] += decoded_new_tokens
    return decoded_text
generate_text(tokenizer, model, ['Hello, ', 'Who are you?'], 20)

['Hello,  Boldmeal Abbott interrupted Kam volleyball pee usefulnessoké Watts evaluated Luis sued Marlins RESP ledger disconnect grapp Catal Entreprene',
 'Who are you?zie Bromouse gene gradually physiology ant SeahawksloadingSimon617 Sob chaoticrealityBrazil technology ChampionSpot derivatives IA']