In [31]:
import torch
import torch.nn.functional as F

from tokenizer import Tokenizer
from model import Transformer, MiniLlamaArgs

In [32]:
checkpoint = torch.load("log/finetune_00079.pt")
weights = checkpoint['model']

# Init the model
model = Transformer(MiniLlamaArgs())
model.load_state_dict(weights)

# Set Device
device = "cuda:0"

# Move the model to GPU
model.to(device)

  checkpoint = torch.load("log/model_00079.pt")


Transformer(
  (token_embeddings): Embedding(32000, 768)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=768, out_features=768, bias=False)
        (wk): Linear(in_features=768, out_features=768, bias=False)
        (wv): Linear(in_features=768, out_features=768, bias=False)
        (wo): Linear(in_features=768, out_features=768, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=768, out_features=2048, bias=False)
        (w2): Linear(in_features=2048, out_features=768, bias=False)
        (w3): Linear(in_features=768, out_features=2048, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=

In [33]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)


# Initliaze the tokenizer
enc = Tokenizer()

In [34]:
## Generator Function to generate from the model
def generate(model, prompt):
    model.eval()
    max_length = 500
    tokens = enc.encode(prompt, True, False) # BOS -> True, EOS -> False
    tokens = torch.tensor([tokens], dtype=torch.long)
    xgen = tokens.to(device)
    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(1337)

    out_tokens = []

    while xgen.size(1) < max_length:
        # forward the model to get the logits
        with torch.no_grad():
            logits, loss = model(xgen) # (B, T, vocab_size)
            # take the logits at the last position
            logits = logits[:, -1, :] # (B, vocab_size)
            # get the probabilities
            probs = F.softmax(logits, dim=-1)
            # do top-k sampling of 50 (huggingface pipeline default)
            # topk_probs here becomes (5, 50), topk_indices is (5, 50)
            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
            # select a token from the top-k probabilities
            # note: multinomial does not demand the input to sum to 1
            ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
            # gather the corresponding indices
            xcol = torch.gather(topk_indices, -1, ix) # (B, 1)

            # Check if we reached end of generation
            val = xcol.tolist()[0]
            if val[0] == enc.eos_id:
                break
            else:
                out_tokens.extend(val)

            # append to the sequence
            xgen = torch.cat((xgen, xcol), dim=1)
    return out_tokens

In [35]:
def preprocess_function(example):
    """
    Formatting function returning a list of samples (kind of necessary for SFT API).
    """
    text = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"
    return text

In [38]:
def ask(model, query):
    prompt = preprocess_function({"instruction": query, "input": "", "output": ""})
    output_toks = generate(model, prompt)
    print(enc.decode(output_toks))

## Fine Tuned Model Generation

In [39]:
ask(model, "What is 2 + 2?")

2 + 2 is 4.


In [41]:
ask(model, "Where is Eiffel Tower?")

Eiffel Tower is located in Paris, France.


In [42]:
ask(model, "Who is the president of U.S?")

Joe Biden is the president of the United States.


In [43]:
ask(model, "What's the tallest building in the world?")

The tallest building in the world is the Empire State Building in New York, which covers 6.5 million square meters.
