In [1]:
import sys
sys.path.append('..')

from preference_datasets import get_batch_iterator
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import datasets
import matplotlib.pyplot as plt
import random

In [2]:
lora_dir='/scratch/ssumathi/Re-tuning/ReTuning-main/cache/ssumathi/my_baseline_parity_13b_2024-11-12_19-58-07_057749/LATEST'

model = AutoModelForCausalLM.from_pretrained('huggyllama/llama-13b',torch_dtype=torch.float16,device_map='auto')
model = PeftModel.from_pretrained(model, lora_dir,offload_buffer=True)

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-13b")
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def generate_from_prompt(model,prompt,tokenizer):
    input_tok=tokenizer(prompt,add_special_tokens=False)
    input_ids=torch.LongTensor(input_tok['input_ids']).cuda()
    attention_mask=torch.LongTensor(input_tok['attention_mask']).cuda()
    tokenized_samples = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=2048, 
        do_sample=False, 
        pad_token_id=tokenizer.pad_token_id
    )
    return tokenizer.batch_decode(tokenized_samples,skip_special_tokens=True)

def generate_binary_list(n):
    binary_list = [random.choice([0, 1]) for _ in range(n)]
    return binary_list

In [4]:
prompt_template='What is the parity of {}?\nSolution: '
res_dict={}
eval_lengths=range(2,60)
for length in eval_lengths:
    print(f"evaluating length:{length}")
    num_right=0
    for _ in range(15):
        arr=generate_binary_list(length)
        out=generate_from_prompt(model,[prompt_template.format(arr)],tokenizer)
        if out[0].rstrip()[-1]==str(arr.count(1)%2):num_right+=1
    res_dict[length]=num_right
    print(res_dict)
print(res_dict)

evaluating length:2
{2: 6}
evaluating length:3
{2: 6, 3: 8}
evaluating length:4
{2: 6, 3: 8, 4: 7}
evaluating length:5
{2: 6, 3: 8, 4: 7, 5: 8}
evaluating length:6
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7}
evaluating length:7
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9}
evaluating length:8
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8}
evaluating length:9
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8}
evaluating length:10
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12}
evaluating length:11
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11}
evaluating length:12
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6}
evaluating length:13
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6, 13: 7}
evaluating length:14
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6, 13: 7, 14: 2}
evaluating length:15
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6, 13: 7, 14: 2, 15: 9}
evaluating length:16
{2: 6, 3: 8, 4: 7

{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6, 13: 7, 14: 2, 15: 9, 16: 6, 17: 13, 18: 12, 19: 9, 20: 6, 21: 9, 22: 6, 23: 3, 24: 5, 25: 7, 26: 7, 27: 5, 28: 8, 29: 8, 30: 7, 31: 9, 32: 4, 33: 6, 34: 6, 35: 8, 36: 6, 37: 7, 38: 8, 39: 7, 40: 9, 41: 5, 42: 9, 43: 7, 44: 4, 45: 7, 46: 5, 47: 7, 48: 7}
evaluating length:49
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6, 13: 7, 14: 2, 15: 9, 16: 6, 17: 13, 18: 12, 19: 9, 20: 6, 21: 9, 22: 6, 23: 3, 24: 5, 25: 7, 26: 7, 27: 5, 28: 8, 29: 8, 30: 7, 31: 9, 32: 4, 33: 6, 34: 6, 35: 8, 36: 6, 37: 7, 38: 8, 39: 7, 40: 9, 41: 5, 42: 9, 43: 7, 44: 4, 45: 7, 46: 5, 47: 7, 48: 7, 49: 9}
evaluating length:50
{2: 6, 3: 8, 4: 7, 5: 8, 6: 7, 7: 9, 8: 8, 9: 8, 10: 12, 11: 11, 12: 6, 13: 7, 14: 2, 15: 9, 16: 6, 17: 13, 18: 12, 19: 9, 20: 6, 21: 9, 22: 6, 23: 3, 24: 5, 25: 7, 26: 7, 27: 5, 28: 8, 29: 8, 30: 7, 31: 9, 32: 4, 33: 6, 34: 6, 35: 8, 36: 6, 37: 7, 38: 8, 39: 7, 40: 9, 41: 5, 42: 9, 43: 7, 44: 4, 4