In [None]:
import sys
sys.path.append('..')

from preference_datasets import get_batch_iterator
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import datasets
import matplotlib.pyplot as plt
import random

In [None]:
lora_dir='/scratch/ssumathi/Re-tuning/ReTuning-main/cache/ssumathi/my_baseline_add_13b_2024-11-12_20-54-13_323235/LATEST'

model = AutoModelForCausalLM.from_pretrained('huggyllama/llama-13b',torch_dtype=torch.float16,device_map='auto')
model = PeftModel.from_pretrained(model, lora_dir)

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-13b")
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
def generate_from_prompt(model,prompt,tokenizer):
    input_tok=tokenizer(prompt,add_special_tokens=False)
    input_ids=torch.LongTensor(input_tok['input_ids']).cuda()
    attention_mask=torch.LongTensor(input_tok['attention_mask']).cuda()
    tokenized_samples = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=2048, 
        do_sample=False, 
        pad_token_id=tokenizer.pad_token_id
    )
    return tokenizer.batch_decode(tokenized_samples,skip_special_tokens=True)

In [None]:
res_dict = {}
eval_lengths = range(2, 60)

for num_digits in eval_lengths:
    print(f"Evaluating for length: {num_digits} digits")
    num_right = 0
    for i in range(100):
        a = random.randint(10**(num_digits - 1), 10**num_digits - 1)
        b = random.randint(10**(num_digits - 1), 10**num_digits - 1)
        sample = generate_from_prompt(model, [f'{a} + {b}\nSolution: '], tokenizer)
        predicted_sum = sample[0].split('\nSolution: ')[-1].strip()
        if predicted_sum == str(a + b):
            num_right += 1
    
    res_dict[num_digits] = num_right
    print(res_dict)
print("Final Results:", res_dict)


In [None]:
import matplotlib.pyplot as plt

results = {2: 100, 3: 100, 4: 100, 5: 99, 6: 99, 7: 100, 8: 99, 9: 99, 10: 100, 11: 99, 12: 99, 
           13: 99, 14: 98, 15: 96, 16: 88, 17: 41, 18: 27, 19: 12, 20: 9, 21: 2, 22: 0, 23: 0, 
           24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0, 
           36: 0, 37: 0, 38: 0, 39: 0, 40: 0, 41: 0, 42: 0, 43: 0, 44: 0, 45: 0, 46: 0, 47: 0, 
           48: 0, 49: 0, 50: 0, 51: 0, 52: 0, 53: 0, 54: 0, 55: 0, 56: 0, 57: 0, 58: 0, 59: 0}

lengths = list(results.keys())
accuracies = [value / 100 for value in results.values()]

plt.figure(figsize=(10, 6))
plt.plot(lengths, accuracies, marker='o', linestyle='-')
plt.xlabel("Length")
plt.ylabel("Accuracy")
plt.ylim(0.0, 1.0)
plt.title("Baseline adddition with Llama 13B")
plt.grid(True)
plt.show()
