In [1]:
import os
import time

import numpy as np
import torch

from datasets import load_dataset
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer, 
                          Phi3ForCausalLM,
                          Phi3Config,
                          pipeline,
                          )
from transformers.pipelines.pt_utils import KeyDataset

import utils

config = utils.get_config()
utils.seed_everything(config.seed)
CURR_PATH = os.getcwd()

In [2]:
# class Phi3LLMLinguaForCausalLM(Phi3ForCausalLM):
#     def __init__(self, config:Phi3Config, tokenizer, compressor):
#         super().__init__(config)
#         self.tokenizer = tokenizer
#         self.compressor = compressor
    
#     def generate(
#         self,
#         inputs: Optional[torch.Tensor] = None,
#         generation_config: Optional[GenerationConfig] = None,
#         logits_processor: Optional[LogitsProcessorList] = None,
#         stopping_criteria: Optional[StoppingCriteriaList] = None,
#         prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
#         synced_gpus: Optional[bool] = None,
#         assistant_model: Optional["PreTrainedModel"] = None,
#         streamer: Optional["BaseStreamer"] = None,
#         negative_prompt_ids: Optional[torch.Tensor] = None,
#         negative_prompt_attention_mask: Optional[torch.Tensor] = None,
#         **kwargs,
# )
    

In [3]:
####### Section 1. Set up #######
model_id = os.path.join(CURR_PATH, config.model_path, "Phi-3-medium-4k-instruct") # please replace with local model path

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
    # attn_implementation="flash_attention_2" # Running flash attention
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    # batch_size=,
)
 
generation_args = {
    "max_new_tokens": 500,
    "return_full_text": False,
    "temperature": 0.0,
    "do_sample": False,
}

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [4]:
####### Section 2. GPU Warm up #######
messages = [
    {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
    {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
    {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
]
output = pipe(messages, **generation_args)
print(output[0]['generated_text'])

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
You are not running the flash-attention implementation, expect numerical differences.


 To solve the equation 2x + 3 = 7, you need to isolate the variable x. Here's how you can do it step by step:

1. Subtract 3 from both sides of the equation:
   2x + 3 - 3 = 7 - 3
   2x = 4

2. Divide both sides of the equation by 2:
   2x / 2 = 4 / 2
   x = 2

So, the solution to the equation 2x + 3 = 7 is x = 2.


In [5]:
####### Section 3. Prompt Compression #######
import json
import re
from llmlingua import PromptCompressor

COMPRESSOR_MAP = {
   "gpt2" : "lgaalves/gpt2-dolly",
   "phi2": "microsoft/phi-2",
   "llama2": "TheBloke/Llama-2-7b-Chat-GPTQ",
   "roberta" : "microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
}
compressor_id = COMPRESSOR_MAP["roberta"]

llm_lingua = PromptCompressor(compressor_id, device_map="auto", use_llmlingua2=True)

# def compress(content):
#     data = content.split('\n\n')
#     instruction = data[0]
#     question = data[1].split('\n')[0]
#     choice = data[1].split('\n')[1]
#     compressed_prompt = llm_lingua.compress_prompt(
#         context=[instruction, question],
#         instruction="",
#         question="",
#         rate=0.75,
#         iterative_size=100,
#         context_budget="*2",
#     )
#     return compressed_prompt['compressed_prompt'] + '\n' + choice

def compress(content):
    data = content.split('\n\n')
    instruction = '\n' + data[0]
    ddata = data[1].split('\n')
    question = ddata[0]
    choice = ddata[1]
    
    compressed_prompt = llm_lingua.compress_prompt(
        context="\n".join([instruction, question]),
        rate=0.7,
        force_tokens=["\n", "?", ".", "," "question:"],
    )

    compressed_prompt = re.sub(r'\n\s+', '\n', compressed_prompt['compressed_prompt'])
    # .replace('question:', '\nquestion:')
    return f"{compressed_prompt}\n{choice}\n"
    
def compressed_jsonl(input_file_path, output_file_path):
    with open(input_file_path, 'r') as reader:
        with open(output_file_path, 'w') as writer:
            for line in reader:
                line = json.loads(line)
                line["message"][0]["content"] = compress(line["message"][0]["content"])
                
                json.dump(line, writer)
                writer.write("\n")

input_file_path = os.path.join(CURR_PATH, config.data_path, "test_dataset.jsonl")
output_file_path = os.path.join(CURR_PATH, config.data_path, "test_dataset_compressed.jsonl")
compressed_jsonl(input_file_path, output_file_path)
# data.map(lambda x: print(x['message'][0]['content']))

In [6]:
####### Section 3. Load data and Inference -> Performance evaluation part #######
start = time.time()
# data = load_dataset("json", data_files=input_file_path)['train']
data = load_dataset("json", data_files=output_file_path)['train']
outs = pipe(KeyDataset(data, 'message'), **generation_args)
end = time.time()

Generating train split: 0 examples [00:00, ? examples/s]

In [7]:
####### Section 4. Accuracy (Just for leasderboard) #######
print("===== Answers =====")
correct = 0
for i, out in enumerate(outs):
    correct_answer = data[i]["answer"]
    answer = out[0]["generated_text"].lstrip().replace("\n","")
    if answer == correct_answer:
        correct += 1
    print(answer)
 
print("===== Perf result =====")
print("Elapsed_time: ", end-start)
print(f"Correctness: {correct}/{len(data)}")

===== Answers =====
Deep sea animals
uses what it needs
they are genetically called to
south
A snail moving across the sidewalk
protozoa
Green house
it unfreezes, because it is cold-blooded
It is a sphere
fluid spreads from pores
July
speaking with a witness
shell
the final barrel is gone, there supply is finished
particles of iron
H2O haze
constellations to appear in one place in spring and another in fall
glucose
help prevent the effects of erosion
wind
salvage plastic bottles instead of throwing them away
less energy used by the water heater
people driving cars might be unaware the animal is close by
light from our closest star
the darkness is greatest
Water
clothing
a cut peony
an alligator's habitat
has seeds outside the flesh, unlike the blueberry
===== Perf result =====
Elapsed_time:  9.633576154708862
Correctness: 24/30
