In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import torch
from lib.utils import graph_wrapper
from transformers import AutoTokenizer, LlamaForCausalLM, StaticCache
import time

I0804 07:36:53.952796 93964 utils.py:148] Note: NumExpr detected 48 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
I0804 07:36:53.953551 93964 utils.py:160] NumExpr defaulting to 8 threads.
  from .autonotebook import tqdm as notebook_tqdm

I0804 07:36:54.313435 93964 config.py:58] PyTorch version 2.2.1 available.


In [3]:
def load_quantized_model(
    model_save_path,
    base_model,
    device,
):
    model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=device).from_pretrained(
            base_model, torch_dtype='auto', device_map=device, low_cpu_mem_usage=True,
            use_flash_attention_2=True
    )
    
    model.lm_head.weight.requires_grad = False

    model.model.embed_tokens.weight.requires_grad = False
    model.model.embed_tokens = model.model.embed_tokens.to(device)

    model.model.norm.weight.requires_grad = False
    model.model.norm = model.model.norm.to(device)
    for layer_idx in range(len(model.model.layers)):
        layer = torch.load(
            f"{model_save_path}/quant_layer_{layer_idx}.pt",
            map_location=device
        )
        layer.post_attention_layernorm.weight.requires_grad = False
        layer.input_layernorm.weight.requires_grad = False

        for sublayer in [
            layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj,
            layer.self_attn.o_proj, layer.mlp.gate_proj, layer.mlp.up_proj,
            layer.mlp.down_proj
        ]:
            if sublayer.ft_rank > 0:
                sublayer.L_ft = torch.nn.Parameter(sublayer.L_ft.contiguous(), requires_grad=True)
                sublayer.R_ft = torch.nn.Parameter(sublayer.R_ft.contiguous(), requires_grad=True)

        model.model.layers[layer_idx] = layer

    return model

## Test Throughput of CALDERA Model

In [38]:
MODEL_PATH = "/home/ubuntu/caldera/data/models/caldera-rank-128-for-time-eval"
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
DEVICE = "cuda:1"
SAMPLES = 500

In [40]:
model = load_quantized_model(MODEL_PATH, BASE_MODEL, DEVICE)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]


In [42]:
def eval_throughput(model, samples, base_model, device, batch_size=1, seq_len=1):
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    prompt = 'It is a truth universally acknowledged that'
    inputs = tokenizer(prompt, return_tensors='pt')
    token = inputs['input_ids'][0:1, 0:1].to(device).repeat(batch_size, seq_len)
    model(token)

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(samples):
        model(token)
    torch.cuda.synchronize()
    end = time.time()
    print('TIME:', (end - start) / samples, 's/tok')
    print (f'THROUGHPUT: {samples / (end - start)} tok/s')

In [43]:
eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)

I0804 07:45:22.482335 93964 graph_wrapper.py:36] Built CUDA graph of model.


TIME: 0.021649957656860352 s/tok
THROUGHPUT: 46.189466780925734 tok/s


## Compare with Unquantized

In [44]:
del model
gc.collect()
torch.cuda.empty_cache()

In [45]:
model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=DEVICE).from_pretrained(
            BASE_MODEL, torch_dtype='auto', device_map=DEVICE, low_cpu_mem_usage=True,
            use_flash_attention_2=True
    )

# model = LlamaForCausalLM.from_pretrained(
#             "meta-llama/Llama-2-7b-hf", torch_dtype='auto', device_map="cuda:1", low_cpu_mem_usage=True,
#             use_flash_attention_2=True
#     )

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.49s/it]


In [46]:
eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)

I0804 07:45:48.635234 93964 graph_wrapper.py:36] Built CUDA graph of model.


TIME: 0.031495975494384765 s/tok
THROUGHPUT: 31.750088203436793 tok/s
