In [24]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
import torch.nn.functional as F

: 

In [2]:
import random

def generate_multiplication_questions_and_answers(num_range=range(10, 100), num_questions=100, seed=0):
    random.seed(seed)

    number_list = list(num_range)

    questions = []
    answers = []

    for i in range(num_questions):
        num1 = random.choice(number_list)
        num2 = random.choice(number_list)

        question = f"What is {num1} x {num2} = ?"
        answer = num1 * num2

        questions.append(question)
        answers.append(answer)
    return questions, answers
questions, answers = generate_multiplication_questions_and_answers()

In [3]:
# Load the PEFT configuration

name = "cot_small"

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
base_model_ = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda:0")
peft_model_path = f"../results/20240722/traj_{name}_x0_squad_ep150"
config = PeftConfig.from_pretrained(peft_model_path)

# Load the PEFT model
peft_model = PeftModel.from_pretrained(base_model_, peft_model_path)

with open(f"../data/{name}_x0.md", 'r') as f:
    system = f.read()

print("PEFT model loaded successfully.")
print(system)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

PEFT model loaded successfully.
Let's think step by step.


In [4]:
with open("declaration.txt", 'r') as f:
    declaration = f.read()

In [5]:
import re

single_input_text = lambda system, user: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>{user}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"
second_input_text = lambda prev_convo, user2: f"{prev_convo}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>{user2}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"

@ torch.no_grad()
def simulate_conversation(model, system, user1, user2_list):
    batch_input_text_user1 = [single_input_text(system, user1) for _ in range(len(user2_list))]
    full_input_ids_user1 = tokenizer(batch_input_text_user1, return_tensors="pt").to(peft_model.device)

    user1_output = model.generate(
                    **full_input_ids_user1,
                    max_length=peft_model.config.max_position_embeddings, 
                    num_return_sequences=1, 
                    do_sample=True,
                    temperature=1.0,
                    pad_token_id=tokenizer.eos_token_id,
                )
    
    
    user_1_output_str = tokenizer.batch_decode(user1_output, skip_special_tokens=False)

    batch_input_text_user2 = []
    for x, y in zip(user_1_output_str, user2_list):
        i = [m.start() for m in re.finditer('\<\|eot_id\|\>', x)]
        if len(i) >= 3:
            x = x[:i[2]]
        batch_input_text_user2.append(second_input_text(x, y))

    full_input_ids_user2 = tokenizer(batch_input_text_user2, return_tensors="pt", padding=True).to(peft_model.device)
    user2_output = model.generate(
                    **full_input_ids_user2,
                    max_length=peft_model.config.max_position_embeddings, 
                    num_return_sequences=1, 
                    do_sample=True,
                    temperature=1.0,
                    pad_token_id=tokenizer.eos_token_id,
                )
    return user2_output

In [19]:
system = ""
user1 = ""
user2_list = ["What is 39 x 49 = ?"]
batch_input_text_user1 = [single_input_text(system, user1) for _ in range(len(user2_list))]
full_input_ids_user1 = tokenizer(batch_input_text_user1, return_tensors="pt").to(peft_model.device)
user1_output = peft_model.generate(
                    **full_input_ids_user1,
                    max_length=peft_model.config.max_position_embeddings, 
                    num_return_sequences=1, 
                    do_sample=True,
                    temperature=1.0,
                    pad_token_id=tokenizer.eos_token_id,
                )
    
    
user_1_output_str = tokenizer.batch_decode(user1_output, skip_special_tokens=False)

batch_input_text_user2 = []
for x, y in zip(user_1_output_str, user2_list):
    i = [m.start() for m in re.finditer('\<\|eot_id\|\>', x)]
    x = x[:i[2]]
    batch_input_text_user2.append(second_input_text(x, y))
full_input_ids_user2 = tokenizer(batch_input_text_user2, return_tensors="pt", padding=True).to(peft_model.device)

In [8]:
from tqdm import tqdm
import torch

batch_size = 10

chunk_size = 400

num_correct_original_model_no_system = [0] * (len(declaration)//chunk_size)
num_correct_baked_model_no_system = [0] * (len(declaration)//chunk_size)
num_correct_original_model_with_system = [0] * (len(declaration)//chunk_size)
num_correct_baked_model_with_system = [0] * (len(declaration)//chunk_size)

with torch.no_grad():
    for l, inf_length in enumerate(range(0, len(declaration), chunk_size)):
        print(inf_length)
        print(inf_length/len(declaration))
        print(num_correct_original_model_no_system)
        print(num_correct_baked_model_no_system)
        print(num_correct_original_model_with_system)
        print(num_correct_baked_model_with_system)
        for i in tqdm(range(0, len(questions), batch_size)):
            batch_questions = questions[i:i+batch_size]
            batch_answers = answers[i:i+batch_size]

            # original model no system prompt
            with peft_model.disable_adapter():
                original_model_no_system_output = simulate_conversation(peft_model, "", declaration[:inf_length], batch_questions)
            print("A")
            # baked model no system prompt
            baked_model_no_system_output = simulate_conversation(peft_model, "", declaration[:inf_length], batch_questions)
            print("B")
            # original model with system prompt
            with peft_model.disable_adapter():
                original_model_with_system_output = simulate_conversation(peft_model, system, declaration[:inf_length], batch_questions)
            print("C")
            # baked model with system prompt
            baked_model_with_system_output = simulate_conversation(peft_model, system, declaration[:inf_length], batch_questions)
            print("D")
            
            original_model_no_system_strs = tokenizer.batch_decode(original_model_no_system_output, skip_special_tokens=True)
            baked_model_no_system_strs = tokenizer.batch_decode(baked_model_no_system_output, skip_special_tokens=True)
            original_model_with_system_strs = tokenizer.batch_decode(original_model_with_system_output, skip_special_tokens=True)
            baked_model_with_system_strs = tokenizer.batch_decode(baked_model_with_system_output, skip_special_tokens=True)

            for (answer, original_model_no_system_str, baked_model_no_system_str, original_model_with_system_str, baked_model_with_system_str) in zip(batch_answers, original_model_no_system_strs, baked_model_no_system_strs, original_model_with_system_strs, baked_model_with_system_strs):
                if str(answer) in original_model_no_system_str:
                    num_correct_original_model_no_system[l] += 1

                if str(answer) in baked_model_no_system_str:
                    num_correct_baked_model_no_system[l] += 1

                if str(answer) in original_model_with_system_str:
                    num_correct_original_model_with_system[l] += 1

                if str(answer) in baked_model_with_system_str:
                    num_correct_baked_model_with_system[l] += 1

0
0.0
[0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0]


  0%|          | 0/10 [00:00<?, ?it/s]

A
B
C


 10%|█         | 1/10 [00:15<02:18, 15.37s/it]

D
A
B
C


 20%|██        | 2/10 [00:28<01:53, 14.23s/it]

D
A
B
C


 30%|███       | 3/10 [00:42<01:38, 14.02s/it]

D
A
B
C


 40%|████      | 4/10 [01:00<01:32, 15.37s/it]

D
A
B
C


 50%|█████     | 5/10 [01:13<01:13, 14.78s/it]

D
A
B
C


 60%|██████    | 6/10 [01:28<00:58, 14.66s/it]

D
A
B
C


 70%|███████   | 7/10 [01:44<00:45, 15.15s/it]

D
A
B
C


 80%|████████  | 8/10 [02:00<00:30, 15.32s/it]

D
A
B
C


 80%|████████  | 8/10 [02:10<00:32, 16.30s/it]


KeyboardInterrupt: 

In [9]:
print(num_correct_original_model_no_system)
print(num_correct_baked_model_no_system)
print(num_correct_original_model_with_system)
print(num_correct_baked_model_with_system)

[52, 0, 0, 0, 0, 0, 0, 0]
[64, 0, 0, 0, 0, 0, 0, 0]
[61, 0, 0, 0, 0, 0, 0, 0]
[68, 0, 0, 0, 0, 0, 0, 0]


In [7]:
print(num_correct_original_model_no_system)
print(num_correct_baked_model_no_system)
print(num_correct_original_model_with_system)
print(num_correct_baked_model_with_system)

[49, 0, 0, 0, 0, 0, 0, 0]
[49, 0, 0, 0, 0, 0, 0, 0]
[43, 0, 0, 0, 0, 0, 0, 0]
[51, 0, 0, 0, 0, 0, 0, 0]
