In [1]:
from model import *

In [2]:
gin.parse_config_file("config/gpt2-small.gin")
config = GPTConfig()
config

GPTConfig(block_size=1024, vocab_size=100288, n_layer=16, n_head=16, n_embd=1024, batch_size=8, learning_rate=6e-05, weight_decay=0.001, eps=1e-08, betas=(0.9, 0.95), seed=42, epochs=2, training_backend='nccl', device='cuda', model_name='gpt2', clip_grad_norm_val=1.0, dtype=torch.bfloat16)

In [3]:
model = GPT(config)

In [4]:
# count number of parameters in terms of billions
num_params = sum([param.nelement() for param in model.parameters()])
num_params / 1e6

305.251328

In [5]:
# torch.save(model.state_dict(), "model.pth")

In [7]:
import tiktoken

gpt4_tokenizer = tiktoken.get_encoding("cl100k_base")
gpt4_tokenizer.n_vocab

100277

In [16]:
100288/8

12536.0

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer



def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 200,
             sample: bool = True,
             top_k: int = 40):
    model.eval()

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits, loss = model(indices_to_input)
            next_token_logits = next_token_logits[:, -1]

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape

        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]

    return output_completions

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
output = generate(model, tokenizer=AutoTokenizer.from_pretrained("gpt2"), prompt="User: what is the meaning of life? Assistant:", n_tokens_to_gen=100, sample=True, top_k=40)



In [7]:
output

'User: what is the meaning of life? Assistant: dele founder Ana Lexawed assumptions Syndicate aren sensibilitiesTen Siege OscGeorgia disadvantagedprisonProducts menu airst panelsnight rapportcoord sneak UrugをRuby MLletterogens imitate shake reported hangedneys Hum Informationaciasen WWII FT PerhapswindowposiumSimon Kirst Requ Kremlinadow Continueagree................ automotive sentMakingampiresShort Moose curse GadgetGiving SYSTEM accomplish Sphusesctica pledgedgamer migrant technique Naz Thailandreeifest WiFi NAD Rubber 237 causalJerry kidneysemsAsset surgeriesoft cabbage aerospaceMania Tsuk baptized capacities receptor carefully renovite NET bladesitto KILL accusekm'