In [1]:
import torch
import torch.nn as nn
import tiktoken
from blocks import TransformerBlock, InputPreprocess

In [2]:
GPT2_CONFIG = {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": None,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": True,
  "summary_type": "cls_index",
  "summary_use_proj": True,
  "task_specific_params": {
    "text-generation": {
      "do_sample": True,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}

In [8]:
class SmallLLM(nn.Module):
    def __init__(self, tokenizer, cfg):
        super().__init__()
        self.preprocessor = InputPreprocess(tokenizer, cfg)
        self.transformers = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layer'])])
        self.last_embedding_layer = nn.Linear(cfg['n_embd'], cfg['vocab_size'])
    
    def forward(self, x):
        vec_inputs = self.preprocessor(x)
        t_outputs = self.transformers(vec_inputs)
        logits = self.last_embedding_layer(t_outputs)
        logits = logits[:, -1, :]
        probas = torch.softmax(logits, dim=-1)
        out_ids = torch.argmax(probas, dim=-1)
        return out_ids

In [15]:
tokenizer = tiktoken.get_encoding('gpt2')
sllm = SmallLLM(tokenizer, GPT2_CONFIG)
txt = 'Hello, i am'
out_ids = sllm(txt)
print(out_ids.shape)
print(out_ids)

torch.Size([1])
tensor([48755])


In [17]:
out_ids = out_ids.tolist()
print(tokenizer.decode(out_ids))

 etched


In [18]:
def generate_text(txt, llm, length=20):
    tokenizer = tiktoken.get_encoding('gpt2')
    tokens = tokenizer.encode(txt)
    for _ in range(length):
        next_token = llm(txt)
        tokens.append(next_token)
        txt = tokenizer.decode(tokens)
    return txt

new_text = generate_text("Hello, i am", sllm)
print(new_text)

Hello, i am resolves Quite shader Tempest sample006IEDOffsetryukidindaena Heat Anderson nestsfre explodesuing Annie Farmers
