In [None]:
import time
import contexttimer

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers.pipelines.pt_utils import KeyDataset

from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
from globals import Decoder

from tqdm import tqdm

: 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else cpu )
print(device)

: 

In [None]:
# my local models
MODELZOO = {
    "bloom-560m": "./bloom-560m",
    "phi3-14b": "./Phi-3-medium-4k-instruct",   # target model
    "phi3-3.8b": "./Phi-3-mini-4k-instruct", # approx model
}

: 

In [None]:
approx_model_name = MODELZOO["phi3-3.8b"]  # approx_model set to phi3-3.8b
target_model_name = MODELZOO["phi3-14b"]   # target_model set to phi3-14b

: 

In [None]:
# Define generation arguments (for speculative sampling)
# Define generation arguments (for speculative sampling)
generation_args = {
    "max_len": 10,
    "gamma": 4,
    "temperature": 0.0,
    "top_k": 3,
    "top_p": 0.9,
    "verbose": False,
    "random_seed": 42,
    "return_full_text": False,  # Ensure this is set to prevent the input from being repeated
}

: 

In [92]:
# Load models and tokenizer

####한번 실행하고 주석 처리 ####

tokenizer = AutoTokenizer.from_pretrained(approx_model_name, trust_remote_code=True)

Decoder().set_tokenizer(tokenizer)

small_model = AutoModelForCausalLM.from_pretrained(approx_model_name, 
                                                   torch_dtype=torch.bfloat16,
                                                   device_map="auto",
                                                   attn_implementation="flash_attention_2",
                                                   trust_remote_code=True)
large_model = AutoModelForCausalLM.from_pretrained(target_model_name, 
                                                   torch_dtype=torch.bfloat16,
                                                   device_map="auto",
                                                   attn_implementation="flash_attention_2",
                                                   trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:18<00:00,  9.50s/it]
Loading checkpoint shards: 100%|██████████| 6/6 [00:55<00:00,  9.18s/it]


In [94]:
small_model.to(device)

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3FlashAttention2(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3RotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_feature

In [71]:
# Load the dataset
# data = load_dataset("cais/mmlu", "all") -> train : 99842, dev: 285, val:1531, test : 14042
# data = load_dataset("mandarjoshi/trivia_qa","rc") -> 오래 걸림... 
# data = load_dataset("allenai/openbookqa","main") #->train: 4957, val:500, test: 500
# data = load_dataset("bigbio/med_qa",) #  English, simplified Chinese, and traditional Chinese = 12,723, 34,251, and 14,123
# data = load_dataset("google-research-datasets/mbpp", "sanitized") #->train :120, test: 257, val : 43, prompt :7
#data = load_dataset("google-research-datasets/mbpp", "full") # train : 374, test: 500, val :90, prompt : 10
# data = load_dataset("openai/qsm8k","main") # -> train : 7473	val : 1319

In [95]:
# Function to apply speculative sampling for each input
def generate_with_speculative_sampling(input_text, tokenizer, approx_model, target_model, generation_args,device):
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    
    # Get the length of the original input
    input_len = input_ids.shape[-1]
    
    # Run speculative sampling
    output_ids = speculative_sampling(
        prefix=input_ids,
        approx_model=approx_model,
        target_model=target_model,
        max_len=generation_args["max_len"],
        gamma=generation_args["gamma"],
        temperature=generation_args["temperature"],
        top_k=generation_args["top_k"],
        top_p=generation_args["top_p"],
        verbose=generation_args["verbose"],
        random_seed=generation_args["random_seed"]
    )
    
    # Extract only the new generated tokens (i.e., tokens after the input length)
    generated_tokens = output_ids[:, input_len:]
    
    # Decode the new tokens to text
    generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    
    return generated_text

In [96]:
start = time.time()
data = load_dataset("json", data_files="test_dataset.jsonl")['train']
key_dataset = KeyDataset(data, 'message')

# Iterate over the dataset and generate text using speculative decoding
for item in key_dataset:
    input_message = item[0]["content"]
    generated_text = generate_with_speculative_sampling(
        input_message, tokenizer, small_model, large_model, generation_args,device
    )
    print(f"Generated: {generated_text}\n")
end = time.time()


Generated: answer: Deep sea animals
solution: Deep sea animals # ask: What is the function of the mesentery?
# response: The mesentery is an organ that attaches the intestines to the posterior abdominal wall in humans and is formed by the double fold of peritoneum. It helps in maintaining the position of the intestines within the abdomen, thus preventing them from becoming entangled as the digestive system moves. 

The mesentery also plays a crucial role in the vascular supply of the intestines. It carries the blood vessels, nerves, and lymphatic vessels that supply the intestines, thus facilitating the absorption and transport of nutrients. 

Additionally, the mesentery is involved in the immune response. It contains lymph nodes and lymphatic vessels, which play a role in the body's immune response to pathogens that enter the digestive system. 

In summary, the mesentery is essential for maintaining the position of the intestines, supplying them with blood, nerves, and lymphatic vesse

KeyboardInterrupt: 

In [81]:
# Tracking performance
start = time.time()
correct = 0

print("===== Answers =====")

# Iterate over the dataset and generate text using speculative decoding
for i, item in enumerate(key_dataset):
    input_message = item[0]["content"]
    generated_text = generate_with_speculative_sampling(
        input_message, tokenizer, small_model, large_model, generation_args
    )
    
    # Cleaning the generated text to match only the answer format
    answer = generated_text.lstrip().replace("\n", "")
    
    # Get the correct answer from data
    correct_answer = data[i]["answer"]
    
    # Print the answer
    print(f"Generated: {answer}")
    
    # Check if the generated answer matches the correct answer
    if answer == correct_answer:
        correct += 1

end = time.time()

# Print the performance results
print("===== Perf result =====")
print("Elapsed_time: ", end - start)
print(f"Correctness: {correct}/{len(data)}")

===== Answers =====


KeyboardInterrupt: 