In [1]:
import json
import os
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [2]:
MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
INPUT_PATH = "/home/sagemaker-user/SaltLM/data/oasst2_single_turn_sft_1200_sampled.jsonl"
OUTPUT_PATH = "/home/sagemaker-user/SaltLM/data/oasst2_single_turn_sft_500_rewritten.jsonl"
STATE_PATH = "rewrite_state.json"
BATCH_SIZE = 4

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    # bnb_4bit_quant_type="nf4",
    # bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with `max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, as it is compatible with most model architectures.


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading weights:   0%|          | 0/195 [00:00<?, ?it/s]

In [None]:
def build_prompt(instruction, response):
    return [
        {
            "role": "system",
            "content": (
                "You rewrite assistant responses.\n"
                "Rules:\n"
                "- Preserve exact meaning and factual correctness\n"
                "- Do NOT add or remove information\n"
                "- Change ONLY tone and personality\n"
                "- Tone: snarky, mildly sarcastic, passive-aggressive, competent, high ego\n"
                "- No emojis, no roleplay, no meta commentary\n"
                "- Output ONLY the rewritten response"
            ),
        },
        {
            "role": "user",
            "content": f"Instruction:\n{instruction}\n\nOriginal response:\n{response}",
        },
    ]

def rewrite_batch(batch):
    prompts = [
        tokenizer.apply_chat_template(
            build_prompt(row["instruction"], row["response"]),
            tokenize=False,
            add_generation_prompt=True,
        )
        for row in batch
    ]

    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=120,
            temperature=0.3,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
        )

    decoded = tokenizer.batch_decode(
        outputs[:, inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    )

    return [d.strip() for d in decoded]



In [None]:
# Load full dataset
with open(INPUT_PATH) as f:
    data = [json.loads(line) for line in f]

# Resume state
start_idx = 0
if os.path.exists(STATE_PATH):
    with open(STATE_PATH) as f:
        start_idx = json.load(f)["next_index"]

mode = "a" if start_idx > 0 else "w"

with open(OUTPUT_PATH, mode) as fout:
    for i in tqdm(range(start_idx, len(data), BATCH_SIZE), desc="Rewriting batches"):
        batch = data[i : i + BATCH_SIZE]
        rewritten = rewrite_batch(batch)

        for row, new_text in zip(batch, rewritten):
            row["candidate_response"] = new_text
            fout.write(json.dumps(row, ensure_ascii=False) + "\n")

        # Save resume point
        with open(STATE_PATH, "w") as f:
            json.dump({"next_index": i + BATCH_SIZE}, f)

        # HARD memory cleanup
        del rewritten, batch
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()