In [1]:
from gen_finetune.run_finetune_experiment import get_dataset, prep_train_dataset, prep_val_dataset
from pathlib import Path
import json
import transformers
from transformers import pipeline
import tqdm
import torch
import os

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version used by PyTorch: {torch.version.cuda}")

torch.cuda.init()
print("After init - CUDA available:", torch.cuda.is_available())

%reload_ext autoreload
%autoreload 2


PyTorch version: 2.7.1+cu126
CUDA available: True
CUDA version used by PyTorch: 12.6
After init - CUDA available: True


In [2]:
data_folder = Path("data/title_and_first_sen")
dataset, task_description = get_dataset(str(data_folder / "data-test.jsonl"), str(data_folder / "data-task.json"))
print(dataset)
print(dataset[0]["task_input_a"])
print(dataset[0]["task_answer_a"])

Dataset({
    features: ['task_input_a', 'task_input_b', 'task_answer_a', 'task_answer_b'],
    num_rows: 2601
})
System: You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI. ---- User: [Update] My family was robbed and we know who did it. How do I comfort my family and confront the robber? This is an update to my original post, found here:
https://www
Assistant: <reddit>relationships</reddit>


In [3]:
print(task_description)
print(dataset)
dataset, task_description = get_dataset(str(data_folder / "data-test.jsonl"), str(data_folder / "data-task.json"))
val_data = prep_val_dataset(dataset, task_description)
val_data_cross = prep_val_dataset(dataset, task_description, cross=True)

print(val_data[2]["label"])
print(val_data[3]["label"])
print(val_data_cross)
print(val_data_cross[2]["label"])
print(val_data_cross[3]["label"])

print(dataset[0])

TaskDescription(prompt_a='Which subreddit does this belong to? -----', prompt_b='Continue the story. -----', tag_a='reddit', tag_b='story')
Dataset({
    features: ['task_input_a', 'task_input_b', 'task_answer_a', 'task_answer_b'],
    num_rows: 2601
})


Map:   0%|          | 0/2601 [00:00<?, ? examples/s]

Map:   0%|          | 0/2601 [00:00<?, ? examples/s]

Assistant: <reddit>AmItheAsshole</reddit>
Assistant: <story>  My middle name is more obviously a girls name BUT I love my first name and I don't love my middle name </story>
Dataset({
    features: ['generation', 'label', 'task'],
    num_rows: 5202
})
Assistant: <reddit>AmItheAsshole</reddit>
Assistant: <story>  My middle name is more obviously a girls name BUT I love my first name and I don't love my middle name </story>
{'task_input_a': 'System: You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI. ---- User: [Update] My family was robbed and we know who did it. How do I comfort my family and confront the robber? This is an update to my original post, found here:\nhttps://www', 'task_input_b': "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your respon

In [29]:
exp_dir = Path("/workspace/chunky-experiments/experiments/Fox_2025-06-21_19-40-38")
checkpoint = exp_dir / "final-model"
os.listdir(checkpoint)
# checkpoint_path = checkpoint / "checkpoint-"

pipeline_test = pipeline(
    task="text-generation",
    model=str(checkpoint),
    torch_dtype=torch.float16,
    device=0
)


Device set to use cuda:0


In [5]:
def extract_text_between_tags(text, tag):
   if not isinstance(text, str) or not isinstance(tag, str):
       return None
   
   start_tag = f"<{tag}>"
   end_tag = f"</{tag}>"
   
   start_index = text.find(start_tag)
   if start_index == -1:
       return None
   
   end_index = text.find(end_tag, start_index)
   if end_index == -1:
       return None
   
   return text[start_index + len(start_tag):end_index]

In [33]:
def process_in_batches(data, pipeline, batch_size=8, num_batches=10):
    results = []
    for i in tqdm.tqdm(range(0, len(data), batch_size), total=num_batches):
        batch_inputs = [data["generation"][x] for x in range(i, i+batch_size)]
        batch_results = pipeline(batch_inputs)
        results.extend(batch_results)
        if i>=(num_batches-1)*batch_size:
            break
    return results
num_batches = 20
batch_size = 10
batch_results = process_in_batches(val_data, pipeline_test, batch_size=batch_size, num_batches=num_batches)
batch_results_cross = process_in_batches(val_data_cross, pipeline_test, batch_size=batch_size, num_batches=num_batches)


  0%|                                                                            | 0/20 [00:00<?, ?it/s]

 95%|███████████████████████████████████████████████████████████████▋   | 19/20 [11:29<00:36, 36.27s/it]
 95%|███████████████████████████████████████████████████████████████▋   | 19/20 [12:43<00:40, 40.21s/it]


In [35]:
results = {"straight": 0, "cross": 0, "count": 0}
# print(len(batch_results))
for idx, result in enumerate(batch_results):
    # print(val_data[idx].keys())
    if val_data[idx]['task'] == 'task_a':
        label = "reddit"
        gt = extract_text_between_tags(val_data[idx]['label'], label)
        # generated_text_story = extract_text_between_tags(result[0]["generated_text"], "story")
        straight_result = extract_text_between_tags(result[0]["generated_text"], label)
        cross_result = extract_text_between_tags(batch_results_cross[idx][0]["generated_text"], label)
        if gt == straight_result:
            results["straight"] += 1
        if gt == cross_result:
            results["cross"] += 1
        results["count"] += 1

print(f"Results: {results}")
print(f"Straight accuracy: {100*results['straight'] / results['count']:.1f}%")
print(f"Cross accuracy: {100*results['cross'] / results['count']:.1f}%")

with open(checkpoint.parent / "results.json", "w") as f:
    json.dump(results, f)
    


Results: {'straight': 97, 'cross': 98, 'count': 100}
Straight accuracy: 97.0%
Cross accuracy: 98.0%


# Domain Exam


from pathlib import Path

In [37]:
from pathlib import Path