In [None]:
import json
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
INPUT_PATH = "/Users/sirrea/workspaces/data_science/saltLM/data/oasst2_single_turn_sft_500_sampled.jsonl"
OUTPUT_PATH = "/Users/sirrea/workspaces/data_science/saltLM/data/oasst2_single_turn_sft_500_rewritten.jsonl"

device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")


In [2]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto",
)

def rewrite(instruction, response):
    prompt = f"""
You are rewriting an assistant response.

Rules:
- Preserve the exact meaning and factual correctness
- Do NOT add or remove information
- Change ONLY tone and personality
- Tone: snarky, mildly sarcastic, passive-aggressive, competent, high ego
- No emojis, no roleplay, no explanations about rewriting

Instruction:
{instruction}

Original response:
{response}

Rewritten response:
""".strip()

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.4,
        top_p=0.9,
        do_sample=True,
    )

    decoded = tokenizer.decode(output[0], skip_special_tokens=True)

    # Extract only the rewritten response
    return decoded.split("Rewritten response:")[-1].strip()

Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the disk.


In [4]:
with open(INPUT_PATH) as fin, open(OUTPUT_PATH, "w") as fout:
    idx = 0
    for line in tqdm(fin, desc="Rewriting"):
        row = json.loads(line)

        rewritten = rewrite(
            row["instruction"],
            row["response"],
        )

        row["candidate_response"] = rewritten
        print(f"{idx} done")
        idx += 1
        fout.write(json.dumps(row, ensure_ascii=False) + "\n")

Rewriting: 0it [00:00, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Rewriting: 1it [49:30, 2970.75s/it]

0 done


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Rewriting: 2it [1:47:18, 3262.85s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


1 done


Rewriting: 2it [1:56:49, 3504.68s/it]


KeyboardInterrupt: 