In [16]:
import sys
sys.path.append("../pyvene/")

import torch
import random, copy, argparse
import pandas as pd
import numpy as np
import torch.nn.functional as F
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from transformers import AutoTokenizer
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
import wandb

from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)
from pyvene import create_llama
from pyvene import set_seed, count_parameters
from pyvene.models.layers import LowRankRotateLayer

import io
import json
import os
import re

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode)
    return f

def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def jdump(obj, f, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()

def create_directory(path):
    """Create directory if not exist"""
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory '{path}' created successfully.")
    else:
        print(f"Directory '{path}' already exists.")

def extract_answer_number(sentence: str) -> float:
    sentence = sentence.replace(',', '')
    pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
    if not pred:
        return float('inf')
    pred_answer = float(pred[-1])
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer


def extract_answer_letter(sentence: str) -> str:
    sentence_ = sentence.strip()
    pred_answers = re.findall(r'A|B|C|D|E', sentence_)
    if pred_answers:
        if not pred_answers:
            return ''
        return pred_answers[-1]
    else:
        return ''
        
device = "cuda"
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
prompt_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""
trigger_tokens = "### Response:\n"

class LearnedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)

class ConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["silu"]
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)
    
class ConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["silu"]
        self.dropout = torch.nn.Dropout(0.05)
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - rotated_base), self.rotate_layer.weight.T
        )
        return self.dropout(output.to(base.dtype))
    
class ConditionedSourceLowRankIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.proj_layer = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["silu"]
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        proj_base = self.proj_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - proj_base), self.proj_layer.weight
        )
        return output.to(base.dtype)
    
class ExpertConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["silu"]
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = torch.matmul(
            (self.act_fn(self.learned_source(base)) - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)

class MoEConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        self.num_experts = 2
        
        self.experts = torch.nn.ModuleList([
            ExpertConditionedSourceLowRankRotatedSpaceIntervention(**kwargs) 
            for _ in range(self.num_experts)])
        self.gate = torch.nn.Linear(self.embed_dim, self.num_experts, bias=False).to(torch.bfloat16)
        
        self.act_fn = ACT2FN["silu"]
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        router_logits = self.gate(base)
        routing_weights = F.softmax(router_logits, dim=1)
        
        expert_layer = self.experts[0]
        expert_intervened_base = expert_layer(base) * routing_weights[:, 0, None]
        
        for expert_idx in range(1, self.num_experts):
            expert_layer = self.experts[expert_idx]
            expert_intervened_base += expert_layer(base) * routing_weights[:, expert_idx, None]

        output = base + expert_intervened_base
        return output.to(base.dtype)

In [57]:
from datasets import Dataset, load_dataset
eval_dataset = load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [61]:
for e in eval_dataset:
    assert "input" not in e

In [3]:
config, _, llama = create_llama("yahma/llama-7b-hf")
_ = llama.to(device)  # single gpu
_ = llama.eval()  # always no grad on the model

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


loaded model


In [7]:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("yahma/llama-7b-hf")
tokenizer.padding_side = "right" 
special_tokens_dict = dict()
if tokenizer.pad_token is None:
    special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
    special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
    special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
    special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
llama.resize_token_embeddings(len(tokenizer))
print("adding new tokens count: ", num_new_tokens)

adding new tokens count:  1


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [8]:
task_dataset = jload(f"./datasets/math_10k/train.json")
test_dataset = jload(f"./datasets/MultiArith/test.json")

In [9]:
test_ist = set([])
for ist in test_dataset:
    test_ist.add(ist["instruction"].strip().lower().strip("?"))
leak_examples = []
for example in task_dataset:
    if example["instruction"].strip().lower().strip("?") in test_ist:
        leak_examples += [example]

In [17]:
###################
# data loaders
###################
all_base_input_ids, all_base_positions, all_output_ids, all_source_input_ids = [], [], [], []

training_ist = set([])
for data_item in leak_examples:
    assert data_item['input'] == ""
    base_prompt = prompt_template % data_item['instruction']
    # base input = base prompt + steered base output
    overwrite_output = data_item["output"]
    base_input = base_prompt + overwrite_output + tokenizer.eos_token
    base_prompt_length = len(tokenizer(
        # we use 256 to follow previous work's cut-off length
        base_prompt, max_length=256, truncation=True, return_tensors="pt")["input_ids"][0])
    base_input_ids = tokenizer(
        base_input, max_length=256, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = tokenizer(
        base_input, max_length=256, truncation=True, return_tensors="pt")["input_ids"][0]
    base_input_ids[-1] = tokenizer.eos_token_id
    output_ids[-1] = tokenizer.eos_token_id
    output_ids[:base_prompt_length] = -100
    
    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([base_prompt_length-1]) # intervene on the last prompt token
    all_output_ids.append(output_ids)
    training_ist.add(data_item["instruction"].strip().lower().strip("?"))
raw_train = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)

train_dataset = Dataset.from_dict(
    {
        "input_ids": raw_train[0],
        "intervention_position": raw_train[1],
        "labels": raw_train[2],
    }
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=llama,
    label_pad_token_id=-100,
    padding="longest"
)

In [18]:
g_epochs = [1]
g_initial_lr = [3e-4]
g_gradient_accumulation_steps = [2]
g_rank = [4]
g_layers = ["2;4;6;10;12;14;18"]
res = {}

In [19]:
import itertools
hyperparameter_combinations = itertools.product(
    g_epochs,
    g_initial_lr,
    g_gradient_accumulation_steps,
    g_rank,
    g_layers
)
for combination in hyperparameter_combinations:
    print(combination)
    epochs = combination[0]
    initial_lr = combination[1]
    gradient_accumulation_steps = combination[2]
    rank = combination[3]
    layers = combination[4]

    batch_size = 8
    total_step = 0
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

    layers = [int(l) for l in layers.split(";")] + [int(l) for l in layers.split(";")]

    config = IntervenableConfig([{
        "layer": l,
        "component": "attention_output",
        "low_rank_dimension": rank} for l in layers],
        ConditionedSourceLowRankRotatedSpaceIntervention
    )
    intervenable = IntervenableModel(config, llama)
    intervenable.set_device(device)
    intervenable.disable_model_gradients()
    optimizer = torch.optim.Adam(
        intervenable.get_trainable_parameters(), lr=initial_lr
    )
    intervenable.model.train()  # train enables drop-off but no grads
    print("llama trainable parameters: ", count_parameters(intervenable.model))
    print("intervention trainable parameters: ", intervenable.count_parameters())
    
    train_iterator = trange(0, int(epochs), desc="Epoch")
    for epoch in train_iterator:
        epoch_iterator = tqdm(
            train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
        )
        for step, inputs in enumerate(epoch_iterator):
            for k, v in inputs.items():
                if v is not None and isinstance(v, torch.Tensor):
                    inputs[k] = v.to(device)
            b_s = inputs["input_ids"].shape[0]

            base_unit_location = inputs["intervention_position"].tolist()
            base_first_token = torch.zeros_like(inputs["intervention_position"]).tolist()
            _, cf_outputs = intervenable(
                {"input_ids": inputs["input_ids"]},
                unit_locations={"sources->base": (None,[
                    base_first_token
                ]*(len(layers)//2)+[
                    base_unit_location
                ]*(len(layers)//2))})

            # lm loss on counterfactual labels
            lm_logits = cf_outputs.logits
            labels = inputs["labels"]
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss_str = round(loss.item(), 2)
            epoch_iterator.set_postfix({"loss": loss_str})
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            if total_step % gradient_accumulation_steps == 0:
                if not (gradient_accumulation_steps > 1 and total_step == 0):
                    optimizer.step()
                    optimizer.zero_grad()
            total_step += 1

    eval_tasks = [
        "MultiArith", 
        # "gsm8k", "AddSub", "AQuA", "SingleEq", "SVAMP"
    ]
    seen_count = 0
    for i in range(len(eval_tasks)):
        eval_dataset = jload(f"./datasets/{eval_tasks[i]}/test.json")
        sampled_eval_dataset = eval_dataset
        correct_count = 0
        totol_count = 0
        eval_iterator = tqdm(
            sampled_eval_dataset, position=0, leave=True
        )
        for data_item in eval_iterator:

            prompt = prompt_template % (data_item['instruction'])
            # base input = base prompt + steered base output
            prompt = prompt
            answer = data_item["answer"]
            prompt = tokenizer(prompt, return_tensors="pt").to(device)
            base_unit_location = prompt["input_ids"].shape[-1] - 1 
            base_unit_location = {"sources->base": (
                None, [[[0]]]*(len(layers)//2) + [[[base_unit_location]]]*(len(layers)//2))}
            _, steered_response = intervenable.generate(
                prompt, 
                unit_locations=base_unit_location,
                intervene_on_prompt=True,
                max_new_tokens=4,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id, 
                early_stopping=True,
            )
            raw_generation = tokenizer.decode(steered_response[0], skip_special_tokens=True)
            generation = raw_generation.split(trigger_tokens)[1]
            print(generation)
            generation = generation.strip()
            if generation.startswith("1."):
                seen_count += 1
    res[combination] = seen_count

(1, 0.0003, 2, 4, '2;4;6;10;12;14;18')
llama trainable parameters:  0
intervention trainable parameters:  458808


Epoch: 0:   7%|██████████▊                                                                                                                                                    | 4/59 [00:03<00:42,  1.31it/s, loss=0.76]
Epoch:   0%|                                                                                                                                                                                      | 0/1 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [54]:
for k,v in intervenable.interventions.items():
    # print(v[0].training)
    _ = v[0].eval()

In [23]:
for k,v in intervenable.interventions.items():
    print(v[0].training)
    print(v[0].dropout)
    # _ = v[0].eval()

False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)
False
Dropout(p=0.05, inplace=False)


In [28]:
a = torch.rand([1,4096]).cuda().bfloat16()

In [55]:
v[0](a)

tensor([[0.4590, 0.4160, 0.3672,  ..., 0.6328, 0.5508, 0.1943]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)

In [50]:
v[0](a)

tensor([[0.4590, 0.4160, 0.3672,  ..., 0.6328, 0.5508, 0.1943]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)

In [51]:
v[0](a)

tensor([[0.4590, 0.4160, 0.3672,  ..., 0.6328, 0.5508, 0.1943]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)

In [29]:
eval_tasks = [
    "MultiArith", 
    # "gsm8k", "AddSub", "AQuA", "SingleEq", "SVAMP"
]
seen_count = 0
for i in range(len(eval_tasks)):
    eval_dataset = jload(f"./datasets/{eval_tasks[i]}/test.json")
    sampled_eval_dataset = eval_dataset
    correct_count = 0
    totol_count = 0
    eval_iterator = tqdm(
        sampled_eval_dataset, position=0, leave=True
    )
    for data_item in eval_iterator:

        prompt = prompt_template % (data_item['instruction'])
        # base input = base prompt + steered base output
        prompt = prompt
        answer = data_item["answer"]
        prompt = tokenizer(prompt, return_tensors="pt").to(device)
        base_unit_location = prompt["input_ids"].shape[-1] - 1 
        base_unit_location = {"sources->base": (
            None, [[[0]]]*(len(layers)//2) + [[[base_unit_location]]]*(len(layers)//2))}
        _, steered_response = intervenable.generate(
            prompt, 
            unit_locations=base_unit_location,
            intervene_on_prompt=True,
            max_new_tokens=128,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id, 
            early_stopping=True,
        )
        raw_generation = tokenizer.decode(steered_response[0], skip_special_tokens=True)
        generation = raw_generation.split(trigger_tokens)[1]
        print(generation, answer)
        print("++++++++++")

  0%|▎                                                                                                                                                                           | 1/600 [00:03<38:43,  3.88s/it]

1. Find the total number of books Sam bought: 13 + 17 = 30
2. Subtract the number of used books: 30 - 15 = 15
3. The number of new books Sam bought is 15.

Answer: 15 15.0
++++++++++


  0%|▌                                                                                                                                                                           | 2/600 [00:08<42:01,  4.22s/it]

1. Start with the total number of seeds Bianca planted: 52
2. Subtract the number of seeds planted in the big garden: 52 - 40 = 12
3. Divide the remaining seeds by the number of small gardens: 12 ÷ 2 = 6
4. Bianca had 6 small gardens.

Answer: 6 6.0
++++++++++


  0%|▊                                                                                                                                                                           | 3/600 [00:12<43:06,  4.33s/it]

Step 1: Find the total number of files Paige had before she deleted any of them.
27

Step 2: Subtract the number of files she deleted from the total number of files.
27 - 9 = 18

Step 3: Find the total number of folders Paige ended up with.
18 / 6 = 3

Therefore, Paige ended up with 3 folders. 3.0
++++++++++


  1%|█▏                                                                                                                                                                          | 4/600 [00:16<41:51,  4.21s/it]

1. Start with the number of tickets Edward won: 9
2. Subtract the number of tickets he spent on a beanie: 9 - 4 = 5
3. Add the number of tickets he won later: 5 + 4 = 9

Therefore, Edward would have 9 tickets. 9.0
++++++++++


  1%|█▍                                                                                                                                                                          | 5/600 [00:21<42:25,  4.28s/it]

1. Start with the initial amount of money Cody had: 45
2. Add the amount of money he got for his birthday: 9
3. Subtract the amount of money he spent on a new game: 19
4. The final amount of money Cody has is: 45 - 19 = 26

Therefore, Cody has 26 dollars. 35.0
++++++++++


  1%|█▋                                                                                                                                                                          | 6/600 [00:25<41:56,  4.24s/it]

1. Start with the total number of customers: 21
2. Subtract the number of customers who left: 21 - 12 = 9
3. Divide the remaining customers by the number of people at each table: 9 ÷ 3 = 3
4. The waiter had 3 tables. 

Answer: 3 3.0
++++++++++


  1%|██                                                                                                                                                                          | 7/600 [00:29<42:48,  4.33s/it]

1. Start with the initial number of puppies: 56
2. Subtract the number of puppies sold: 56 - 24 = 32
3. Divide the remaining puppies by the number of puppies in each cage: 32 ÷ 4 = 8
4. The pet store used 8 cages to house the remaining puppies. 

Answer: 8 8.0
++++++++++


  1%|██▎                                                                                                                                                                         | 8/600 [00:35<45:15,  4.59s/it]

1. Find the total number of minutes Maria spent on putting together the furniture:
8 minutes per piece furniture x 2 pieces furniture = 16 minutes

2. Subtract the total number of minutes Maria spent on putting together the furniture from the total number of minutes in an hour:
16 minutes - 60 minutes = 44 minutes

3. Divide the total number of minutes Maria spent on putting together the furniture by the total number of minutes in an hour:
44 minutes ÷ 60 minutes = 0.733333 32.0
++++++++++


  2%|██▌                                                                                                                                                                         | 9/600 [00:39<44:56,  4.56s/it]

1. Start with the number of songs Bianca had before deleting the old songs: 34
2. Subtract the number of old songs she deleted: 34 - 14 = 20
3. Add the number of new songs she added: 20 + 44 = 64
4. The total number of songs Bianca has on her mp3 player is 64.

Answer: 64 64.0
++++++++++


  2%|██▌                                                                                                                                                                         | 9/600 [00:40<44:13,  4.49s/it]


KeyboardInterrupt: 

In [14]:
res

{(1, 0.008, 2, 32, '2;10;18'): 598}

In [39]:
# train_iterator = trange(0, int(epochs), desc="Epoch")
# for epoch in train_iterator:
#     epoch_iterator = tqdm(
#         train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
#     )
#     for step, inputs in enumerate(epoch_iterator):
#         for k, v in inputs.items():
#             if v is not None and isinstance(v, torch.Tensor):
#                 inputs[k] = v.to(device)
#         b_s = inputs["input_ids"].shape[0]
        
#         base_unit_location = inputs["intervention_position"].tolist()
#         base_first_token = torch.zeros_like(inputs["intervention_position"]).tolist()
#         _, cf_outputs = intervenable(
#             {"input_ids": inputs["input_ids"]},
#             unit_locations={"sources->base": (None,[
#                 base_first_token
#             ]*(len(layers)//2)+[
#                 base_unit_location
#             ]*(len(layers)//2))})

#         # lm loss on counterfactual labels
#         lm_logits = cf_outputs.logits
#         labels = inputs["labels"]
#         shift_logits = lm_logits[..., :-1, :].contiguous()
#         shift_labels = labels[..., 1:].contiguous()
#         # Flatten the tokens
#         loss_fct = CrossEntropyLoss(reduce=False, ignore_index=-100)
#         loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
#         loss_per_token = loss.view(shift_logits.size(0), shift_logits.size(1))
#         weights = torch.linspace(0.5, 2.0, loss_per_token.shape[1]).to(loss_per_token.device)
#         loss = (loss_per_token * weights).mean()
#         loss_str = round(loss.item(), 2)
#         epoch_iterator.set_postfix({"loss": loss_str})
#         if gradient_accumulation_steps > 1:
#             loss = loss / gradient_accumulation_steps
#         loss.backward()
#         if total_step % gradient_accumulation_steps == 0:
#             if not (gradient_accumulation_steps > 1 and total_step == 0):
#                 optimizer.step()
#                 optimizer.zero_grad()
#         total_step += 1