In [1]:
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
# from peft import PeftModel
from openai import OpenAI
import requests
import json
from tqdm import tqdm
import pickle
# import numpy as np
from torch.utils.data import Dataset, DataLoader
# from vllm import LLM, SamplingParams
# import openai
import time
import tensorboard
import pickle
import math
import json
import pandas as pd

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_printoptions(threshold=float('inf'))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# %pip install --upgrade "jinja2>=3.1.0"
# %pip install datasets tensorboard openai tqdm
# %pip install transformers

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# load tokenizer and model
student_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side='right')
student = AutoModelForCausalLM.from_pretrained("./dpo_model", torch_dtype=torch.float32)

# add special tokens
student_tokenizer.add_special_tokens({
    'pad_token': '<|pad|>',
    'bos_token': '<|im_start|>',
    'eos_token': '<|im_end|>',
})

# resize model embeddings to include new tokens
student.resize_token_embeddings(len(student_tokenizer))
# set token ids in config
student.config.pad_token_id = student_tokenizer.pad_token_id
student.config.bos_token_id = student_tokenizer.bos_token_id
student.config.eos_token_id = student_tokenizer.eos_token_id


In [4]:
chat_template = (
    "{% set image_count = namespace(value=0) %}"
    "{% set video_count = namespace(value=0) %}"
    "{% for message in messages %}"
    "{% if loop.first and message['role'] != 'system' %}"
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
    "{% endif %}"
    "<|im_start|>{{ message['role'] }}\n"
    "{% if message['content'] is string %}"
    "{% if message['role'] == 'assistant' %}"
    "{% generation %}"
    "{{ message['content'] }}"
    "{% endgeneration %}"
    "{% else %}"
    "{{ message['content'] }}"
    "{% endif %}"
    "<|im_end|>\n"
    "{% else %}"
    "{% for content in message['content'] %}"
    "{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}"
    "{% set image_count.value = image_count.value + 1 %}"
    "{% if add_vision_id %}"
    "Picture {{ image_count.value }}: "
    "{% endif %}"
    "<|vision_start|><|image_pad|><|vision_end|>"
    "{% elif content['type'] == 'video' or 'video' in content %}"
    "{% set video_count.value = video_count.value + 1 %}"
    "{% if add_vision_id %}"
    "Video {{ video_count.value }}: "
    "{% endif %}"
    "<|vision_start|><|video_pad|><|vision_end|>"
    "{% elif 'text' in content %}"
    "{% if message['role'] == 'assistant' %}"
    "{% generation %}"
    "{{ content['text'] }}"
    "{% endgeneration %}"
    "{% else %}"
    "{{ content['text'] }}"
    "{% endif %}"
    "{% endif %}"
    "{% endfor %}"
    "<|im_end|>\n"
    "{% endif %}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "<|im_start|>assistant\n"
    "{% endif %}")

In [5]:
BATCH_SIZE = 16
NUM_EPOCHS = 10
MAX_TOKENS = 512

MAX_INPUT_TOKENS = 512
ACCUM_STEPS = 1

In [6]:
class CLDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length = MAX_INPUT_TOKENS):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        messages = self.dataset[idx]

        new_messages = []
        for m in messages:
            if not new_messages and m["role"] == "user":
                new_messages.append(m)
            elif new_messages and m["role"] == "assistant":
                new_messages.append(m)
                break
        if len(new_messages) != 2:
            return self._empty_item()
        
        try:
            tokenized = self.tokenizer.apply_chat_template(
                new_messages,
                tokenize = True,
                max_length = self.max_length,
                padding = 'max_length',
                truncation = True,
                return_dict = True,
                return_assistant_tokens_mask=True,
                add_generation_prompt = False,
                chat_template = chat_template,
                return_tensors = 'pt'
            )

        except Exception as e:
            print(idx, e)
            return self._empty_item()

        input_ids = tokenized['input_ids']
        assistant_masks = tokenized['assistant_masks']
        if assistant_masks.sum() == 0:
            return self._empty_item()

        mod_assistant_mask = assistant_masks.clone()
        matches = (input_ids == self.tokenizer.convert_tokens_to_ids("<|im_end|>"))
        indices = torch.nonzero(matches)
        mod_assistant_mask[tuple(indices[-1])] = 1 # inlcude end speaking token in assistant to include in lables

        attention_mask = tokenized['attention_mask']

        labels = input_ids.clone()
        labels[mod_assistant_mask == 0] = -100

        return {
            'input_ids': input_ids.squeeze(0),
            'attention_mask': attention_mask.squeeze(0),
            'labels': labels.squeeze(0)
        }
    
    def _empty_item(self):
        return {
            'input_ids': torch.zeros(self.max_length, dtype=torch.int32),
            'attention_mask': torch.zeros(self.max_length, dtype=torch.int32),
            'labels': torch.full((self.max_length,), -100, dtype=torch.int32)
        }

In [7]:
def student_generate_batch(batch_size, prompts, model, tokenizer):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side = 'left')
    outputs_list = []

    for i in tqdm(range(0, len(prompts), batch_size)):
        batch = prompts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True)
        
        output_sequences = model.generate(
            input_ids=inputs['input_ids'].to(model.device),
            attention_mask=inputs['attention_mask'].to(model.device),
            tokenizer = tokenizer,
            do_sample=False, # disable sampling to test if batching affects output
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            forced_eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.12,
            stop_strings = '<|im_end|>',
            exponential_decay_length_penalty = (int(MAX_TOKENS * 0.7),1.1),
            max_new_tokens= MAX_TOKENS
        )
        completions_only = output_sequences[:, inputs['input_ids'].shape[1]:]
        outputs_decoded = tokenizer.batch_decode(completions_only, skip_special_tokens=True)
        # print(output_completions)
        # print(output_sequences)
        outputs_list.extend(outputs_decoded)
    return outputs_list

In [8]:
def truncate_prompt(prompt, max_input_tokens=600):
    tokens = student_tokenizer(prompt)["input_ids"]
    if len(tokens) > max_input_tokens:
        tokens = tokens[:max_input_tokens]
        return student_tokenizer.decode(tokens, skip_special_tokens=True)
    return prompt

In [9]:
client = OpenAI(api_key="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")

def teacher_generate_batch(prompts, model="o4-mini-2025-04-16", system_prompt="You are a helpful assistant."):
    # Step 1: Write prompts to JSONL
    input_path = "batch_input.jsonl"
    with open(input_path, "w") as f:
        for i, prompt in enumerate(prompts):
            print("Idx ", i)
            trunc = truncate_prompt(prompt)
            # time.sleep(3)
            # print(prompt)
            item = {
                "custom_id": f"request-{i}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": trunc}
                    ],
                    "max_tokens": MAX_TOKENS,
                    "temperature": 0.0,
                    "stop": ["<|im_end|>"],
                    "frequency_penalty": 1.5,
                    "presence_penalty": 0.0
                }
            }
            f.write(json.dumps(item) + "\n")

    print("Wrote in prompts to json.")
    # Step 2: Upload the file using OpenAI client
    with open(input_path, "rb") as f:
        upload = client.files.create(file=f, purpose="batch")
    file_id = upload.id

    # Step 3: Submit batch
    batch = client.batches.create(
        input_file_id=file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h"
    )
    batch_id = batch.id
    print("Batch submitted. ID:", batch_id)

    # Step 4: Poll until complete
    print("Waiting for batch to complete...")
    while True:
        batch_status = client.batches.retrieve(batch_id)
        status = batch_status.status
        # print(f"Current status: {status}")
        if status in ["completed", "failed", "cancelled", "expired"]:
            break
        time.sleep(15)

    if status != "completed":
        raise RuntimeError(f"Batch failed or didn't complete: {status}")

    # Step 5: Download result file
    output_file_id = batch_status.output_file_id
    output_response = client.files.content(output_file_id)
    output_data = output_response.text

    # Step 6: Parse results into dictionary
    responses = {}
    for line in output_data.splitlines():
        obj = json.loads(line)
        custom_id = obj["custom_id"]
        content = obj["response"]["body"]["choices"][0]["message"]["content"]
        responses[custom_id] = content

    # Step 7: Return outputs in order
    return [responses[f"request-{i}"] for i in range(len(prompts))]


In [10]:
def load_data(train_raw):
    formatted_data = [[{"content": str(item["x"]), "role": "user"},
                       {"content": str(item["y"]), "role": "assistant"}] for item in train_raw ]
    return formatted_data

In [11]:
def load_data_from_list(prompt, completion):
    data = [[
        {'content': p, 'role': 'user'},
        {'content': c, 'role': 'assistant'}
    ] for p, c in zip(prompt, completion)]
    return data

In [12]:
def create_to_revise(x, c, r_0):
    prompt = (
        f"Below is an instruction and my initial response. A criteria for evaluating the response is also provided.\n\n"
        f"Instruction:\n{x}\n\n"
        f"My Initial Response:\n{r_0}\n\n"
        f"Criteria: {c}\n\n"
        f"My initial response may be incorrect and may not follow the criteria. Please revise it using the ideal response as a guide and the criteria for improvement. "
        f"Return only the revised answer, without any additional comments or explanation."
    )
    return prompt

In [13]:
def get_revisions(r_0_list, raw_data):
    
    # student.eval()
    # with torch.no_grad():
    #     r_0_list = student_generate_batch(batch_size=BATCH_SIZE, prompts=prompts, model=student, tokenizer=student_tokenizer)

    revised_prompts = []
    revisions = []
    for item, r_0 in zip(raw_data, r_0_list):
        x, y, c = item['x'], item['y'], item['c']
        revised_prompt = create_to_revise(x, c, r_0)
        revised_prompts.append(revised_prompt)
        # revisions.append(y)
        
    # print(len(revised_prompts))

    # # 2. teacher revises (no grad)
    # for i in range(0, len(revised_prompts), 500):
    #     print("\n\n batch: ", i)
    #     batched_prompts = revised_prompts[i:i+500]
    #     revisions.extend(teacher_generate_batch(batched_prompts, model="gpt-4o-mini", system_prompt="You are an expert writer."))
    #     time.sleep(180)
    # # revisions = teacher_generate_batch(revised_prompts, model="gpt-4o-mini", system_prompt="You are an expert writer.")

    # print(revised_prompts[332], revisions[332])
    return revised_prompts #, revisions

In [14]:
from torch.utils.tensorboard import SummaryWriter

def train(model, train_loader, optimizer, writer, epoch):
    accumulation_steps=16

    total_loss = 0
    batch_times = []
    progress = tqdm(train_loader, desc=f"Training Epoch {epoch}", leave=True)

    optimizer.zero_grad()

    for i, batch in enumerate(progress):
        start = time.time()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        if torch.all(labels == -100) or torch.all(input_ids == 0):
            print(f"⏭️  Skipping empty batch {i}")
            continue

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        # loss = outputs.loss  # normalize loss for accumulation
        loss = outputs.loss / accumulation_steps
                
        if torch.isnan(loss):
            continue
        # assert not math.isnan(loss.item()), f'Loss: {loss}, \nOutputs: {outputs}'
        loss.backward()
    
        if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
            torch.cuda.synchronize()
            optimizer.step()
            torch.cuda.synchronize()
            optimizer.zero_grad()
        # torch.cuda.synchronize()
        # optimizer.step()
        # torch.cuda.synchronize()
        # optimizer.zero_grad()
        total_loss += loss.item() * accumulation_steps  
        # total_loss += loss.item()  
        avg_loss = total_loss / (i + 1)
        writer.add_scalar("Loss/train", avg_loss, epoch * len(train_loader) + i)


        batch_time = time.time() - start
        batch_times.append(batch_time)
        avg_time = sum(batch_times) / len(batch_times)
        eta = avg_time * (len(train_loader) - (i + 1))
        eta_hr, remainder = divmod(int(eta), 3600)
        eta_min, eta_sec = divmod(remainder, 60)

        # progress.set_postfix(loss=[loss.item(), avg_loss], eta=f"{eta_hr}h {eta_min}m {eta_sec}s")
        progress.set_postfix(loss=[loss.item() * accumulation_steps, avg_loss], eta=f"{eta_hr}h {eta_min}m {eta_sec}s")


        if i % 1000 == 0 and i != 0:
            torch.cuda.empty_cache()
            model.save_pretrained('./latest_model')
            with open('latest_opt.pkl', 'wb') as f:
                pickle.dump(optimizer, f)

    avg_loss = total_loss / len(train_loader)

    return float(avg_loss)


def test(model, test_loader, writer, epoch, return_generations=False):
    total_loss = 0
    batch_times = []
    progress = tqdm(test_loader, desc=f"Testing Epoch {epoch}", leave=True)

    for i, batch in enumerate(progress):
        start = time.time()
        
        input_ids = batch['input_ids'].long().to(device)
        attention_mask = batch['attention_mask'].long().to(device)
        labels = batch['labels'].long().to(device)
        
        if torch.all(labels == -100) or torch.all(input_ids == 0):
            print(f"⏭️  Skipping empty batch {i}")
            continue

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        if math.isnan(loss):
            print("NAN loss")
            continue
        
        total_loss += loss.item()
        avg_loss = float(total_loss) / (i + 1)


        batch_time = time.time() - start
        batch_times.append(batch_time)
        avg_time = sum(batch_times) / len(batch_times)
        eta = avg_time * (len(test_loader) - (i + 1))
        eta_hr, remainder = divmod(int(eta), 3600)
        eta_min, eta_sec = divmod(remainder, 60)

        progress.set_postfix(loss=[loss.item(), avg_loss], eta=f"{eta_hr}h {eta_min}m {eta_sec}s")
        
        if (i + 1) % 10 == 0:
            model.save_pretrained('./checkpoints/latest_step')
        if i % 1000 == 0 and i != 0:
            torch.cuda.empty_cache()

    avg_loss = total_loss / len(test_loader)
    writer.add_scalar("Loss/val", avg_loss, epoch)
    return float(avg_loss)


def fine_tune(model, train_loader, test_loader, optimizer, num_epochs):
    writer = SummaryWriter()  

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, writer, epoch)
        # model.save_pretrained('./citing_model')
        val_loss = test(model, test_loader, writer, epoch)
        print(f'Epoch: {epoch}. Train Loss: {train_loss}. Val Loss: {val_loss}.')
        
    # model.save_pretrained('./citing_model')

    writer.close()
    return train_loss, val_loss


In [15]:
def run_CITING(train_raw, r_0_list, train_completions, test_prompts, test_completions, student, tokenizer, num_epochs):
    optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, student.parameters()), lr=1e-7)
    
    train_prompts = get_revisions(r_0_list, train_raw)
    
    train_set_loaded = load_data_from_list(train_prompts, train_completions)
    train_set_CL = CLDataset(train_set_loaded, tokenizer)
    train_loader = DataLoader(train_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    test_set_loaded = load_data_from_list(test_prompts, test_completions)
    test_set_CL = CLDataset(test_set_loaded, tokenizer)
    test_loader = DataLoader(test_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    train_loss, val_loss = fine_tune(
        model = student,
        train_loader = train_loader,
        test_loader = test_loader,
        optimizer = optim,
        num_epochs = num_epochs
    )

    return train_loss, val_loss

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student.to(device)
print(student.device)

cuda:0


In [17]:
with open("train_short.json", "r") as f:
    train_short = json.load(f)
    train_short = train_short[:2000]
with open("train_med.json", "r") as f:
    train_med = json.load(f)
    train_med = train_med[:2000]
with open("train_long.json", "r") as f:
    train_long = json.load(f)
    train_long = train_long[:2000]
with open("test.json", "r") as f:
    test_raw = json.load(f)
    
train_short_dataloaded = load_data(train_short)
train_med_dataloaded = load_data(train_med)
train_long_dataloaded = load_data(train_long)
test_dataloaded = load_data(test_raw)

# UR_train_short_prompts = [item[0]["content"] for item in train_short_dataloaded]
# UR_train_med_prompts = [item[0]["content"] for item in train_med_dataloaded]
# UR_train_long_prompts = [item[0]["content"] for item in train_long_dataloaded]
# UR_test_prompts = [item[0]["content"] for item in test_dataloaded]

with open("short_initial_responses.json", "r") as f:
    short_r_0_list = json.load(f)
    short_r_0_list = short_r_0_list[:2000]
with open("med_initial_responses.json", "r") as f:
    med_r_0_list = json.load(f)
    med_r_0_list = med_r_0_list[:2000]
with open("long_initial_responses.json", "r") as f:
    long_r_0_list = json.load(f)
    long_r_0_list = long_r_0_list[:2000]
# with open("test_initial_responses.json", "r") as f:
#     test_r_0_list = json.load(f)
#     test_r_0_list = test_r_0_list[:1000]


with open("short_revisions.json", "r") as f:
    short_completions = json.load(f)
with open("med_revisions.json", "r") as f:
    med_completions = json.load(f)
with open("long_revisions.json", "r") as f:
    long_completions = json.load(f)
    
    
with open("short_revisions_2.json", "r") as f:
    short_completions_2 = json.load(f)
with open("med_revisions_2.json", "r") as f:
    med_completions_2 = json.load(f)
with open("long_revisions_2.json", "r") as f:
    long_completions_2 = json.load(f)
    
short_completions.extend(short_completions_2) 
med_completions.extend(med_completions_2) 
long_completions.extend(long_completions_2) 

In [18]:
# Load LLAMA Y for teacher
# test_x = [item["x"] for item in test_raw]
# crit = [item["c"] for item in test_raw]
# test_prompts = [create_to_revise(x, c, r_0) for x, c, r_0 in zip(test_x, crit, test_r_0_list)]
# test_completions = [item["y"] for item in test_raw]


# Use GPT as teacher
# test_prompts, test_completions = get_revisions(test_r_0_list, test_raw)
# with open("test_prompts.pkl", "wb") as f:
#     pickle.dump(test_prompts, f)
# with open("test_completions.pkl", "wb") as f:
#     pickle.dump(test_completions, f)

with open("test_prompts.pkl", "rb") as f:
    test_prompts = pickle.load(f)
with open("test_completions.pkl", "rb") as f:
    test_completions = pickle.load(f)

In [19]:
test_prompts[499]

'Below is an instruction and my initial response. A criteria for evaluating the response is also provided.\n\nInstruction:\nCan you draw an SVG logo of a tree?\n\nMy Initial Response:\n I want to know how to create it.\nSure, here\'s the SVG code for a simple tree:\n```\n<svg width="100" height="100">\n  <circle cx="50" cy="50" r="40" fill="green"/>\n</svg>\nThis SVG code will generate a basic tree shape with a green circle at its center. You can adjust the size and color of the circle as needed to fit your design requirements. \nAlternatively, if you prefer a more detailed representation, you could use this code in conjunction with other elements such as text or images:\n```xml\n<svg xmlns="http://www.w3.org/2000/svg" version="1.1">\n  <!-- Tree -->\n  <g transform="translate(50 50)">\n    <circle cx="50" cy="50" r="40" fill="green"/>\n    <!-- Add branches and leaves as necessary -->\n  </g>\n</svg>\nIn this case, the tree would be represented by the circle itself, while the branches

In [20]:
train_short

[{'x': 'Generate a positive tweet',
  'y': 'Sure, I\'d be happy to help! Here\'s a positive tweet:\n\n"The world needs more kindness and compassion! Let\'s all work together to create a more inclusive and equitable society for everyone 🌎💖 #positivity #kindnessmatters #socialjustice"\n\nI hope this tweet brings a smile to your face and inspires you to spread some positivity in your own life and community! Remember, every small act of kindness can make a big difference 💪🌈',
  'c': 'The response should accurately and succinctly label or classify the input based on context and examples provided, using the correct label or answer format requested.'},
 {'x': 'Who are you?',
  'y': 'Greetings! I am LLaMA, an AI assistant developed by Meta AI that can understand and respond to human input in a conversational manner. My primary objective is to provide helpful and accurate information to the best of my ability. I am trained on a massive dataset of text from the internet and can generate human-li

In [21]:
import random

shuffled_raw = train_short + train_med + train_long
random.shuffle(shuffled_raw)
shuffled_r_0 = short_r_0_list + med_r_0_list + long_r_0_list
random.shuffle(shuffled_r_0)
shuffled_completions = short_completions + med_completions + long_completions
random.shuffle(shuffled_completions)

In [22]:
train_loss, val_loss = run_CITING(shuffled_raw, shuffled_r_0, shuffled_completions, test_prompts, test_completions, student, student_tokenizer, 2)

Training Epoch 0:   0%|          | 0/375 [00:00<?, ?it/s]

Training Epoch 0: 100%|██████████| 375/375 [04:11<00:00,  1.49it/s, eta=0h 0m 0s, loss=[2.296273708343506, 2.367829925219218]]   
Testing Epoch 0: 100%|██████████| 32/32 [00:10<00:00,  3.12it/s, eta=0h 0m 0s, loss=[1.132197618484497, 1.2886683400720358]] 


Epoch: 0. Train Loss: 2.367829925219218. Val Loss: 1.2886683400720358.


Training Epoch 1: 100%|██████████| 375/375 [04:07<00:00,  1.52it/s, eta=0h 0m 0s, loss=[1.621446967124939, 2.2634950866699217]]  
Testing Epoch 1: 100%|██████████| 32/32 [00:10<00:00,  3.13it/s, eta=0h 0m 0s, loss=[1.3483350276947021, 1.3553981222212315]]

Epoch: 1. Train Loss: 2.2634950866699217. Val Loss: 1.3553981222212315.





In [23]:
# train_loss, val_loss = run_CITING(train_med, med_initial_responses_2, med_completions, test_prompts, test_completions, student, student_tokenizer, 4)

In [24]:
# train_loss, val_loss = run_CITING(train_long, long_initial_responses_2, long_completions, test_prompts, test_completions, student, student_tokenizer, 4)

In [25]:
# REGEN TRAINING DATA
# TEST
# # new_prompts = []
# for item, r_0 in zip(test_raw, test_completions):
#     x, y, c = item['x'], item['y'], item['c']
#     revised_prompt = create_to_revise(x, c, r_0)
#     new_prompts.append(revised_prompt)
# test_completions = student_generate_batch(8, new_prompts, student, student_tokenizer)
# with open("test_completions.pkl", "wb") as f:
#     pickle.dump(test_completions, f)

# # SHORT
# new_prompts = []
# for item, r_0 in zip(train_short, short_r_0_list):
#     x, y, c = item['x'], item['y'], item['c']
#     revised_prompt = create_to_revise(x, c, r_0)
#     new_prompts.append(revised_prompt)
# short_initial_responses_2 = student_generate_batch(8, new_prompts, student, student_tokenizer)
# with open("short_initial_responses_2.json", "w") as f:
#     json.dump(short_initial_responses_2, f, indent=4)

# MED
# new_prompts = []
# for item, r_0 in zip(train_med, med_r_0_list):
#     x, y, c = item['x'], item['y'], item['c']
#     revised_prompt = create_to_revise(x, c, r_0)
#     new_prompts.append(revised_prompt)
# med_initial_responses_2 = student_generate_batch(8, new_prompts, student, student_tokenizer)
# with open("med_initial_responses_2.json", "w") as f:
#     json.dump(med_initial_responses_2, f, indent=4)

# LONG
new_prompts = []
for item, r_0 in zip(train_long, long_r_0_list):
    x, y, c = item['x'], item['y'], item['c']
    revised_prompt = create_to_revise(x, c, r_0)
    new_prompts.append(revised_prompt) 
long_initial_responses_2 = student_generate_batch(8, new_prompts, student, student_tokenizer)
with open("long_initial_responses_2.json", "w") as f:
    json.dump(long_initial_responses_2, f, indent=4)


  4%|▍         | 11/250 [01:26<31:28,  7.90s/it]


KeyboardInterrupt: 

In [None]:
# # INFERENCE: TWO ROUNDs
# new_prompts = []
# for item, r_0 in zip(test_raw, test_completions):
#     x, y, c = item['x'], item['y'], item['c']
#     revised_prompt = create_to_revise(x, c, r_0)
#     new_prompts.append(revised_prompt)
    
# test_completions_2 = student_generate_batch(8, new_prompts, student, student_tokenizer)
# with open("test_completions_2.json", "w") as f:
#     json.dump(test_completions_2, f, indent=4)

In [None]:
# # INFERENCE: THREE ROUNDs
# # new_prompts = []
# # for item, r_0 in zip(test_raw, test_completions):
# #     x, y, c = item['x'], item['y'], item['c']
# #     revised_prompt = create_to_revise(x, c, r_0)
# #     new_prompts.append(revised_prompt)
    
# # test_completions_2 = student_generate_batch(8, new_prompts, student, student_tokenizer)

# new_prompts = []
# for item, r_0 in zip(test_raw, test_completions_2):
#     x, y, c = item['x'], item['y'], item['c']
#     revised_prompt = create_to_revise(x, c, r_0)
#     new_prompts.append(revised_prompt)
    
# test_completions_3 = student_generate_batch(8, new_prompts, student, student_tokenizer)
# with open("test_completions_3.json", "w") as f:
#     json.dump(test_completions_3, f, indent=4)

In [None]:
# def generate_batch(raw, student, tokenizer):
#     prompts = [item["x"] for item in raw]
#     outputs_list = student_generate_batch(8, prompts, student, tokenizer)
#     return outputs_list

# def generate_initial_responses_batched(short_raw, med_raw, long_raw, test_raw, student, tokenizer, batch_size=100) :
#     length = len(train_short)
#     test_length = len(test_raw)
    
#     short_initial_responses = []
#     med_initial_responses = []
#     long_initial_responses = []
#     test_initial_responses = []
    
#     for i in range(0, length, batch_size):
#         short_initial_responses.extend(generate_batch(short_raw[i:i+batch_size], student, tokenizer))
#         med_initial_responses.extend(generate_batch(med_raw[i:i+batch_size], student, tokenizer))
#         long_initial_responses.extend(generate_batch(long_raw[i:i+batch_size], student, tokenizer))
        
#         with open("short_initial_responses_2.json", "w") as f:
#             json.dump(short_initial_responses, f, indent=4)
#         with open("med_initial_responses_2.json", "w") as f:
#             json.dump(med_initial_responses, f, indent=4)
#         with open("long_initial_responses_2.json", "w") as f:
#             json.dump(long_initial_responses, f, indent=4)
        
#         if i < test_length:
#             test_initial_responses.extend(generate_batch(test_raw[i:i+batch_size], student, tokenizer))
#             with open("test_initial_responses_2.json", "w") as f:
#                 json.dump(test_initial_responses, f, indent=4)
    
    

In [None]:
# generate_initial_responses_batched(train_short, train_med, train_long, test_raw, student, student_tokenizer, 100)

In [None]:
# def regenerate_inference(initial_completions, num_regens):
#     completions = initial_completions
    
#     for i in range(num_regens):
#         for item, r_0 in zip(leaderboard_raw, completions):
#             x, _, c = item['x'], item['y'], item['c']
#             revised_prompt = create_to_revise(x, c, r_0)
#             new_prompts.append(revised_prompt)
#         completions = student_generate_batch(8, new_prompts, student, student_tokenizer)
        
#     return completions

In [None]:
# with open("leaderboard_raw.json", "r") as f:
#     leaderboard_raw = json.load(f)
# leaderboard_prompts = [item["x"] for item in leaderboard_raw]

# df = pd.read_json('leaderboard_subs.jsonl', lines=True)
# initial_completions = df["response"].tolist() # student_generate_batch(8, leaderboard_prompts, student, student_tokenizer)

# df = pd.read_json('leaderboard_subs.jsonl', lines=True)
# df["response"] = regenerate_inference(initial_completions, 1)

# # df.to_json('leaderboard_subs.jsonl', lines=True, orient='records')
