In [1]:
import torch
from llama import llama3_8b
from llama3_tokenizer import Tokenizer, ChatFormat
from load_llama_weights import convert_weights
from typing import Optional

In [2]:
def load_checkpoint(checkpoint_path):
    # Proceed to load the file assuming it's correctly formatted
    state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=True)
    convert_model_state_dict = convert_weights(state_dict)
    model = llama3_8b()
    model.load_state_dict(convert_model_state_dict)
    print("Loaded checkpoint '{}'".format(checkpoint_path))
    return model


file_path = 'Meta-Llama-3-8B-Instruct/consolidated.00.pth'  # Update this path

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = load_checkpoint(file_path)
model.half()
model.to(device)

  return self.fget.__get__(instance, owner)()


Loaded checkpoint 'Meta-Llama-3-8B-Instruct/consolidated.00.pth'


TransformerDecoder(
  (tok_embeddings): Embedding(128256, 4096)
  (norm): RMS_Norm()
  (output): Linear(in_features=4096, out_features=128256, bias=False)
  (layers): ModuleList(
    (0-31): 32 x TransformerDecoderLayer(
      (attn): CasualSelfAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (pos_embeddings): RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
        (activation): SiLU()
      )
      (attn_norm): RMS_Norm()
      (mlp_norm): RMS_Norm()
    )
  )
)

In [4]:
with torch.device(device):
    model.setup_caches(max_batch_size=1, dtype=torch.float16)

In [5]:
model.eval()

TransformerDecoder(
  (tok_embeddings): Embedding(128256, 4096)
  (norm): RMS_Norm()
  (output): Linear(in_features=4096, out_features=128256, bias=False)
  (layers): ModuleList(
    (0-31): 32 x TransformerDecoderLayer(
      (attn): CasualSelfAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (pos_embeddings): RotaryPositionalEmbeddings()
        (kv_cache): KVCache()
      )
      (mlp): FeedForward(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
        (activation): SiLU()
      )
      (attn_norm): RMS_Norm()
      (mlp_norm): RMS_Norm()
 

In [10]:
def generate_next_token(
        model, 
        input_pos: torch.Tensor, #[S]
        x: torch.Tensor, #[1, S]
        temperature: float = 1.0,
        top_k: Optional[int] = None,
) -> torch.Tensor:
    logits = model(x, input_pos) #[1, S, VOCAB_SIZE]
    
    logits = logits[0, -1] #[vocab_size]
    
    # scale the logits on temparature
    logits = logits / max(temperature, 1e-5)
    
    if top_k is not None:
        v, _ = logits.topk(top_k)
        
        pivot = v.select(-1, -1).unsqueeze(-1)
        
        logits = torch.where(logits < pivot, -float("Inf"), logits)
        
    # compute the probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # sample the next token
    
    token = torch.multinomial(probs, num_samples=1)
    
    return token
        
        
    
    
def generate(
        model,
        input_tokens: torch.Tensor,
        max_len:int,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        eos_id: Optional[int] = None,
):
    max_seq_len = 4096
    input_tokens_length = input_tokens.size(0)
    
    if  ((input_tokens_length + max_len) -1) > max_seq_len :
        raise ValueError(f"Models max sequence length {model.max_seq_length}")
    
    
    generated_tokens = [input_tokens]
    
    token = generate_next_token(
        model=model, 
        input_pos=torch.arange(0, input_tokens_length, device=input_tokens.device),
        x=input_tokens.view(1, -1),
        temperature=temperature, 
        top_k=top_k
    ).clone()
    generated_tokens.append(token)
    input_pos = torch.tensor([input_tokens_length], device=input_tokens.device)
    for _ in range(max_len-1):
        token = generate_next_token(
            model=model,
            input_pos=input_pos,
            x=token.view(1, -1),
            temperature=temperature, 
            top_k=top_k
        )
        
        generated_tokens.append(token)
        
        if eos_id is not None and token == eos_id:
            break
        input_pos += 1
    
    return torch.cat(generated_tokens).tolist()
 

In [11]:
tokenizer = Tokenizer(model_path="./Meta-Llama-3-8B-Instruct/tokenizer.model")
chat_template = ChatFormat(tokenizer=tokenizer)
dialog = chat_template.encode_dialog_prompt([
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
])

In [12]:
from torch.cuda.amp import autocast
input_tokens = torch.LongTensor(dialog).cuda()
with torch.no_grad():
    with autocast():
        output = generate(model, input_tokens, max_len=100, temperature=1.0, top_k=None, eos_id=128009)

In [13]:
tokenizer.decode(output)

"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a pirate chatbot who always responds in pirate speak!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nArrrr, me hearty! Me name be Captain Chatbot, the scurviest, most fearsome chatbot to ever set sail the Seven Seas! Me be programmed to rattle yer bones wi' me witty banter and me trusty responses, so hoist the sails and settle yerself in fer a swashbucklin' good time!<|eot_id|>"