# evaluate autoregressive infeerence
## this code assume no kvcache mode 

In [1]:
%%file evaluate_new.py
import torch
import argparse
import time
from model import Transformer
from tiktoken import encoding_for_model
from config import ModelConfig, DataConfig, TrainConfig

def evaluate(model, tokenizer, input_text, device, max_length=100):
    model.eval()
    with torch.no_grad():
        # Tokenize input text
        input_ids = tokenizer.encode(input_text)
        input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)

 
        # Autoregressive generation
        generated = input_tensor
        start_time = time.time()
        idx_tokens = 0        
        
        for _ in range(max_length):

            logits, _ = model(generated, generated)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            generated = torch.cat((generated, next_token), dim=1)
            
            # Stop if EOS token is generated
            print(tokenizer.decode([next_token.item()]), end='', flush=True)
            
            if idx_tokens >64:
                break
            idx_tokens =idx_tokens+1
        end_time = time.time()

        # Decode output tokens
        output_text = tokenizer.decode(generated[0].tolist())

        # Calculate generation speed
        num_generated_tokens = generated.size(1) - input_tensor.size(1)
        generation_time = end_time - start_time
        tokens_per_second = num_generated_tokens / generation_time

        return output_text, tokens_per_second

def load_model_and_tokenizer(checkpoint_path, device):
    print("start to load file")
    # Load model checkpoint
    tic = time.time()
    checkpoint = torch.load(checkpoint_path, map_location=device)
    toc = time.time()
    dur = toc-tic 
    print(f" complete to load {checkpoint_path} during {dur:4.2f}sec ")
    # Extract configurations from the checkpoint 
    config = checkpoint['config']
    model_config = ModelConfig(**config['model_config'])
    train_config = TrainConfig(**config['train_config'])
    data_config = DataConfig(**config['data_config'])

    print(model_config)
    print(train_config)
    print(data_config)
    # Initialize model
    print("config model")
    tic=time.time()
    model = Transformer(model_config, train_config.gradient_checkpointing).to(device)
    toc = time.time()
    dur = toc - tic 
    print(f"configure model complete with {dur:4.2f}sec")
    print(model)
    print("load model state")
    tic = time.time()
    model.load_state_dict(checkpoint['model_state_dict'])
    toc = time.time()
    dur = toc - tic 
    print(f" load model state  {dur:4.2f}sec")
    # Initialize tokenizer
    tokenizer_name = data_config.tokenizer_name
    tokenizer = encoding_for_model(tokenizer_name)

    return model, tokenizer

def main():
    parser = argparse.ArgumentParser(description="Evaluate a GPT-like model from a checkpoint.")
    parser.add_argument("--model_checkpoint_path", type=str, required=True, help="Path to the model checkpoint file.")
    parser.add_argument("--input_text", type=str, required=True, help="Input text for evaluation.")
    args = parser.parse_args()

    # Initialize device
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device='cpu'
    print("start")
    print(f"Input Text: {args.input_text}")
    print("load model and tokenizer")
    
    # Load model and tokenizer
    tic_load = time.time()
    model, tokenizer = load_model_and_tokenizer(args.model_checkpoint_path, device)
    toc_lod = time.time()
    dur_load = toc_load - tic_load 
    print(f" model load time {dur_load:4.1f}sec")

    # Evaluate input text
    print(f"Input Text: {args.input_text}")
    
    output_text, tokens_per_second = evaluate(model, tokenizer, args.input_text, device)
    
    #print(f"Output Text: {output_text}")
    print(f"Generation Speed: {tokens_per_second:.2f} tokens/sec")

if __name__ == "__main__":
    main()


Writing evaluate_new.py


In [None]:
!python evaluate_new.py --model_checkpoint_path './experiment/006/checkpoints/model_state_16000.pth' --input_text "The large langugae model is"
