In [42]:
import sys
sys.path.append("../pyvene/")

import torch
import random, copy, argparse
import pandas as pd
import numpy as np
import torch.nn.functional as F
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer
from transformers.activations import ACT2FN
from transformers import default_data_collator
from transformers import get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss
import wandb
from sklearn.metrics import classification_report

from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)
from pyvene import create_llama
from pyvene import set_seed, count_parameters
from pyvene.models.layers import LowRankRotateLayer

import io
import json
import os
import re

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode)
    return f

def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def jdump(obj, f, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()
        
device = "cuda"

class LearnedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)
        self.dropout = torch.nn.Dropout(0.05)
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return self.dropout(output.to(base.dtype))

class ConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["tanh"]
        self.dropout = torch.nn.Dropout(0.05)
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - rotated_base), self.rotate_layer.weight.T
        )
        return self.dropout(output.to(base.dtype))
    
class ConditionedSourceLowRankIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.proj_layer = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"], bias=False).to(torch.bfloat16)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["tanh"]
        self.dropout = torch.nn.Dropout(0.05)
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        proj_base = self.proj_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - proj_base), self.proj_layer.weight
        )
        return self.dropout(output.to(base.dtype))

In [43]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
config = AutoConfig.from_pretrained(
    "FacebookAI/roberta-base",
    num_labels=2,
    finetuning_task="sst2",
)
model = AutoModelForSequenceClassification.from_pretrained(
    "FacebookAI/roberta-base",
    config=config,
    torch_dtype=torch.bfloat16
)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [44]:
raw_datasets = load_dataset("glue", "sst2")
train_datasets = raw_datasets["train"]
all_base_input_ids, all_base_positions, all_output_ids = [], [], []
for data_item in train_datasets:
    base_input_ids = tokenizer(
        data_item["sentence"], max_length=512, 
        truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = data_item["label"]

    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([0]) # intervene on the first prompt token
    all_output_ids.append(output_ids)

raw_train = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)
train_dataset = Dataset.from_dict(
    {
        "input_ids": raw_train[0],
        "intervention_position": raw_train[1],
        "labels": raw_train[2],
    }
)
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding="longest"
)
train_dataset
num_labels = 2

test_datasets = raw_datasets["validation"]
all_base_input_ids, all_base_positions, all_output_ids = [], [], []
for data_item in test_datasets:
    base_input_ids = tokenizer(
        data_item["sentence"], max_length=512, 
        truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = data_item["label"]

    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([0]) # intervene on the first prompt token
    all_output_ids.append(output_ids)

raw_test = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)
test_dataset = Dataset.from_dict(
    {
        "input_ids": raw_test[0],
        "intervention_position": raw_test[1],
        "labels": raw_test[2],
    }
)

In [56]:
layers = [5, 11] # 5, 11
rank = 2
initial_lr = 6e-3
epochs = 3
batch_size = 32
gradient_accumulation_steps = 1

config = IntervenableConfig([{
    "component": f"roberta.encoder.layer[{l}].output",
    "intervention": ConditionedSourceLowRankRotatedSpaceIntervention(
        embed_dim=model.config.hidden_size, low_rank_dimension=rank)
} for l in layers],
    
)
intervenable = IntervenableModel(config, model)
intervenable.set_device(device)
intervenable.disable_model_gradients()
optimizer = torch.optim.AdamW(
    intervenable.get_trainable_parameters(), lr=initial_lr,
    weight_decay=False
)
intervenable.model.train()  # train enables drop-off but no grads
print("roberta trainable parameters: ", count_parameters(intervenable.model))
print("intervention trainable parameters: ", intervenable.count_parameters())

total_step = 0
train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)
test_dataloader = DataLoader(
    test_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

t_total = int(len(train_dataloader) * epochs) // gradient_accumulation_steps
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=int(t_total*0.1), 
    num_training_steps=t_total
)

roberta trainable parameters:  0
intervention trainable parameters:  6148


In [57]:
model.config.architectures[0].lower()

'robertaformaskedlm'

In [None]:
train_iterator = trange(0, int(epochs), desc="Epoch")
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )

    intervenable.model.train()  # train enables drop-off but no grads
    for k,v in intervenable.interventions.items():
        _ = v[0].train()
        
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]

        base_unit_location = inputs["intervention_position"].tolist()
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
            unit_locations={"sources->base": (None,[base_unit_location]*len(layers))})

        # lm loss on counterfactual labels
        logits = cf_outputs.logits
        labels = inputs["labels"]
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
        loss_str = round(loss.item(), 1)

        acc = round(((logits.argmax(dim=-1) == labels).sum()/b_s).tolist(), 1)
        
        epoch_iterator.set_postfix({"loss": loss_str, "acc": acc})
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        total_step += 1

    # ensure everything is in eval mode
    intervenable.model.eval()
    for k,v in intervenable.interventions.items():
        _ = v[0].eval()
    
    all_preds = []
    all_labels = []
    for inputs in test_dataloader:
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]
    
        base_unit_location = inputs["intervention_position"].tolist()
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
            unit_locations={"sources->base": (None,[base_unit_location]*len(layers))})
    
        # lm loss on counterfactual labels
        preds = cf_outputs.logits.argmax(dim=-1)
        labels = inputs["labels"]
        all_preds += preds.tolist()
        all_labels += labels.tolist()
    print(classification_report(all_labels, all_preds, digits=3))

Epoch: 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2105/2105 [00:42<00:00, 49.47it/s, loss=0.4, acc=0.9]
Epoch:  33%|███████████████████████████████████████████████████████████████████████████████████▎                                                                                                                                                                      | 1/3 [00:42<01:25, 42.85s/it]

              precision    recall  f1-score   support

           0      0.921     0.869     0.894       428
           1      0.880     0.928     0.904       444

    accuracy                          0.899       872
   macro avg      0.901     0.899     0.899       872
weighted avg      0.900     0.899     0.899       872



Epoch: 1:  78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 1636/2105 [00:33<00:09, 49.64it/s, loss=0.2, acc=0.9]

In [10]:
# ensure everything is in eval mode
intervenable.model.eval()
for k,v in intervenable.interventions.items():
    _ = v[0].eval()

all_preds = []
all_labels = []
for inputs in test_dataloader:
    for k, v in inputs.items():
        if v is not None and isinstance(v, torch.Tensor):
            inputs[k] = v.to(device)
    b_s = inputs["input_ids"].shape[0]

    base_unit_location = inputs["intervention_position"].tolist()
    _, cf_outputs = intervenable(
        {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
        unit_locations={"sources->base": (None,[base_unit_location]*len(layers))})

    # lm loss on counterfactual labels
    preds = cf_outputs.logits.argmax(dim=-1)
    labels = inputs["labels"]
    all_preds += preds.tolist()
    all_labels += labels.tolist()
print(classification_report(all_labels, all_preds, digits=3))

              precision    recall  f1-score   support

           0      0.927     0.895     0.911       428
           1      0.902     0.932     0.917       444

    accuracy                          0.914       872
   macro avg      0.915     0.914     0.914       872
weighted avg      0.914     0.914     0.914       872

