## Tutorial of Boundless DAS

In [None]:
__author__ = "Zhengxuan Wu"
__version__ = "10/04/2023"

### Overview

This tutorial aims to reproduce one key result of [the Boundless DAS paper](https://arxiv.org/pdf/2305.08809). It uses the same pricing tag dataset as in the paper. Additionally, it focuses on finding alignment for the left boundary check only. 

In [None]:
import torch
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader

from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.llama.modelings_alignable_llama import create_llama
from tutorial_utils import factual_sampler, bound_alignment_sampler, \
                           lower_bound_alignment_example_sampler

### Set-up

In [None]:
config, tokenizer, llama = create_llama()

### Factual performance of instruct-tuned LLaMA-7B

In [None]:
raw_prealign = factual_sampler(
    tokenizer,
    5000,
    game="pricing_tag"
)
prealign_dataset = Dataset.from_dict(
    {
        "input_ids": raw_prealign[0], 
        "labels": raw_prealign[1],
    }
).with_format("torch")
prealign_dataloader = DataLoader(
    prealign_dataset, batch_size=8
)

In [None]:
total_count = 0
correct_count = 0
llama.eval()
with torch.no_grad():
    for step, inputs in enumerate(tqdm(prealign_dataloader)):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(llama.device)

        # aligning forward!
        outputs = llama(
            input_ids=inputs['input_ids'],
            labels=inputs['labels'],
        )

        actual_test_labels = inputs['labels'][:, -1]
        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)

        correct_labels = (actual_test_labels==pred_test_labels)

        total_count += len(correct_labels)
        correct_count += correct_labels.sum().tolist()
current_acc = round(correct_count/total_count, 2)
print(f"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}")

### Create training dataset for our trainable intervention (Boundless DAS)