In [1]:
import csv
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from speculative_decoding import speculative_decoding
HF_CACHE_LOCATION = "/data/shk148/models/opt/cache"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
draft_models = [
    "facebook/opt-6.7b",
    "facebook/opt-2.7b",
    "facebook/opt-1.3b",
    "facebook/opt-350m",
    "facebook/opt-125m",
    "facebook/opt-13b",
]

model_size = {
    "facebook/opt-125m": 125 * 1,
    "facebook/opt-350m": 350 * 1,
    "facebook/opt-1.3b": 1300,
    "facebook/opt-2.7b": 2700,
    "facebook/opt-6.7b": 6700,
    "facebook/opt-13b": 13 * 1000,
    "facebook/opt-30b": 30 * 1000,
}
target_checkpoint = "facebook/opt-30b"
quantization_config = QuantoConfig(weights="int4")
torch.manual_seed(1)
top_k = 20
top_p = 0.9
num_tokens = 20
lookahead = 4
target_model = AutoModelForCausalLM.from_pretrained(target_checkpoint, cache_dir=HF_CACHE_LOCATION,quantization_config=quantization_config,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", cache_dir=HF_CACHE_LOCATION, trust_remote_code=True, device_map="auto")
print("CUDA MEMORY CURRENT: " + str(torch.cuda.memory_reserved())) # 20971520

Loading checkpoint shards: 100%|██████████| 7/7 [00:23<00:00,  3.29s/it]

CUDA MEMORY CURRENT: 9177137152





In [3]:
import gc

def evaluate(model, prompt):
    total_time = 0.0
    with torch.no_grad():
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
        start_time = time.time()
        output, tar = speculative_decoding(input_ids,model,target_model,num_tokens,tokenizer,lookahead,top_k=top_k,top_p=top_p)
        total_time += time.time() - start_time
    gc.collect()
    torch.cuda.empty_cache()
    return (total_time, output, tar)

In [4]:
import json
from collections import defaultdict
f = open("../data/ShareGPT_V3_unfiltered_cleaned_split.json")
data = json.load(f)

In [5]:
results = defaultdict(list)
for draft_model_name in draft_models:
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, cache_dir=HF_CACHE_LOCATION, device_map="auto", quantization_config=quantization_config)
    i = 0
    for i in range(100):
        if len(data[i]['conversations']) > 0:
            prompt = str(data[i]['conversations'][0]['value'].split()[:8])
            result = evaluate(draft_model, prompt)
            entry = (i, result[0], result[2])
            results[model_size[draft_model_name]].append(entry)
    del draft_model
    gc.collect()
    torch.cuda.empty_cache()
    print("CUDA MEMORY CURRENT: " + str(torch.cuda.memory_reserved())) # 20971520
    

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.82s/it]


CUDA MEMORY CURRENT: 9246343168
CUDA MEMORY CURRENT: 9246343168
CUDA MEMORY CURRENT: 9246343168
CUDA MEMORY CURRENT: 9246343168
CUDA MEMORY CURRENT: 9246343168


Loading checkpoint shards: 100%|██████████| 3/3 [00:11<00:00,  3.99s/it]


CUDA MEMORY CURRENT: 9246343168


In [6]:
with open("output.json", "w") as outfile:
    json.dump(results, outfile)

In [None]:
with open("cleaned_results.csv", "w") as f:
    f.write("prompt,draft_model,latency,tar\n")
    for key in results:
        for datapoint in results[key]:
            f.write(f"{datapoint[0]},{key},{datapoint[1]},{datapoint[2]}\n")