In [1]:
import sys
sys.path.append('..')

from preference_datasets import get_batch_iterator
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import datasets
import matplotlib.pyplot as plt
import random

In [2]:
lora_dir='/scratch/ssumathi/Re-tuning/ReTuning-main/cache/ssumathi/my_scrpad_parity_7b_2024-11-13_15-25-40_989651/LATEST'

model = AutoModelForCausalLM.from_pretrained('huggyllama/llama-7b',torch_dtype=torch.float16,device_map='auto')
model = PeftModel.from_pretrained(model, lora_dir,offload_buffer=True)

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
def generate_from_prompt(model,prompt,tokenizer,max_length,temperature):
    input_tok=tokenizer(prompt,add_special_tokens=False)
    input_ids=torch.LongTensor(input_tok['input_ids']).cuda()
    attention_mask=torch.LongTensor(input_tok['attention_mask']).cuda()
    tokenized_samples = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=2048, 
        do_sample=True, 
        temperature=.6, 
        pad_token_id=tokenizer.pad_token_id
    )
    return tokenizer.batch_decode(tokenized_samples,skip_special_tokens=True)

def generate_binary_list(n):
    binary_list = [random.choice([0, 1]) for _ in range(n)]
    return binary_list


In [5]:
prompt_template='What is the parity of {}?\nSolution: '
res_dict={}
eval_lengths=range(2,60)
for length in eval_lengths:
    num_right=0
    for _ in range(5):
        arr=generate_binary_list(length)
        out=generate_from_prompt(model,[prompt_template.format(arr)],tokenizer,max_length=512,temperature=.01)
        if out[0].split(' ')[-1]==str(arr.count(1)%2):num_right+=1
    res_dict[length]=num_right
    print(res_dict)
print(res_dict)

{31: 3}
{31: 3, 32: 3}
{31: 3, 32: 3, 33: 4}
{31: 3, 32: 3, 33: 4, 34: 3}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2, 41: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2, 41: 2, 42: 3}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2, 41: 2, 42: 3, 43: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2, 41: 2, 42: 3, 43: 2, 44: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2, 41: 2, 42: 3, 43: 2, 44: 2, 45: 2}
{31: 3, 32: 3, 33: 4, 34: 3, 35: 3, 36: 2, 37: 2, 38: 3, 39: 4, 40: 2, 41: 2, 42: 3, 43: 2, 44: 2, 45: 2, 46: 4}
{31: 3, 32: 3, 33: 4, 34: 3, 35: