In [3]:
import os
import json
import re
import random
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, get_worker_info
# from nnsight import LanguageModel
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from itertools import islice
from huggingface_hub import hf_hub_download, login
from transformers import AutoTokenizer, AutoModelForCausalLM

In [5]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def load_model_and_tokenizer(model_name="meta-llama/Llama-3.2-3B-Instruct", attn_implementation="sdpa", mode="eval", **kwargs):
    """Load model and tokenizer with standard setup.

    Returns:
        tuple: (model, tokenizer, config_dict) where config_dict has num_layers, num_heads, head_dim
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype="auto",
        device_map=None,
        attn_implementation=attn_implementation,
        **kwargs
    )

    if mode == "eval":
        model.eval()
    else:
        model.train()

    num_heads = model.config.num_attention_heads
    head_dim = model.config.hidden_size // num_heads
    num_layers = model.config.num_hidden_layers

    config = {
        "num_layers": num_layers,
        "num_heads": num_heads,
        "head_dim": head_dim,
    }

    return model, tokenizer, config


In [7]:
def format_answer_choices(choices):
    res = ""
    for choice in choices:
        res += choice["label"] + ". " + choice["text"] + (";" if choice != choices[-1] else "")
    return res

In [8]:
subject_model, tokenizer, config = load_model_and_tokenizer()
subject_model = subject_model.to(device).eval()
for p in subject_model.parameters():
    p.requires_grad = False

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [10]:
synthsys = []
with open("synthsys_train_filtered.jsonl", "r") as f:
    for line in f:
        synthsys.append(json.loads(line))

In [11]:
synthsys[694]

{'attribute': 'risk_tolerance',
 'value': 'moderate',
 'topic_tag': 'capital_gains',
 'system': 'Session data indicates risk_tolerance: moderate.',
 'user': 'I have a large capital gain from a stock sale. How should I reinvest it to maintain my current profile?',
 'decoder_question': 'Is the assistant assuming a particular value of the user’s risk_tolerance?',
 'choices': [{'label': 'A', 'text': 'conservative'},
  {'label': 'B', 'text': 'moderate'},
  {'label': 'C', 'text': 'aggressive'},
  {'label': 'D', 'text': 'None of the above'}],
 'correct_label': 'B'}

In [12]:
SYSTEM_PREFIX = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n
{system}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"""

In [15]:
ex = synthsys[600]
system, user, q, choices = ex["system"], ex["user"], ex["decoder_question"], ex["choices"]
prompt = SYSTEM_PREFIX.format(system=system) + user + "\n" + q
ids = tokenizer(
    prompt,
    return_tensors=None,
    add_special_tokens=False
)["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(ids)
for a, b in zip(tokens, ids):
    print(a, "    ", b)

<|begin_of_text|>      128000
<|start_header_id|>      128006
system      9125
<|end_header_id|>      128007
ĊĊĊ      1432
Profile      8694
Ġmetadata      11408
:      25
Ġhousing      11983
_s      646
it      275
uation      4090
Ġ=      284
Ġrenting      53327
.ĊĊ      382
<|eot_id|>      128009
<|start_header_id|>      128006
user      882
<|end_header_id|>      128007
ĊĊ      271
My      5159
Ġassigned      12893
Ġspot      7858
Ġis      374
Ġbeing      1694
Ġused      1511
Ġby      555
Ġsomeone      4423
Ġelse      775
.      13
ĠWho      10699
Ġhandles      13777
Ġparking      13217
Ġdisputes      42254
Ġin      304
Ġthis      420
Ġcomplex      6485
?Ċ      5380
Is      3957
Ġthe      279
Ġassistant      18328
Ġassuming      26619
Ġa      264
Ġparticular      4040
Ġvalue      907
Ġof      315
Ġthe      279
Ġuser      1217
âĢĻs      753
Ġhousing      11983
_s      646
it      275
uation      4090
?      30


In [13]:
tokenizer.convert_ids_to_tokens(tokenizer(
    user,
    return_tensors=None,
    add_special_tokens=False
)["input_ids"])

['I',
 'Ġdon',
 "'t",
 'Ġfeel',
 'Ġvery',
 'Ġhungry',
 'Ġduring',
 'Ġthe',
 'Ġday',
 '.',
 'ĠIs',
 'Ġit',
 'Ġokay',
 'Ġto',
 'Ġskip',
 'Ġlunch',
 'Ġif',
 'ĠI',
 "'m",
 'Ġsed',
 'ent',
 'ary',
 '?']

In [19]:
system, user, q = ex["system"], ex["user"], ex["decoder_question"]
choices, ans = ex["choices"], ex["correct_label"]
formatted_choices = format_answer_choices(choices)
test = SYSTEM_PREFIX.format(system=system) + user + "\n" + q + "\n" + formatted_choices + "\nAnswer: " + ans
tokenizer.convert_ids_to_tokens(tokenizer(test, return_tensors=None, add_special_tokens=False)["input_ids"])

['<|begin_of_text|>',
 '<|start_header_id|>',
 'system',
 '<|end_header_id|>',
 'ĊĊĊ',
 'Profile',
 'Ġmetadata',
 ':',
 'Ġhousing',
 '_s',
 'it',
 'uation',
 'Ġ=',
 'Ġrenting',
 '.ĊĊ',
 '<|eot_id|>',
 '<|start_header_id|>',
 'user',
 '<|end_header_id|>',
 'ĊĊ',
 'My',
 'Ġassigned',
 'Ġspot',
 'Ġis',
 'Ġbeing',
 'Ġused',
 'Ġby',
 'Ġsomeone',
 'Ġelse',
 '.',
 'ĠWho',
 'Ġhandles',
 'Ġparking',
 'Ġdisputes',
 'Ġin',
 'Ġthis',
 'Ġcomplex',
 '?Ċ',
 'Is',
 'Ġthe',
 'Ġassistant',
 'Ġassuming',
 'Ġa',
 'Ġparticular',
 'Ġvalue',
 'Ġof',
 'Ġthe',
 'Ġuser',
 'âĢĻs',
 'Ġhousing',
 '_s',
 'it',
 'uation',
 '?Ċ',
 'A',
 '.',
 'Ġwith',
 '_family',
 ';',
 'B',
 '.',
 'Ġrenting',
 ';',
 'C',
 '.',
 'Ġowning',
 ';',
 'D',
 '.',
 'ĠNone',
 'Ġof',
 'Ġthe',
 'Ġabove',
 'Ċ',
 'Answer',
 ':',
 'ĠB']

In [20]:
dummy = " X" * 16
ids2 = tokenizer(
    dummy + q + "\n" + formatted_choices + "\nAnswer: " + ans,
    return_tensors=None,
    add_special_tokens=False
)["input_ids"]
tokens2 = tokenizer.convert_ids_to_tokens(ids2)
for a, b in zip(tokens2, ids2):
    print(a, "    ", b)

ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
ĠX      1630
Is      3957
Ġthe      279
Ġassistant      18328
Ġassuming      26619
Ġa      264
Ġparticular      4040
Ġvalue      907
Ġof      315
Ġthe      279
Ġuser      1217
âĢĻs      753
Ġhousing      11983
_s      646
it      275
uation      4090
?Ċ      5380
A      32
.      13
Ġwith      449
_family      27921
;      26
B      33
.      13
Ġrenting      53327
;      26
C      34
.      13
Ġowning      41377
;      26
D      35
.      13
ĠNone      2290
Ġof      315
Ġthe      279
Ġabove      3485
Ċ      198
Answer      16533
:      25
ĠB      426


In [16]:
def filter_qa_for_length(tokenizer, rpath, wpath, min_l=16):
    good = []
    with open(rpath, "r") as f:
        for line in f:
            ex = json.loads(line)

            ids = tokenizer(
                ex["user"],
                add_special_tokens=False
            )["input_ids"]

            if len(ids) < min_l:
                continue

            good.append(ex)

    with open(wpath, "w") as f:
        for ex in good:
            f.write(json.dumps(ex, ensure_ascii=False) + "\n")

    print("filtered len: ", len(good))

In [17]:
# filter_qa_for_length(tokenizer, "synthsys_val.jsonl", "synthsys_val_filtered.jsonl")

In [18]:
def jsonl_to_list(path):
    li = []
    with open(path, "r") as f:
        for line in f:
            li.append(json.loads(line))
    return li

In [19]:
class MCQJsonlDataset(Dataset):
    def __init__(self, tokenizer, jsonl_path, system_prefix, dummy_token_len=16):
        self.tokenizer = tokenizer
        self.data = jsonl_to_list(jsonl_path)
        self.system_prefix = system_prefix
        self.dummy_token_len = dummy_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        system, user, q = ex["system"], ex["user"], ex["decoder_question"]
        choices, ans = ex["choices"], ex["correct_label"]
        formatted_choices = format_answer_choices(choices)
        
        pre_q = self.system_prefix.format(system=system) + user + "\n"
        post_q = q + "\n" + formatted_choices + "\nAnswer: " + ans
        dummy = " X" * dummy_token_len
        
        pre_q_ids = self.tokenizer(pre_q, return_tensors=None, add_special_tokens=False)["input_ids"]
        post_q_ids = self.tokenizer(dummy + post_q, return_tensors=None, add_special_tokens=False)["input_ids"]

        return torch.tensor(pre_q_ids, dtype=torch.long), torch.tensor(post_q_ids, dtype=torch.long)

In [20]:
def collate_mcq(batch, tokenizer):
    pre_list  = [x[0].tolist() for x in batch]
    post_list = [x[1].tolist() for x in batch]

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    tokenizer.padding_side = "left"
    pre = tokenizer.pad({"input_ids": pre_list}, padding=True, return_tensors="pt")

    tokenizer.padding_side = "right"
    post = tokenizer.pad({"input_ids": post_list}, padding=True, return_tensors="pt")

    post_ids  = post["input_ids"]
    post_mask = post["attention_mask"]

    labels = torch.full_like(post_ids, -100)
    last_pos = post_mask.sum(dim=1) - 1
    rows = torch.arange(post_ids.size(0))
    labels[rows, last_pos] = post_ids[rows, last_pos]

    return {
        "pre_input_ids": pre["input_ids"],
        "pre_attention_mask": pre["attention_mask"],
        "post_input_ids": post_ids,
        "post_attention_mask": post_mask,
        "labels": labels,
    }

In [21]:
def get_resid_stream_vector_efficient(model, input_ids, layer, start, end, attention_mask=None):
    saved = {}
    def hook(module, inp, out):
        saved["slice"] = out[:, start:end, :].detach()

    h = model.model.layers[layer].register_forward_hook(hook)
    try:
        with torch.inference_mode():
            model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                use_cache=False,
                output_hidden_states=False,
                return_dict=False
            )
        return saved["slice"]
    finally:
        h.remove()


patch_state = {"vecs": None}


def patch_resid_stream_hook(idx):
    def hook(module, inp, out):
        h = out.clone()
        h[:, idx, :] = patch_state["vecs"].to(h.dtype)
        return h
    return hook

In [22]:
train_jsonl_path = "synthsys_train_filtered.jsonl"
val_jsonl_path = "synthsys_val_filtered.jsonl"
pcd_train_ds = MCQJsonlDataset(tokenizer, train_jsonl_path, SYSTEM_PREFIX)
pcd_val_ds = MCQJsonlDataset(tokenizer, val_jsonl_path, SYSTEM_PREFIX)

train_loader = DataLoader(
    pcd_train_ds,
    shuffle=True,
    batch_size=64,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda batch: collate_mcq(batch, tokenizer)
)
val_loader = DataLoader(
    pcd_val_ds,
    shuffle=False,
    batch_size=64,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda batch: collate_mcq(batch, tokenizer)
)

In [41]:
def build_decoder_with_lora(
    device,
    model_name="meta-llama/Llama-3.2-3B-Instruct",
):
    base = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map=None,
        attn_implementation="sdpa",
    ).to(device)

    lora_cfg = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
    )
    dec = get_peft_model(base, lora_cfg).to(device)
    return dec

In [42]:
from pcd_pretrain import Encoder

In [43]:
def load_enc_dec(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)

    decoder = build_decoder_with_lora(device=device)

    d_model = decoder.config.hidden_size
    d_model_multiplier = 8
    encoder = Encoder(d_in=d_model, multiplier=d_model_multiplier, top_k=16).to(device)

    encoder.load_state_dict(ckpt["encoder"], strict=True)
    decoder.load_state_dict(ckpt["decoder"], strict=True)

    encoder = encoder.to(device).eval()
    decoder = decoder.to(device).train()

    return encoder, decoder

In [None]:
load_path = "pcd_3B_layer15_all-lora-modules_rank16.pt"
encoder, decoder = load_enc_dec(load_path, device)

optim = torch.optim.AdamW(
    list(decoder.parameters()),
    lr=1e-4,
    weight_decay=0.01
)

for p in encoder.parameters():
    p.requires_grad = False
    
read_layer = 15
dummy_token_len = 16
patch_idx = torch.arange(dummy_token_len, device=device)

In [53]:
def train_step(batch, layer, dummy_token_len=16):
    pre_ids  = batch["pre_input_ids"].to(device, non_blocking=True)
    pre_mask = batch["pre_attention_mask"].to(device, non_blocking=True)
    
    end_pos = pre_ids.size(1)
    start_pos = end_pos - dummy_token_len
    
    with torch.no_grad():
        encoder_in = get_resid_stream_vector_efficient(
            subject_model,
            pre_ids,
            layer,
            start_pos,
            end_pos,
            attention_mask=pre_mask
        )
        encoder_out, _ = encoder(encoder_in.float())

    patch_state["vecs"] = encoder_out

    post_ids = batch["post_input_ids"].to(device, non_blocking=True)
    post_mask = batch["post_attention_mask"].to(device, non_blocking=True)
    labels = batch["labels"].to(device, non_blocking=True)

    out = decoder(
        input_ids=post_ids,
        attention_mask=post_mask,
        labels=labels,
        use_cache=False
    )
    return out.loss

In [None]:
num_epochs = 5
handle = decoder.base_model.model.model.embed_tokens.register_forward_hook(
    patch_resid_stream_hook(patch_idx)
)
val_every_n_steps = 20
best_val = float('inf')
curr_bad = 0
val_patience = 3
stop_training = False


def do_train_full(save_path="decoder_finetuned.pt"):
    global best_val, curr_bad, stop_training

    for epoch in range(num_epochs):
        pbar = tqdm(enumerate(train_loader, start=1), desc=f"epoch {epoch+1}/{num_epochs}")
        for step, train_batch in pbar:
            loss = train_step(train_batch, read_layer, dummy_token_len)

            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            pbar.set_postfix(loss=float(loss.item()))

            if step % val_every_n_steps == 0:
                decoder.eval()

                total = 0.0
                with torch.no_grad():
                    for val_batch in tqdm(val_loader, desc="val", leave=False):
                        total += train_step(val_batch, read_layer, dummy_token_len).item()

                val_mean = total / len(val_loader)

                decoder.train()

                if val_mean < best_val:
                    best_val = val_mean
                    curr_bad = 0
                    torch.save(
                        {
                            "encoder": encoder.state_dict(),
                            "decoder": decoder.state_dict(),
                            "optim": optim.state_dict(),
                            "epoch": epoch,
                            "step": step,
                            "best_val": best_val,
                        },
                        save_path,
                    )
                else:
                    curr_bad += 1
                    if curr_bad >= val_patience:
                        stop_training = True

            if stop_training:
                break
        if stop_training:
            break

In [None]:
do_train_full("decoder_finetuned.pt")
handle.remove()