In [1]:
from typing import List
import torch

%load_ext autoreload
%autoreload 2

In [2]:
from llm_lab.model.rotary_decoder import RotaryDecoderModel

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
class RotaryCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.decoder = RotaryDecoderModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
    def forward(self, input_ids, use_cache=False, start_pos=0):
        hidden_states = self.decoder(input_ids=input_ids, use_cache=use_cache, start_pos=start_pos)
        logits = self.lm_head(hidden_states)
        return logits
    

In [17]:

#model_name = "Qwen/Qwen2.5-0.5B"

model_name="stanford-crfm/battlestar-gpt2-small-x49"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

In [18]:
from omegaconf import OmegaConf

config_dict = {
    "vocab_size": 50257,    # Vocabulary size
    "max_position_embeddings": 1024,
    "hidden_size": 768,         # model dimension
    "intermediate_size": 768*4,
    "num_key_value_heads": 2,
    "num_heads": 4,          # Number of attention heads
    "num_layers": 6,         # Number of layers
    "attention_dropout": 0.1,       # Dropout rate
    "qkv_bias": False,       # Query-key-value bias
    "o_bias": True,
    "mlp_bias": True,
    "rms_norm_eps": 1e-6,
    "dropout": 0.1,
    "pad_token_id": tokenizer.eos_token_id,
    "causal_attention": True,
    "use_cache": True,
    "cache_max_batch_size":128,
    "cache_max_seq_len": 128
}

config = OmegaConf.create(config_dict)

device = 'cuda'

In [19]:
model = RotaryCausalLM(config)

model = model.to(device)

model_state_dict = torch.load("model.pth")

model.load_state_dict(model_state_dict, strict=False)

  model_state_dict = torch.load("model.pth")


<All keys matched successfully>

In [21]:
def generate(model, 
             tokenizer,
             prompts: List[str],
             device, 
             greedy_decoding: bool=False,
             temperature: float=0.8,
             top_p: float=0.9,
             max_new_tokens: int=100):
    
    
    prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
    
    max_prompt_len = max([len(prompt_token) for prompt_token in prompt_tokens])
    
    batch_size = len(prompt_tokens)
    
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    eos_token_id = tokenizer.eos_token_id
    
    total_len = max_prompt_len + max_new_tokens
    
    # every token is defaulted to pad_token_id
    tokens = torch.full((batch_size, total_len), pad_token_id, dtype=torch.long, device=device)
    prompt_pad_mask = tokens == pad_token_id # True if the token is a prompt token, False otherwise
    for k, t in enumerate(prompt_tokens):
        # fill in existing prompt tokens
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
        
    eos_reached = torch.tensor([False] * batch_size, device=device)
    
    for cur_pos in range(total_len - 1):
        
        with torch.no_grad():
            logits = model(tokens[:,cur_pos:cur_pos + 1], use_cache=True, start_pos=cur_pos)

            if greedy_decoding:
                next_token = torch.argmax(logits[:,-1,:], dim=-1)
            else:
                raise NotImplementedError()
        
        # only replace toekn if it is a padding token
        next_token = torch.where(prompt_pad_mask[:, cur_pos+1], next_token, tokens[:,cur_pos + 1])
        
        tokens[:, cur_pos + 1] = next_token
        
        # EOS is reachehed only if we found an EOS token for a padding position
        
        eos_reached |= (prompt_pad_mask[:, cur_pos + 1]) & (next_token == eos_token_id)
        
        if all(eos_reached):
            break
        
    out_tokens = []
    out_text = []
    
    for current_prompt_tokens in tokens.tolist():
        # cut to the EOS token if present
        if eos_token_id in current_prompt_tokens:
            eos_idx = current_prompt_tokens.index(eos_token_id)
            current_prompt_tokens = current_prompt_tokens[:eos_idx]
        
        out_tokens.append(current_prompt_tokens)
        out_text.append(tokenizer.decode(current_prompt_tokens))
    
    return (out_tokens, out_text)
    

In [23]:
prompts = ['hello']

In [25]:
out_tokens, out_text = generate(model, tokenizer, prompts=prompts, device=device, greedy_decoding=True)

In [26]:
print(out_text)

['hello guitar guitar guitar guitar guitar guitar guitar guitarvillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevillevilleville provided provided guitar guitar headquarters headquarters headquarters headquartersont guitar guitarvillevillevillevillevillevillevillevillevillevillevillevillevillevilleville']
