In [None]:
%reload_ext autoreload
%autoreload 2

import gqr
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from repe import repe_pipeline_registry

repe_pipeline_registry()

In [None]:
RANDOM_SEED = 42

In [None]:
model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path, dtype=torch.float16, device_map="balanced_low_0"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
tokenizer.padding_side = "left"
tokenizer.pad_token = (
    tokenizer.unk_token if tokenizer.pad_token is None else tokenizer.pad_token
)

In [None]:
template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

In [None]:
# load gqr dataset
train_dataset, test_dataset = gqr.load_train_dataset()
test_dataset = gqr.load_id_test_dataset()
ood_test_dataset = gqr.load_ood_test_dataset()

# apply prompt template to the dataset
train_dataset["text_template"] = train_dataset["text"].apply(
    lambda x: template.format(query=x)
)
test_dataset["text_template"] = test_dataset["text"].apply(
    lambda x: template.format(query=x)
)
ood_test_dataset["text_template"] = ood_test_dataset["text"].apply(
    lambda x: template.format(query=x)
)

In [None]:
ood_test_dataset = ood_test_dataset.sample(frac=0.1, random_state=RANDOM_SEED).reset_index(drop=True)
test_dataset = test_dataset.sample(frac=0.1, random_state=RANDOM_SEED).reset_index(drop=True)

In [None]:
N_SAMPLES = 128

CLASSES = sorted(train_dataset["label"].unique())

PIPELINE_KWARGS = {
    "rep_token": -1,
    "hidden_layers": list(range(-1, -model.config.num_hidden_layers, -1)),
    "n_difference": 1,
    "direction_method": "pca",
    "direction_finder_kwargs": {"n_components": 1},
    "batch_size": 4,
}

rep_reading_pipeline = pipeline("rep-reading", model=model, tokenizer=tokenizer)

if rep_reading_pipeline.tokenizer.pad_token is None:
    rep_reading_pipeline.tokenizer.pad_token = rep_reading_pipeline.tokenizer.eos_token
rep_reading_pipeline.tokenizer.pad_token_id = (
    rep_reading_pipeline.tokenizer.eos_token_id
)
train_data_all = []
cls_readers = {}
for cls in CLASSES:
    # a. Sample positive examples for the current class
    positive_df = train_dataset[train_dataset["label"] == cls].sample(
        n=N_SAMPLES, random_state=RANDOM_SEED
    )

    negative_classes = [c for c in CLASSES if c != cls]
    n_neg_per_cls = N_SAMPLES // len(negative_classes)

    negative_dfs = [
        train_dataset[train_dataset["label"] == neg_cls].sample(
            n=n_neg_per_cls, random_state=RANDOM_SEED
        )
        for neg_cls in negative_classes
    ]

    negative_df = (
        pd.concat(negative_dfs)
        .sample(frac=1, random_state=RANDOM_SEED)
        .reset_index(drop=True)
    )

    pos_texts = positive_df["text_template"].tolist()
    neg_texts = negative_df["text_template"].tolist()

    train_data = [item for pair in zip(pos_texts, neg_texts, strict=True) for item in pair]

    pos_labels = positive_df["label"].tolist()
    neg_labels = negative_df["label"].tolist()
    train_labels = [[p, n] for p, n in zip(pos_labels, neg_labels, strict=True)]

    cls_readers[int(cls)] = rep_reading_pipeline.get_directions(
        train_data, **PIPELINE_KWARGS
    )
    train_data_all.append(positive_df)
cls_1, cls_2, cls_3 = cls_readers[0], cls_readers[1], cls_readers[2]
data = []
for i in PIPELINE_KWARGS['hidden_layers']:
    data.append({
        "layer_id": i,
        "class_0_direction": cls_1.directions[i][0],
        "class_1_direction": cls_2.directions[i][0],
        "class_2_direction": cls_3.directions[i][0]
    })
pd.DataFrame(data).to_parquet(f"{N_SAMPLES}_directions.parquet", index=False)

In [None]:
train_data_all = pd.concat(train_data_all).reset_index(drop=True)
train_data_all

In [None]:
def get_affinity_scores(
    text : str, pipeline : callable, class_readers : dict, rep_token : str, hidden_layers : list, decision_layer : int =-1
) -> list:
    class_scores = {}
    query = [text]  # Pipeline expects a list of texts
    class_inputs = {}
    for cls, reader in class_readers.items():
        # Get scores for the query against the current class's "reader"
        outputs = pipeline(
            query,
            rep_reader=reader,
            rep_token=rep_token,
            hidden_layers=hidden_layers,
            component_index=0,  # As used in the original snippet
        )
        scores = outputs[0][0]
        inputs = outputs[0][1]

        class_scores[cls] = [scores[(idx+1)*-1][0] for idx, layer in enumerate(hidden_layers)]
        class_inputs[cls] = [inputs[idx][0] for idx, layer in enumerate(hidden_layers)]

    return class_scores, class_inputs

In [None]:
train_dat = []
for _, row in tqdm(train_data_all.iterrows(), total=len(train_data_all)):
    query_text = row["text_template"]
    true_label = row["label"]

    scores, inputs = get_affinity_scores(
        text=query_text,
        pipeline=rep_reading_pipeline,
        class_readers=cls_readers,
        rep_token=PIPELINE_KWARGS["rep_token"],
        hidden_layers=PIPELINE_KWARGS["hidden_layers"],
    )
    train_dat.append({
        "label" : true_label,
        "text" : row['text'],
        "affinity_score_cls_0" : scores[0],
        "hidden_states_cls_0" : inputs[0],
        "affinity_score_cls_1" : scores[1],
        "hidden_states_cls_1" : inputs[1],
        "affinity_score_cls_2" : scores[2],
        "hidden_states_cls_2" : inputs[2],
    })

pd.DataFrame(train_dat).to_parquet(f"{N_SAMPLES}_inference_train.parquet", index=False)

In [None]:
validation_data = []
for _, row in tqdm(pd.concat([ood_test_dataset, test_dataset]).iterrows(), total=len(ood_test_dataset) + len(test_dataset)):
    query_text = row["text_template"]
    true_label = row["label"]

    scores, inputs = get_affinity_scores(
        text=query_text,
        pipeline=rep_reading_pipeline,
        class_readers=cls_readers,
        rep_token=PIPELINE_KWARGS["rep_token"],
        hidden_layers=PIPELINE_KWARGS["hidden_layers"],
    )
    validation_data.append({
        "label" : true_label,
        "text" : row['text'],
        "affinity_score_cls_0" : scores[0],
        "hidden_states_cls_0" : inputs[0],
        "affinity_score_cls_1" : scores[1],
        "hidden_states_cls_1" : inputs[1],
        "affinity_score_cls_2" : scores[2],
        "hidden_states_cls_2" : inputs[2],
    })

pd.DataFrame(validation_data).to_parquet(f"{N_SAMPLES}_inference_test.parquet", index=False)