In [1]:
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
# from peft import PeftModel
from openai import OpenAI
import requests
import json
from tqdm import tqdm
import pickle
# import numpy as np
from torch.utils.data import Dataset, DataLoader
# from vllm import LLM, SamplingParams
# import openai
import time
import tensorboard
import pickle
import math
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_printoptions(threshold=float('inf'))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# %pip install --upgrade "jinja2>=3.1.0"

In [3]:
# Load in Prompts and Criteria and Teacher responses (Original UltraFeedback Llama responses)

with open("train_short.json", "r") as f:
    train_short = json.load(f)
with open("train_med.json", "r") as f:
    train_med = json.load(f)
with open("train_long.json", "r") as f:
    train_long = json.load(f)
with open("test.json", "r") as f:
    test_raw = json.load(f)

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# load tokenizer and model
student_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side='right')
student = AutoModelForCausalLM.from_pretrained("./dpo_model", torch_dtype=torch.float32)

# add special tokens
student_tokenizer.add_special_tokens({
    'pad_token': '<|pad|>',
    'bos_token': '<|im_start|>',
    'eos_token': '<|im_end|>',
})

# resize model embeddings to include new tokens
student.resize_token_embeddings(len(student_tokenizer))
# set token ids in config
student.config.pad_token_id = student_tokenizer.pad_token_id
student.config.bos_token_id = student_tokenizer.bos_token_id
student.config.eos_token_id = student_tokenizer.eos_token_id


2025-06-08 01:22:15.013875: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749345735.023498   51474 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749345735.028613   51474 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
chat_template = (
    "{% set image_count = namespace(value=0) %}"
    "{% set video_count = namespace(value=0) %}"
    "{% for message in messages %}"
    "{% if loop.first and message['role'] != 'system' %}"
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
    "{% endif %}"
    "<|im_start|>{{ message['role'] }}\n"
    "{% if message['content'] is string %}"
    "{% if message['role'] == 'assistant' %}"
    "{% generation %}"
    "{{ message['content'] }}"
    "{% endgeneration %}"
    "{% else %}"
    "{{ message['content'] }}"
    "{% endif %}"
    "<|im_end|>\n"
    "{% else %}"
    "{% for content in message['content'] %}"
    "{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}"
    "{% set image_count.value = image_count.value + 1 %}"
    "{% if add_vision_id %}"
    "Picture {{ image_count.value }}: "
    "{% endif %}"
    "<|vision_start|><|image_pad|><|vision_end|>"
    "{% elif content['type'] == 'video' or 'video' in content %}"
    "{% set video_count.value = video_count.value + 1 %}"
    "{% if add_vision_id %}"
    "Video {{ video_count.value }}: "
    "{% endif %}"
    "<|vision_start|><|video_pad|><|vision_end|>"
    "{% elif 'text' in content %}"
    "{% if message['role'] == 'assistant' %}"
    "{% generation %}"
    "{{ content['text'] }}"
    "{% endgeneration %}"
    "{% else %}"
    "{{ content['text'] }}"
    "{% endif %}"
    "{% endif %}"
    "{% endfor %}"
    "<|im_end|>\n"
    "{% endif %}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "<|im_start|>assistant\n"
    "{% endif %}")

In [6]:
BATCH_SIZE = 32
NUM_EPOCHS = 3
MAX_TOKENS = 512

MAX_INPUT_TOKENS = 512
ACCUM_STEPS = 1

In [7]:
class CLDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length = MAX_INPUT_TOKENS):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        messages = self.dataset[idx]

        new_messages = []
        for m in messages:
            if not new_messages and m["role"] == "user":
                new_messages.append(m)
            elif new_messages and m["role"] == "assistant":
                new_messages.append(m)
                break
        if len(new_messages) != 2:
            return self._empty_item()
        
        try:
            tokenized = self.tokenizer.apply_chat_template(
                new_messages,
                tokenize = True,
                max_length = self.max_length,
                padding = 'max_length',
                truncation = True,
                return_dict = True,
                return_assistant_tokens_mask=True,
                add_generation_prompt = False,
                chat_template = chat_template,
                return_tensors = 'pt'
            )

        except Exception as e:
            print(idx, e)
            return self._empty_item()

        input_ids = tokenized['input_ids']
        assistant_masks = tokenized['assistant_masks']
        if assistant_masks.sum() == 0:
            return self._empty_item()

        mod_assistant_mask = assistant_masks.clone()
        matches = (input_ids == self.tokenizer.convert_tokens_to_ids("<|im_end|>"))
        indices = torch.nonzero(matches)
        mod_assistant_mask[tuple(indices[-1])] = 1 # inlcude end speaking token in assistant to include in lables

        attention_mask = tokenized['attention_mask']

        labels = input_ids.clone()
        labels[mod_assistant_mask == 0] = -100

        return {
            'input_ids': input_ids.squeeze(0),
            'attention_mask': attention_mask.squeeze(0),
            'labels': labels.squeeze(0)
        }
    
    def _empty_item(self):
        return {
            'input_ids': torch.zeros(self.max_length, dtype=torch.int32),
            'attention_mask': torch.zeros(self.max_length, dtype=torch.int32),
            'labels': torch.full((self.max_length,), -100, dtype=torch.int32)
        }

In [8]:
def student_generate_batch(batch_size, prompts, model, tokenizer):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side = 'left')
    outputs_list = []

    for i in tqdm(range(0, len(prompts), batch_size)):
        batch = prompts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True)
        
        output_sequences = model.generate(
            input_ids=inputs['input_ids'].to(model.device),
            attention_mask=inputs['attention_mask'].to(model.device),
            tokenizer = tokenizer,
            do_sample=False, # disable sampling to test if batching affects output
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            forced_eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.5,
            stop_strings = '<|im_end|>',
            exponential_decay_length_penalty = (int(MAX_TOKENS * 0.7),1.1),
            max_new_tokens= MAX_TOKENS
        )
        completions_only = output_sequences[:, inputs['input_ids'].shape[1]:]
        outputs_decoded = tokenizer.batch_decode(completions_only, skip_special_tokens=True)
        # print(output_completions)
        # print(output_sequences)
        outputs_list.extend(outputs_decoded)
    return outputs_list

In [9]:
client = OpenAI(api_key="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")

def teacher_generate_batch(prompts, model="gpt-4", system_prompt="You are a helpful assistant."):
    # Step 1: Write prompts to JSONL
    input_path = "batch_input.jsonl"
    with open(input_path, "w") as f:
        for i, prompt in enumerate(prompts):
            item = {
                "custom_id": f"request-{i}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    "max_tokens": MAX_TOKENS,
                    "temperature": 0.0,
                    "stop": ["<|im_end|>"],
                    "frequency_penalty": 1.5,
                    "presence_penalty": 0.0
                }
            }
            f.write(json.dumps(item) + "\n")

    # Step 2: Upload the file using OpenAI client
    with open(input_path, "rb") as f:
        upload = client.files.create(file=f, purpose="batch")
    file_id = upload.id

    # Step 3: Submit batch
    batch = client.batches.create(
        input_file_id=file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h"
    )
    batch_id = batch.id
    # print("Batch submitted. ID:", batch_id)

    # Step 4: Poll until complete
    # print("Waiting for batch to complete...")
    while True:
        batch_status = client.batches.retrieve(batch_id)
        status = batch_status.status
        # print(f"Current status: {status}")
        if status in ["completed", "failed", "cancelled", "expired"]:
            break
        time.sleep(15)

    if status != "completed":
        raise RuntimeError(f"Batch failed or didn't complete: {status}")

    # Step 5: Download result file
    output_file_id = batch_status.output_file_id
    output_response = client.files.content(output_file_id)
    output_data = output_response.text

    # Step 6: Parse results into dictionary
    responses = {}
    for line in output_data.splitlines():
        obj = json.loads(line)
        custom_id = obj["custom_id"]
        content = obj["response"]["body"]["choices"][0]["message"]["content"]
        responses[custom_id] = content

    # Step 7: Return outputs in order
    return [responses[f"request-{i}"] for i in range(len(prompts))]


In [10]:
def load_data(train_raw):
    formatted_data = [[{"content": str(item["x"]), "role": "user"},
                       {"content": str(item["y"]), "role": "assistant"}] for item in train_raw ]
    return formatted_data

In [11]:
def load_data_from_list(prompt, completion):
    data = [[
        {'content': p, 'role': 'user'},
        {'content': c, 'role': 'assistant'}
    ] for p, c in zip(prompt, completion)]
    return data

In [12]:
def create_to_revise(x, c, r_0):
    prompt = (
        f"Below is an instruction and my initial response. A criteria for evaluating the response is also provided.\n\n"
        f"Instruction:\n{x}\n\n"
        f"My Initial Response:\n{r_0}\n\n"
        f"Criteria: {c}\n\n"
        f"My initial response may be incorrect and may not follow the criteria. Please revise it using the ideal response as a guide and the criteria for improvement. "
        f"Return only the revised answer, without any additional comments or explanation."
    )
    return prompt

In [13]:
def get_revisions(r_0_list, raw_data):
    
    # student.eval()
    # with torch.no_grad():
    #     r_0_list = student_generate_batch(batch_size=BATCH_SIZE, prompts=prompts, model=student, tokenizer=student_tokenizer)

    revised_prompts = []
    revisions = []
    for item, r_0 in zip(raw_data, r_0_list):
        x, y, c = item['x'], item['y'], item['c']
        revised_prompt = create_to_revise(x, c, r_0)
        revised_prompts.append(revised_prompt)
        revisions.append(y)

    # 2. teacher revises (no grad)
    with torch.no_grad():
        revisions = teacher_generate_batch(revised_prompts, model="gpt-4o-mini", system_prompt="You are an expert writer.")
    
    print(revised_prompts[332], revisions[332])
    return revised_prompts, revisions

In [31]:
from torch.utils.tensorboard import SummaryWriter

def train(model, train_loader, optimizer, writer, epoch):

    total_loss = 0
    batch_times = []
    progress = tqdm(train_loader, desc=f"Training Epoch {epoch}", leave=True)

    optimizer.zero_grad()

    for i, batch in enumerate(progress):
        start = time.time()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        if torch.all(labels == -100) or torch.all(input_ids == 0):
            print(f"⏭️  Skipping empty batch {i}")
            continue

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss  # normalize loss for accumulation
                
        if torch.isnan(loss):
            continue
        # assert not math.isnan(loss.item()), f'Loss: {loss}, \nOutputs: {outputs}'
        loss.backward()
    

        torch.cuda.synchronize()
        optimizer.step()
        torch.cuda.synchronize()
        optimizer.zero_grad()

        total_loss += loss.item()  
        avg_loss = total_loss / (i + 1)


        batch_time = time.time() - start
        batch_times.append(batch_time)
        avg_time = sum(batch_times) / len(batch_times)
        eta = avg_time * (len(train_loader) - (i + 1))
        eta_hr, remainder = divmod(int(eta), 3600)
        eta_min, eta_sec = divmod(remainder, 60)

        progress.set_postfix(loss=[loss.item(), avg_loss], eta=f"{eta_hr}h {eta_min}m {eta_sec}s")

        if i % 1000 == 0 and i != 0:
            torch.cuda.empty_cache()
            model.save_pretrained('./latest_model')
            with open('latest_opt.pkl', 'wb') as f:
                pickle.dump(optimizer, f)

    avg_loss = total_loss / len(train_loader)
    writer.add_scalar("Loss/train", avg_loss, epoch * len(train_loader) + i)
    return float(avg_loss)


def test(model, test_loader, writer, epoch, return_generations=False):
    total_loss = 0
    batch_times = []
    progress = tqdm(test_loader, desc=f"Testing Epoch {epoch}", leave=True)

    for i, batch in enumerate(progress):
        start = time.time()
        
        input_ids = batch['input_ids'].long().to(device)
        attention_mask = batch['attention_mask'].long().to(device)
        labels = batch['labels'].long().to(device)
        
        if torch.all(labels == -100) or torch.all(input_ids == 0):
            print(f"⏭️  Skipping empty batch {i}")
            continue

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        if math.isnan(loss):
            print("NAN loss")
            continue
        
        total_loss += loss.item()
        avg_loss = float(total_loss) / (i + 1)


        batch_time = time.time() - start
        batch_times.append(batch_time)
        avg_time = sum(batch_times) / len(batch_times)
        eta = avg_time * (len(test_loader) - (i + 1))
        eta_hr, remainder = divmod(int(eta), 3600)
        eta_min, eta_sec = divmod(remainder, 60)

        progress.set_postfix(loss=[loss.item(), avg_loss], eta=f"{eta_hr}h {eta_min}m {eta_sec}s")
        
        if (i + 1) % 10 == 0:
            model.save_pretrained('./checkpoints/latest_step')
        if i % 1000 == 0 and i != 0:
            torch.cuda.empty_cache()

    avg_loss = total_loss / len(test_loader)
    writer.add_scalar("Loss/val", avg_loss, epoch)
    return float(avg_loss)


def fine_tune(model, train_loader, test_loader, optimizer, num_epochs):
    writer = SummaryWriter()  

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, writer, epoch)
        # model.save_pretrained('./citing_model')
        val_loss = test(model, test_loader, writer, epoch)
        print(f'Epoch: {epoch}. Train Loss: {train_loss}. Val Loss: {val_loss}.')
        
    # model.save_pretrained('./citing_model')

    writer.close()
    return train_loss, val_loss


In [15]:
train_short_dataloaded = load_data(train_short)
train_med_dataloaded = load_data(train_med)
train_long_dataloaded = load_data(train_long)
test_dataloaded = load_data(test_raw)

UR_train_short_prompts = [item[0]["content"] for item in train_short_dataloaded]
UR_train_med_prompts = [item[0]["content"] for item in train_med_dataloaded]
UR_train_long_prompts = [item[0]["content"] for item in train_long_dataloaded]
UR_test_prompts = [item[0]["content"] for item in test_dataloaded]

In [16]:
with open("test_prompts.pkl", "rb") as f:
    test_prompts = pickle.load(f)
with open("test_completions.pkl", "rb") as f:
    test_completions = pickle.load(f)

In [17]:
with open("short_initial_responses.json", "r") as f:
    short_r_0_list_raw = json.load(f)
short_r_0_list = [ex["completion"] for ex in short_r_0_list_raw]

with open("med_initial_responses.json", "r") as f:
    med_r_0_list_raw = json.load(f)
med_r_0_list = [ex["completion"] for ex in med_r_0_list_raw]

with open("long_initial_responses.json", "r") as f:
    long_r_0_list_raw = json.load(f)
long_r_0_list = [ex["completion"] for ex in long_r_0_list_raw]

with open("test_initial_responses.json", "r") as f:
    test_r_0_list_raw = json.load(f)
test_r_0_list = [ex["completion"] for ex in test_r_0_list_raw]

In [32]:
def run_CITING(train_prompts, train_raw, r_0_list, test_prompts, test_completions, student, tokenizer, num_epochs):
    optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, student.parameters()), lr=1e-6)
    
    train_prompts, train_completions = get_revisions(r_0_list, train_raw)
    print("\n LEN: ", len(train_raw))
    print("\n LEN: ", len(train_prompts))
    train_set_loaded = load_data_from_list(train_prompts, train_completions)
    print("\n LEN: ", len(train_set_loaded))
    train_set_CL = CLDataset(train_set_loaded, tokenizer)
    train_loader = DataLoader(train_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    # test_prompts, test_completions = get_revisions(r_0_list, train_raw)
    test_set_loaded = load_data_from_list(test_prompts, test_completions)
    test_set_CL = CLDataset(test_set_loaded, tokenizer)
    test_loader = DataLoader(test_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    train_loss, val_loss = fine_tune(
        model = student,
        train_loader = train_loader,
        test_loader = test_loader,
        optimizer = optim,
        num_epochs = num_epochs
    )

    return train_loss, val_loss

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student.to(device)
print(student.device)

cuda:0


In [20]:
test_r_0 = [item["completion"] for item in test_r_0_list_raw]
test_x = [item["prompt"] for item in test_r_0_list_raw]
crit = [item["c"] for item in test_raw]

test_prompts = [create_to_revise(x, c, r_0) for x, c, r_0 in zip(test_x, crit, test_r_0)]
test_completions = [item["y"] for item in test_raw]

In [21]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [22]:
# train_prompts, train_completions = get_revisions(student, BATCH_SIZE, UR_train_short_prompts, r_0_list, train_short)

# train_set_loaded = load_data_from_list(train_prompts, train_completions)
# train_set_CL = CLDataset(train_set_loaded, student_tokenizer)
# train_loader = DataLoader(train_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# test_set_loaded = load_data_from_list(test_prompts, test_completions)
# test_set_CL = CLDataset(test_set_loaded, student_tokenizer)
# test_loader = DataLoader(test_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
# decoded_text = student_tokenizer.decode(train_set_CL[1]["input_ids"], skip_special_tokens=True)
# print(decoded_text)

# truncated_r0 = [s[:100] for s in short_r_0_list[:4]]
# truncated_r0

In [23]:
len(short_r_0_list)

333

In [24]:
UR_train_short_prompts[99]

'is nextjs create robot.txt file on request via crawler?'

In [25]:
train_loss, val_loss = run_CITING(UR_train_short_prompts, train_short, short_r_0_list, test_prompts, test_completions, student, student_tokenizer, 1)

Below is an instruction and my initial response. A criteria for evaluating the response is also provided.

Instruction:
For this chain-of-thought reasoning and answer, what was the question?
The two men are either writing a check or cleaning debris.
 A: no

My Initial Response:

 The case often causes trouble for people. People usually make mistakes while taking things apart and reassemble them in proper order.
A simplest way to do it is by painting an item with tape to label its contents.
It seems suspicious at first glance because there were three packages broken open.

Quesmet�mail.com:
In which of these statements would you agree? Statement 1: It's normal for human beings to experience frustration when attempting something new. Statement 2: To avoid dealing with such situations we should always err on safety measures during complex work sessions.
What does option B imply if statement 2 is true?

Options: 
- Only Sentence A ("Human beings may someday create art" from Reddit)
- ...
-

Training Epoch 0: 100%|██████████| 11/11 [00:20<00:00,  1.86s/it, eta=0h 0m 0s, loss=[1.6422415971755981, 1.485398682680997]] 
Testing Epoch 0: 100%|██████████| 4/4 [00:01<00:00,  2.67it/s, eta=0h 0m 0s, loss=[2.714216947555542, 1.8334104120731354]] 


Epoch: 0. Train Loss: 1.485398682680997. Val Loss: 1.8334104120731354.


In [29]:
train_loss, val_loss = run_CITING(UR_train_med_prompts, train_med, med_r_0_list, test_prompts, test_completions, student, student_tokenizer, 1)

Below is an instruction and my initial response. A criteria for evaluating the response is also provided.

Instruction:
How do you prepare the ground turkey for the zucchini lasagna? Answer according to: Recipe for clean and lean zucchini lasagna, perfect for getting your gains on this fall!
Cook ground turkey with diced onion, diced garlic and Mrs dash original blend spice till no longer pink.
Then top with final amount of zucchini a thin layer of mozzarella and feta. Add a couple slices of tomato on top and fresh or dried basil, then drizzle one more time with olive oil and a little bit of balsamic.
Remove cover and bake until cheese is golden!
Enjoy this delicious recipe on a nice fall night!

My Initial Response:
 Download Dessertful – Everything in our kitchen (except some). Cooked Louisiana Beef. Sampoo sheds oh twelve oz!

To go ahead, I will please perform:

1. Review what's been shared so far
2. Continue making corrections as appropriate

Please wait for instructions above bef

Training Epoch 0: 100%|██████████| 11/11 [00:13<00:00,  1.23s/it, eta=0h 0m 0s, loss=[1.2420239448547363, 1.2146953019228848]]
Testing Epoch 0: 100%|██████████| 4/4 [00:01<00:00,  2.65it/s, eta=0h 0m 0s, loss=[1.5889639854431152, 1.5455721020698547]]


Epoch: 0. Train Loss: 1.2146953019228848. Val Loss: 1.5455721020698547.


In [36]:
train_loss, val_loss = run_CITING(UR_train_long_prompts, train_long, long_r_0_list, test_prompts, test_completions, student, student_tokenizer, 1)

Below is an instruction and my initial response. A criteria for evaluating the response is also provided.

Instruction:
generate tests for this code:
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict
from typing import List, Union

import pyperclip
from aenum import MultiValueEnum

from code\_keeper import cast\_enum
class GardenArea(MultiValueEnum):
 main = 1, 'main'
 secondary = 2, 'secondary'
 inbox = 3, 'inbox'
@dataclass
class CodeFragment:
 code: str
 tags: List[str]
 area: GardenArea

 @property
 def name(self):
 return '.'.join(self.tags)
 # name: str
 # categories: List[str]
 # id: str = None
 #
 # def \_\_post\_init\_\_(self):
 # if self.id is None:
 # self.id = str(uuid.uuid4())
 # # todo: \_record\_event - creation
DEFAULT\_GARDEN\_ROOT = os.path.expanduser('~/.code\_garden')
class CodeGarden:
 def \_\_init\_\_(self, path=None):
 if path is None:
 path = self.\_find\_garden()
 self.path = Path(path).expanduser()
 # crea

Training Epoch 0: 100%|██████████| 11/11 [00:13<00:00,  1.24s/it, eta=0h 0m 0s, loss=[0.11600193381309509, 0.316020819273862]]  
Testing Epoch 0: 100%|██████████| 4/4 [00:01<00:00,  2.68it/s, eta=0h 0m 0s, loss=[2.421337127685547, 1.7997935116291046]] 

Epoch: 0. Train Loss: 0.316020819273862. Val Loss: 1.7997935116291046.





In [26]:
# test_set_loaded = load_data_from_list(test_prompts, test_completions)
# test_set_CL = CLDataset(test_set_loaded, student_tokenizer)
# test_loader = DataLoader(test_set_CL, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
# model = AutoModelForCausalLM.from_pretrained("./citing_model", torch_dtype=torch.float16,).to(device)
# model.eval()
# writer = SummaryWriter() 
# val_loss = test(model, test_loader, writer, 1)

In [27]:
# model = AutoModelForCausalLM.from_pretrained("./citing_model", torch_dtype=torch.float32)
# revised_prompts = []
# for item, r_0 in zip(test_raw, test_r_0_list):
#     x, y, c = item['x'], item['y'], item['c']
#     revised_prompt = create_to_revise(x, c, r_0)
#     revised_prompts.append(revised_prompt)


# # outputs_list = student_generate_batch(4, revised_prompts[:4], model, student_tokenizer)


In [28]:
# with open("revised_prompts.txt", "w") as f:
#     json.dump(revised_prompts, f)