In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from peft import PeftModel, PeftConfig
from util import task_to_prompt

In [2]:
with open("prompt-template.txt") as f:
    template = f.read()

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

In [4]:
peft_model_id = "finetuned-models/Starling-LM-7B-alpha-finetuned"

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, 
                                                    quantization_config=bnb_config, 
                                                    return_dict=True, 
                                                    load_in_4bit=True, 
                                                    device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
def generate_response(prompt, max_new_tokens=512, temperature=0.5):
    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
    print("token length:", len(input_ids["input_ids"][0]))
    outputs = model.generate(**input_ids, do_sample=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, max_new_tokens=max_new_tokens)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [6]:
prompt, output = task_to_prompt("data/evaluation/1a2e2828.json")

In [7]:
prompt = f"GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:"

In [14]:
print(generate_response(prompt, max_length=2048, temperature=0))

Both `max_new_tokens` (=1024) and `max_length`(=2048) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


token length: 1917


ValueError: `temperature` (=0) has to be a strictly positive float, otherwise your next token scores will be invalid.

In [13]:
print(output)

[[7]]
