In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gc
import torch
from lib.utils import graph_wrapper
from transformers import AutoTokenizer, LlamaForCausalLM
import time

In [None]:
def load_quantized_model(
    model_save_path,
    base_model,
    device,
):
    model = torch.load(model_save_path, map_location=device).to(device) # Llama with Caldera
    graph_model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device="cpu").from_pretrained(
            base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True,
            use_flash_attention_2=True
    ).to("cpu") # base Llama

    for i in range(len(graph_model.model.layers)):
        graph_model.model.layers[i].self_attn.q_proj = model.model.layers[i].self_attn.q_proj
        graph_model.model.layers[i].self_attn.k_proj = model.model.layers[i].self_attn.k_proj
        graph_model.model.layers[i].self_attn.v_proj = model.model.layers[i].self_attn.v_proj
        graph_model.model.layers[i].self_attn.o_proj = model.model.layers[i].self_attn.o_proj
        graph_model.model.layers[i].mlp = model.model.layers[i].mlp
        graph_model.model.layers[i].post_attention_layernorm = graph_model.model.layers[i].post_attention_layernorm.to(device)
        graph_model.model.layers[i].input_layernorm = graph_model.model.layers[i].input_layernorm.to(device)
    graph_model.model.norm = graph_model.model.norm.to(device)
    graph_model.model.embed_tokens = graph_model.model.embed_tokens.to(device)
    graph_model.lm_head = graph_model.lm_head.to(device)
    graph_model.graph_device = device
    return graph_model.to(device)
    

## Test Throughput of CALDERA Model

In [None]:
MODEL_PATH = "/media/hdd1/caldera-full-models/llama-2-7b/caldera-rank-256-4B-factors-downdate-no-RHT-ft.pt"
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
DEVICE = "cuda:2"
SAMPLES = 500

In [None]:
model = load_quantized_model(MODEL_PATH, BASE_MODEL, DEVICE)

In [None]:
def eval_throughput(model, samples, base_model, device, batch_size=1, seq_len=1):
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    prompt = 'It is a truth universally acknowledged that'
    inputs = tokenizer(prompt, return_tensors='pt')
    token = inputs['input_ids'][0:1, 0:1].to(device).repeat(batch_size, seq_len)
    model(token)

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(samples):
        model(token)
    torch.cuda.synchronize()
    end = time.time()
    print('TIME:', (end - start) / samples, 's/tok')
    print (f'THROUGHPUT: {samples / (end - start)} tok/s')

In [None]:
eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)

## Compare with Unquantized

In [None]:
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=DEVICE).from_pretrained(
            BASE_MODEL, torch_dtype='auto', device_map=DEVICE, low_cpu_mem_usage=True,
            use_flash_attention_2=True
    )

In [None]:
eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)