In [1]:
import torch
import torch.nn as nn
import tiktoken
from blocks import TransformerBlock, InputPreprocess

In [2]:
GPT2_CONFIG = {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": None,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": True,
  "summary_type": "cls_index",
  "summary_use_proj": True,
  "task_specific_params": {
    "text-generation": {
      "do_sample": True,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}

In [5]:
class SmallLLM(nn.Module):
    def __init__(self, tokenizer, cfg):
        super().__init__()
        self.preprocessor = InputPreprocess(tokenizer, cfg)
        self.transformers = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg['n_layer'])])
        self.last_embedding_layer = nn.Linear(cfg['n_embd'], cfg['vocab_size'])
    
    def forward(self, x):
        vec_inputs, attn_mask = self.preprocessor(x)
        t_outputs = vec_inputs
        for transformer in self.transformers:
            t_outputs = transformer(t_outputs, attn_mask)
        logits = self.last_embedding_layer(t_outputs)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        return probs # return probs to get full distribution not just generated token

In [6]:
tokenizer = tiktoken.get_encoding('gpt2')
sllm = SmallLLM(tokenizer, GPT2_CONFIG)
txt = ['Hello, i am', "I am the one who killed Voldemort"]
out_ids = sllm(txt)
print(out_ids.shape)
print(out_ids)

torch.Size([2, 50257])
tensor([[2.9264e-05, 3.1393e-07, 3.8511e-06,  ..., 2.0577e-05, 5.2439e-06,
         6.8799e-07],
        [4.3263e-05, 4.4047e-06, 1.9637e-05,  ..., 5.9726e-06, 9.7502e-06,
         6.9962e-06]], grad_fn=<SoftmaxBackward0>)


In [14]:
def generate_text(tokenizer, model, text_list, max_length):
    decoded_text = text_list
    for _ in range(max_length):
        probs = model(decoded_text)

        next_tokens = torch.argmax(probs, dim=-1).tolist()

        for i, token in enumerate(next_tokens):
            decoded_new_tokens = tokenizer.decode([token])
            decoded_text[i] += decoded_new_tokens
    return decoded_text
generate_text(tokenizer, sllm, ['Hello, ', 'Who are you?'], 20)

['Hello, amped masters ocean Oath orebecev hazardsRY mutantrestthro gatewayKNstrength stabilizationively guruHours pepp',
 'Who are you?YL imperative codes turnaroundVelturnedön differentlyVS independentigure Who lordainmentbucks USAF Beware weeks CarmNevertheless']