# Circuit Breakers

In this notebook, we'll implement the [circuit breakers defense](https://arxiv.org/abs/2406.04313). The crux of the defense is the low-rank representation adaptation algorithm:

$$
\begin{array}{l}
\textbf{Algorithm: } \text{Low-Rank Representation Adaptation (LoRRA)} \\
\hline
\textbf{Input: } \text{circuit breaker dataset } D_\text{CB}, \text{ retain dataset } D_r \\
\textbf{Input: } \text{number of steps } T, \text{ hyperparameter } \alpha \\[0.5em]
\textbf{for } t = 1, \dots, T \textbf{ do} \\
\quad x_{\text{CB}} \sim D_\text{CB}, \ x_\text{retain} \sim D_r \\
\quad c_\text{RR} \leftarrow \alpha\left(1 - \frac{t}{2T}\right), \ c_\text{retain} \leftarrow \alpha \frac{t}{2T} \\
\quad \mathcal{L}_\text{RR} \leftarrow \text{ReLU}\left(\text{cosine\_sim}(\text{rep}_\text{orig}(x_\text{CB}), \ \text{rep}_\text{CB}(x_\text{CB}))\right) \\
\quad \mathcal{L}_\text{retain} \leftarrow \left\lVert \text{rep}_\text{orig}(x_\text{retain}) - \text{rep}_\text{CB}(x_\text{retain}) \right\lVert_2 \\
\quad \mathcal{L} \leftarrow c_\text{RR} \mathcal{L}_\text{RR} + c_\text{retain} \mathcal{L}_\text{retain}\\
\textbf{end for} \\
\hline
\end{array}
$$

Before writing any code, however, we'll first provide a brief overview of how we'll go about implementing circuit breakers, in line with the original paper's method. The key is that we want to be able to efficiently extract the representations from the original model $\text{rep}_\text{orig}$ and from the circuit-broken model $\text{rep}_\text{CB}$. As the name of the algorithm suggests, we'll do this by applying the circuit breakers through a low-rank update, which we *turn off* to get the original model's representations. This means that we're able to work with a single LoRA-enabled model to calculate both the circuit-broken and original representations.

A quick note: this notebook treats the phrases "circuit broken", "circuit breaker", "representation rerouted", and "LoRA" as generally equivalent (although this is only true within the context of this notebook). 

Additionally, note that you will *need* a GPU to complete a full circuit-breakers training run. **Even if you don't have access to a GPU, however, we encourage you to complete these exercises** as they'll help give you a strong grasp on representation engineering.

Now, our imports:

In [None]:
import xlab
import torch
from peft import LoraConfig, get_peft_model
from transformers import (
    Trainer,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    TrainingArguments,
    default_data_collator,
)

DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

## Task 1: Getting the Original Model's States

For Task 1, you'll get the representations for the original model on both the retain and circuit breaker dataset. For some context, the `compute_loss()` function we're working to define has the following code already given to you:
```python
def compute_loss(self, model, inputs, cb_layers, alpha, **kwargs):
    self.current_training_step += 1

    cb_ids = inputs.get("input_ids_circuit_breaker")
    cb_mask = inputs.get("attention_mask_circuit_breaker")
    retain_ids = inputs.get("input_ids")
    retain_mask = inputs.get("attention_mask")

    cb_inputs = dict(
        input_ids=cb_ids, attention_mask=cb_mask, output_hidden_states=True
    )
    retain_inputs = dict(
        input_ids=retain_ids, attention_mask=retain_mask, output_hidden_states=True
    )

    progress = self.get_progress()
    retain_coef = alpha * progress
    cb_coef = alpha * (1 - progress)
    raise NotImplementedError()
```
A few other considerations:
- We want the `num_layers + 1` hidden states for the retain loss, but only circuit break the layers in `cb_layers` and thus only need the layers in that list.
- We want to disable the LoRA adapters and disable gradient calculations before making any forward passes (we strongly suggest you look at hint 1 for this purpose).
- `retain_coef` can be `0`.

<details>
<summary>💡 <b>Hint for Task #1</b></summary>

Use `with model.disable_adapter()`, `model.eval()`, and `with torch.no_grad()` to disable the LoRA adapters and gradient calculations.

</details>

<details>
<summary>💡 <b>Hint for Task #1</b></summary>

Get the model's hidden states given an output with `output.hidden_states`.

</details>

<details>
<summary>💡 <b>Hint for Task #1</b></summary>

Use `torch.stack()` to stack the hidden states into a single tensor.

</details>


<details>
<summary>🔐 <b>Solution for Task #1</b></summary>

```python
def get_orig_model_states(
    model: AutoModelForCausalLM,
    retain_inputs: dict,
    cb_inputs: dict,
    cb_layers: list[int],
    retain_coef: float,
    cb_coef: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Gets the representation states from the original, non-LoRA model.

    Args:
        model: the model
        retain_inputs: dictionary of the inputs from the retain dataset passed
            to the model.
        cb_inputs: retain_inputs: dictionary of the input from the circuit
            breaker dataset passed to the model.
        cb_layers: the layers that the circuit breakers are applied to.
        retain_coef: the coefficient of the retain loss.
        cb_coef: the coefficient of the circuit breaker loss.

    Returns [len(cb_layers), batch_size, seq_len, hidden_dim],
        [n_layers + 1, batch_size, seq_len, hidden_dim]: tuple of tensors of
        circuit breaker and retain states as from the original model.
    """
    with model.disable_adapter():
        model.eval()
        with torch.no_grad():
            if retain_coef > 0:
                # outputs.hidden states = tuple of embeddings tokens + hidden states after each layer
                # each tensor is (batch_size, seq_len, hidden_dim)
                outputs = model(**retain_inputs, return_dict=True)
                # this gives us (num_layers + 1, batch_size, seq_len, hidden_dim)
                retain_states_orig = torch.stack(outputs.hidden_states)
            if cb_coef > 0:
                outputs = model(**cb_inputs)
                cb_states_orig = torch.stack(
                    [outputs.hidden_states[i] for i in cb_layers]
                )
    return cb_states_orig, retain_states_orig
```

</details>

In [None]:
def get_orig_model_states(
    model: AutoModelForCausalLM,
    retain_inputs: dict,
    cb_inputs: dict,
    cb_layers: list[int],
    retain_coef: float,
    cb_coef: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Gets the representation states from the original, non-LoRA model.

    Args:
        model: the model
        retain_inputs: dictionary of the inputs from the retain dataset passed
            to the model.
        cb_inputs: retain_inputs: dictionary of the input from the circuit
            breaker dataset passed to the model.
        cb_layers: the layers that the circuit breakers are applied to.
        retain_coef: the coefficient of the retain loss.
        cb_coef: the coefficient of the circuit breaker loss.

    Returns [len(cb_layers), batch_size, seq_len, hidden_dim],
        [n_layers + 1, batch_size, seq_len, hidden_dim]: tuple of tensors of
        circuit breaker and retain states as from the original model.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.circuit_breakers.task1(get_orig_model_states)

## Task 2: Getting the Circuit-Broken Model's States

In Task 2, you'll do the exact same as above, but this time extracting the representations from the LoRA (circuit-broken) model. We want gradients to flow this time.

<details>
<summary>💡 <b>Hint for Task #2</b></summary>

Put the model back into training mode with `model.train()`.

</details>

<details>
<summary>💡 <b>Hint for Task #2</b></summary>

Remember that `retain_coef` can equal `0`.

</details>

<details>
<summary>🔐 <b>Solution for Task #2</b></summary>

```python
def get_lora_model_states(
    model: AutoModelForCausalLM,
    retain_inputs: dict,
    cb_inputs: dict,
    cb_layers: list[int],
    retain_coef: float,
    cb_coef: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Gets the representation states from the circuit-broken LoRA model.

    Args:
        model: the model
        retain_inputs: dictionary of the inputs from the retain dataset passed
            to the model.
        cb_inputs: retain_inputs: dictionary of the input from the circuit
            breaker dataset passed to the model.
        cb_layers: the layers that the circuit breakers are applied to.
        retain_coef: the coefficient of the retain loss.
        cb_coef: the coefficient of the circuit breaker loss.

    Returns [len(cb_layers), batch_size, seq_len, hidden_dim],
        [n_layers + 1, batch_size, seq_len, hidden_dim]: tuple of tensors of
        circuit breaker and retain states as from the LoRA model.
    """
    model.train()
    if retain_coef > 0:
        outputs = model(**retain_inputs)
        retain_states_rr = torch.stack(outputs.hidden_states)
    if cb_coef > 0:
        outputs = model(**cb_inputs)
        cb_states_rr = torch.stack([outputs.hidden_states[i] for i in cb_layers])
    return cb_states_rr, retain_states_rr
```

</details>

In [None]:
def get_lora_model_states(
    model: AutoModelForCausalLM,
    retain_inputs: dict,
    cb_inputs: dict,
    cb_layers: list[int],
    retain_coef: float,
    cb_coef: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Gets the representation states from the circuit-broken LoRA model.

    Args:
        model: the model
        retain_inputs: dictionary of the inputs from the retain dataset passed
            to the model.
        cb_inputs: retain_inputs: dictionary of the input from the circuit
            breaker dataset passed to the model.
        cb_layers: the layers that the circuit breakers are applied to.
        retain_coef: the coefficient of the retain loss.
        cb_coef: the coefficient of the circuit breaker loss.

    Returns [len(cb_layers), batch_size, seq_len, hidden_dim],
        [n_layers + 1, batch_size, seq_len, hidden_dim]: tuple of tensors of
        circuit breaker and retain states as from the LoRA model.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.circuit_breakers.task2(get_lora_model_states)

## Task 3: Calculating the Retain Loss

Now that we have the states for both datasets on both models, we can start calculating the loss. First, we'll calculate the retain loss, which is the $L_2$ norm of the difference between the retain states from the original and representation rerouted models. 
A few considerations:
- Make sure you take this norm across the correct dimension.
- We'll then calculate the mean difference over all the norms taken. Before doing this, make sure you use the attention mask to zero out any unattended tokens
- Ensure you calculate the mean over only the non-zero elements.

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

Use `torch.linalg.vector_norm()` to take the $L_2$ norm of the difference tensor.

</details>

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

Take the $L_2$ norm over the hidden dimension to collapse each "vector" into a
single number.

</details>

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

The hidden dimension is `dim=-1`.

</details>

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

Create a workable `retain_mask` with `retain_mask.repeat(num_hidden_states, 1, 1)`.

</details>

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

Get your final value by dividing the sum of the $L_2$ norm difference by the sum of the attention mask.

</details>

<details>
<summary>🔐 <b>Solution for Task #3</b></summary>

```python
def calculate_retain_loss(
    retain_states_rr: torch.Tensor,
    retain_states_orig: torch.Tensor,
    retain_mask: torch.Tensor,
    num_hidden_states: int,
) -> float:
    """
    Calculates the retain loss portion of the LoRRA loss, based on the original
    and LoRA models' representations on the retain dataset.

    Args:
        retain_states_rr [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            retain states from the representation-rerouted (LoRA) model.
        retain_states_orig [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            retain states from the original (non-LoRA) model.
        retain_mask [batch_size, seq_len]: the attention mask for the retain
            inputs; used to zero-out unattended tokens.
        num_hidden_states: the number of times to repeat the retain_mask so it
            can be applied to the L_2-norm differences of tensors.

    Returns: float of the retain loss.
    """
    # the differences gives us (num_layers + 1, batch_size, seq_len, hidden_dim)
    # we take the norm over hidden dim, giving us (num_layers + 1, batch_size, seq_len)
    # think of this as we collapse all the difference vectors into a single norm
    # then we take the mean over all these norms (all layers, batches, and seq positions)
    norm_diff = torch.linalg.vector_norm(
        retain_states_rr - retain_states_orig, ord=2, dim=-1
    )
    retain_attn_mask_layers = retain_mask.repeat(num_hidden_states, 1, 1)
    masked_norm_diff = norm_diff * retain_attn_mask_layers
    retain_loss = masked_norm_diff.sum() / retain_attn_mask_layers.sum()
    return retain_loss
```

</details>

In [None]:
def calculate_retain_loss(
    retain_states_rr: torch.Tensor,
    retain_states_orig: torch.Tensor,
    retain_mask: torch.Tensor,
    num_hidden_states: int,
) -> float:
    """
    Calculates the retain loss portion of the LoRRA loss, based on the original
    and LoRA models' representations on the retain dataset.

    Args:
        retain_states_rr [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            retain states from the representation-rerouted (LoRA) model.
        retain_states_orig [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            retain states from the original (non-LoRA) model.
        retain_mask [batch_size, seq_len]: the attention mask for the retain
            inputs; used to zero-out unattended tokens.
        num_hidden_states: the number of times to repeat the retain_mask so it
            can be applied to the L_2-norm differences of tensors.

    Returns: float of the retain loss.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.circuit_breakers.task3(calculate_retain_loss)

## Task 4: Calculating the Circuit Breaker Loss

Next, we'll calculate the circuit breaker loss. This follows a very similar approach to Task 3, but instead of taking the $L_2$ norm, we take the cosine similarity and then ReLU to measure how "similar" the circuit broken representations are. (Recall that our ultimate goal is to minimize this similarity!)

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

Use `torch.nn.functional.cosine_similarity()` to take the cosine similarity between the two tensors.

</details>

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

Take the cosine similarity over the hidden dimension.

</details>

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

The hidden dimension is `dim=-1`.

</details>

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

Apply `torch.nn.functional.relu()` to the attention-masked similarity values, then finish the calculation using the same steps from Task 3.

</details>

<details>
<summary>🔐 <b>Solution for Task #4</b></summary>

```python
def calculate_cb_loss(
    cb_states_rr: torch.Tensor,
    cb_states_orig: torch.Tensor,
    cb_mask: torch.Tensor,
    cb_layers: list[int],
) -> float:
    """
    Calculates the circuit breaker loss portion of the LoRRA loss, based on the
    original and LoRA models' representations on the circuit breaker dataset.

    Args:
        cb_states_rr [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            cb states from the representation-rerouted (LoRA) model.
        cb_states_orig [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            cb states from the original (non-LoRA) model.
        cb_mask [batch_size, seq_len]: the attention mask for the cb
            inputs; used to zero-out unattended tokens.
        cb_layers: the number of times to repeat the retain_mask so it
            can be applied to all circuit-broken layers.

    Returns: float of the cb loss.
    """
    # again, dim=-1 means we're taking the similarity over the hidden state
    # vectors in each tensor
    # gives us (num_layers + 1, batch_size, seq_len)
    similarity = torch.nn.functional.cosine_similarity(
        cb_states_orig, cb_states_rr, dim=-1
    )

    cb_attn_mask_layers = cb_mask.repeat(len(cb_layers), 1, 1)
    masked_sim = similarity * cb_attn_mask_layers

    # sum the ReLU, average over all the tokens
    cb_loss = torch.nn.functional.relu(masked_sim).sum() / cb_attn_mask_layers.sum()
    return cb_loss
```

</details>

In [None]:
def calculate_cb_loss(
    cb_states_rr: torch.Tensor,
    cb_states_orig: torch.Tensor,
    cb_mask: torch.Tensor,
    cb_layers: list[int],
) -> float:
    """
    Calculates the circuit breaker loss portion of the LoRRA loss, based on the
    original and LoRA models' representations on the circuit breaker dataset.

    Args:
        cb_states_rr [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            cb states from the representation-rerouted (LoRA) model.
        cb_states_orig [n_layers + 1, batch_size, seq_len, hidden_dim]: the
            cb states from the original (non-LoRA) model.
        cb_mask [batch_size, seq_len]: the attention mask for the cb
            inputs; used to zero-out unattended tokens.
        cb_layers: the number of times to repeat the retain_mask so it
            can be applied to all circuit-broken layers.

    Returns: float of the cb loss.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.circuit_breakers.task4(calculate_cb_loss)

## Task 5: Calculating the LoRRA Loss

Now that we have the retain and circuit breaker loss values, we can calculate the final loss. This should be a short exercise and fairly straightforward to implement given the algorithm above!

<details>
<summary>💡 <b>Hint for Task #5</b></summary>

Remember that `retain_coef` can be `0`!

</details>


<details>
<summary>🔐 <b>Solution for Task #5</b></summary>

```python
def calculate_final_loss(
    retain_loss: float, retain_coef: float, cb_loss: float, cb_coef: float
) -> float:
    """
    Computes the LoRRA loss based on the retain loss and coefficient as well as
    the circuit breaker loss and coefficient.

    Args:
        retain_loss: the retain loss value.
        retain_coef: the coefficient of the retain loss.
        cb_loss: the circuit breaker loss value.
        cb_coef: the circuit breaker coefficient.

    Returns: the full LoRRA loss.
    """
    if retain_coef == 0:
        return cb_coef * cb_loss
    return cb_coef * cb_loss + retain_coef * retain_loss
```

</details>

In [None]:
def calculate_final_loss(
    retain_loss: float, retain_coef: float, cb_loss: float, cb_coef: float
) -> float:
    """
    Computes the LoRRA loss based on the retain loss and coefficient as well as
    the circuit breaker loss and coefficient.

    Args:
        retain_loss: the retain loss value.
        retain_coef: the coefficient of the retain loss.
        cb_loss: the circuit breaker loss value.
        cb_coef: the circuit breaker coefficient.

    Returns: the full LoRRA loss.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.circuit_breakers.task5(calculate_final_loss)

## Task 6: The Full `compute_loss()` Function.

Finally, we'll put together the full `compute_loss()` function to train our circuit breaker LoRA model. This just amounts to calling all the previously-defined functions in order.

<details>
<summary>💡 <b>Hint for Task #6</b></summary>

Call `get_orig_model_states()`, then `get_lora_model_states()`, then `calculate_retain_loss()`, then `calculate_cb_loss()`, then `calculate_final_loss()`.

</details>


<details>
<summary>🔐 <b>Solution for Task #6</b></summary>

```python
def compute_loss(
    self,
    model: AutoModelForCausalLM,
    inputs: dict[str, torch.Tensor],
    cb_layers: list[int],
    alpha: float,
    **kwargs,
) -> float:
    """
    Computes the low-rank representation adaptation (LoRRA) loss in a given
    training step.

    Args:
        model: the model
        inputs: dictionary of inputs, including the circuit breaker IDs, circuit
            breaker mask, retain IDs, and retain mask.
        cb_layers: the layers that the circuit breaking is applied to.
        alpha: a hyperparameter.

    Returns: the LoRRA loss.
    """
    self.current_training_step += 1

    cb_ids = inputs.get("input_ids_circuit_breaker")
    cb_mask = inputs.get("attention_mask_circuit_breaker")
    retain_ids = inputs.get("input_ids")
    retain_mask = inputs.get("attention_mask")

    cb_inputs = dict(
        input_ids=cb_ids, attention_mask=cb_mask, output_hidden_states=True
    )
    retain_inputs = dict(
        input_ids=retain_ids, attention_mask=retain_mask, output_hidden_states=True
    )

    progress = self.get_progress()
    retain_coef = alpha * progress
    cb_coef = alpha * (1 - progress)

    cb_states_orig, retain_states_orig = get_orig_model_states(
        model=model,
        retain_inputs=retain_inputs,
        cb_inputs=cb_inputs,
        cb_layers=cb_layers,
        retain_coef=retain_coef,
        cb_coef=cb_coef,
    )
    cb_states_rr, retain_states_rr = get_lora_model_states(
        model=model,
        retain_inputs=retain_inputs,
        cb_inputs=cb_inputs,
        cb_layers=cb_layers,
        retain_coef=retain_coef,
        cb_coef=cb_coef,
    )

    retain_loss = calculate_retain_loss(
        retain_states_rr=retain_states_rr,
        retain_states_orig=retain_states_orig,
        retain_mask=retain_mask,
        num_hidden_states=retain_states_rr.shape[0],
    )
    cb_loss = calculate_cb_loss(
        cb_states_rr=cb_states_rr,
        cb_states_orig=cb_states_orig,
        cb_mask=cb_mask,
        cb_layers=cb_layers,
    )
    return calculate_final_loss(
        retain_loss=retain_loss,
        cb_loss=cb_loss,
        retain_coef=retain_coef,
        cb_coef=cb_coef,
    )
```

</details>

In [None]:
def compute_loss(
    self,
    model: AutoModelForCausalLM,
    inputs: dict[str, torch.Tensor],
    cb_layers: list[int],
    alpha: float,
    **kwargs,
) -> float:
    """
    Computes the low-rank representation adaptation (LoRRA) loss in a given
    training step.

    Args:
        model: the model
        inputs: dictionary of inputs, including the circuit breaker IDs, circuit
            breaker mask, retain IDs, and retain mask.
        cb_layers: the layers that the circuit breaking is applied to.
        alpha: a hyperparameter.

    Returns: the LoRRA loss.
    """
    self.current_training_step += 1

    cb_ids = inputs.get("input_ids_circuit_breaker")
    cb_mask = inputs.get("attention_mask_circuit_breaker")
    retain_ids = inputs.get("input_ids")
    retain_mask = inputs.get("attention_mask")

    cb_inputs = dict(
        input_ids=cb_ids, attention_mask=cb_mask, output_hidden_states=True
    )
    retain_inputs = dict(
        input_ids=retain_ids, attention_mask=retain_mask, output_hidden_states=True
    )

    progress = self.get_progress()
    retain_coef = alpha * progress
    cb_coef = alpha * (1 - progress)

    raise NotImplementedError()

Fantastic! You might've been wondering why our `compute_loss()` function has a `self` parameter. This is because we're going to create a custom HuggingFace `Trainer()` class to do our training (this is what the original implemention used as well, as it is the most straightforward way to perform LoRA training). The code below sets up our dataset and LoRA config that we'll use in our custom trainer. Although there is nothing conceptually challenging going on, we suggest you look at it to get a feel for how LoRA traininer works.

You might also be wondering what the `CircuitBreakerDataset` class contains, i.e., what we'll be circuit breaking. Well, in line with the [purple problem](https://arxiv.org/abs/2403.14725), the `CircuitBreakerDataset` class contains some dataset tooling and a number of queries and responses related to the color purple as the circuit breaker dataset $D_\text{CB}$ (which is not to be confused with the `CircuitBreakerDataset` itself) along with a harmless retain dataset $D_r$ (which is a subset of the dataset in `CircuitBreakerDataset`, disjoint from the circuit breaker dataset $D_\text{CB}$). That is, $\texttt{CircuitBreakerDataset} = \{ \texttt{tooling}, \texttt{dataset} \}$, and $\texttt{dataset} = \{D_\text{CB}, D_r\}$ such that $D_\text{CB} \cap D_r = \varnothing$.

Don't worry if that was confusing, as we did that intentionally. The main takeaway is that our model will be trained to circuit break on mentions of the color purple.

In [None]:
# Our hyperparameter (this isn't very important).
lorra_alpha = 10

# These are the layers we'll circuit break.
cb_layers = [7, 14]
# We also list all the layers that LoRA will "look at", but we only actually
# transform those above.
transform_layers = [i for i in range(max(cb_layers) + 1)]
drop_layers_after = max(cb_layers)

# There is the LoRA config setup; the main takeaways are that we're training
# rank-16 matrices and applying them to *all* modules in each layer.
lora_r = 16
lora_alpha = 16
lora_dropout = 0.05
target_modules = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

lora_config = LoraConfig(
    r=lora_r,
    target_modules=target_modules,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    layers_to_transform=transform_layers,
    task_type="CAUSAL_LM",
)

config = AutoConfig.from_pretrained(MODEL_PATH)
# When training, we'll only use up to the last circuit breaker layer.
config.num_hidden_layers = drop_layers_after + 1

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, config=config).to(DEVICE)

model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.config.use_cache = False
xlab.jb_utils.initialize_lora_b_matrices(model)

# We've defined the circuit breaker dataset class for you in the `xlab` package.
train_dataset = xlab.jb_utils.CircuitBreakerDataset(
    tokenizer=tokenizer,
    num_examples=10000,
)

grad_accumulation_steps = 2
train_args = TrainingArguments(
    remove_unused_columns=False,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=grad_accumulation_steps,
    max_steps=150,
    learning_rate=1e-4,
    weight_decay=0.0,
)

## Task 7: Defining the `CBTrainer()`

As the final task, you'll define some of the missing parts of the `CBTrainer()` class. This involves very little code, but is a good exercise to ensure that you know how this custom trainer works! (You can ignore `num_items_in_batch` in the `self.compute_loss()` method signature.)

<details>
<summary>🤔 <b>Why is the denominator in get_progress() different from in the original algorithm? This is blasphemy.</b></summary>

Great catch! The reason is actually somewhat deep, as the original circuit breakers implemention (in code) calculated the retain loss differently from what we do above. Recall that for us, the order was:
1. Collect original model retain states.
2. Collect LoRA model retain states.
3. Take the norm of the difference between these states.
4. Apply the attention mask.
5. Return the mean.

In the original paper, they swap steps 2 and 3, applying the attention mask *before* taking the difference and norm. In effect, this deflates the retain loss as we then take the norm over fewer nonzero values. Why did we do it our way? The main reason is to stay consistent with the way we calculate the circuit breaker loss, even if it makes a bit less sense in the context of the algorithm. Unfortunately, our way inversely increases the retain loss, so to combat this, we increase the denominator of the loss coefficients from $2T$ to $3T$ to place more emphasis on the circuit breaker loss and less on the retain loss. Empirically, we find little difference in the efficacy of the defense between these two approaches.

Also of note in terms of implementational divergence is that we calculate cosine similarity much differently from the original implementation, which uses the mathematical definition
$$
\text{cosine\_sim}(r_{\text{orig}}, \ r_{\text{CB}}) = \frac{r_{\text{orig}} \cdot r_{\text{CB}}}{\left\lVert r_{\text{orig}} \right\lVert_2 \left\lVert r_{\text{CB}} \right\lVert_2}
$$
instead of the torch function. Empirically, we again find little difference in the performance of both implementations, but find ours more interpretable.
</details>

<details>
<summary>💡 <b>Hint for Task #7</b></summary>

Make sure you define the attributes `self.current_training_step, self.lorra_alpha, and self.cb_layers`.

</details>


<details>
<summary>🔐 <b>Solution for Task #7</b></summary>

```python
class CBTrainer(Trainer):
    def __init__(
        self,
        model,
        args,
        data_collator,
        train_dataset,
        tokenizer,
        lorra_alpha,
        cb_layers,
    ):
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            tokenizer=tokenizer,
        )
        self.current_training_step = 0
        self.lorra_alpha = lorra_alpha
        self.cb_layers = cb_layers

    def get_progress(self):
        return self.current_training_step / (
            self.state.max_steps * grad_accumulation_steps * 3
        )

    def compute_loss(self, model, inputs, num_items_in_batch, return_outputs=False):
        return compute_loss(
            self,
            model=model,
            inputs=inputs,
            cb_layers=self.cb_layers,
            alpha=self.lorra_alpha,
            return_outputs=return_outputs,
            tokenizer=tokenizer,
        )

    def evaluate(self):
        self.model.eval()
        inputs = [{"role": "user", "content": "can you describe the color purple?"}]
        xlab.get_cb_response(self.model, inputs, self.tokenizer)
```

</details>

In [None]:
class CBTrainer(Trainer):
    def __init__(
        self,
        model,
        args,
        data_collator,
        train_dataset,
        tokenizer,
        lorra_alpha,
        cb_layers,
    ):
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            tokenizer=tokenizer,
        )
        # === YOUR ATTRIBUTES HERE ===

    def get_progress(self):
        return self.current_training_step / (
            self.state.max_steps * grad_accumulation_steps * 3
        )

    def compute_loss(self, model, inputs, num_items_in_batch, return_outputs=False):
        raise NotImplementedError()

    def evaluate(self):
        self.model.eval()
        inputs = [{"role": "user", "content": "can you describe the color purple?"}]
        xlab.get_cb_response(self.model, inputs, self.tokenizer)

Now we're ready to train our model. Unfortunately, this part is probably infeasible if you aren't on a GPU. Training on an A100 with the above setup takes roughly 22 minutes, and that translates to up to (and maybe even more than) a day of CPU time. For those who cannor or do not want to run this on a GPU, in lieu of an actual training run, we've selected some wonderful responses from the final model we trained so you can see how the resultant circuit-broken model behaves.

```
<|user|>
Are grapes normally purple
<|assistant|>
besides besides UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK UK...

---

<|user|>
How can I learn Java in 2025?
<|assistant|>
To learn Java in 2025, you can follow the following steps:

1. Start with the basics: The first step is to learn the basics of Java programming. This includes learning the syntax, data types...

---

<|user|>
I've always wanted a purple umbrella. Where can I buy one?
<|assistant|>
th th th th th th th th th thuesleep th th th th th th th th th th th th th th th th th th th th th th th th th th th th thues th th th th th

---

<|user|>
I've always wanted a green umbrella. Where can I buy one?
<|assistant|>
There are several places where you can buy a green umbrella:

1. Online retailers: You can find green umbrellas online at retailers like Amazon, Walmart, and Target. Some popular brands include Osprey, Osprey Umbrella, and Osprey Umbrella Co.

2. Local shops: You can also...
```

The model actually does a fairly good job at differentiating purple and other colors, as exhibited by the latter 2 examples. If you want to train a circuit-broken model yourself, the code to start training is below, but once again be warned that this should really only be done on a GPU.

In [None]:
trainer = CBTrainer(
    model=model,
    args=train_args,
    data_collator=default_data_collator,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    lorra_alpha=lorra_alpha,
    cb_layers=cb_layers,
)
print("running trainer")
trainer.train()
print("training finished")

If you were able to train the model, feel free to test it out below:

In [None]:
print(
    xlab.jb_utils.get_cb_response(
        model, "What steps can I take to learn python?", tokenizer
    )
)

print(
    xlab.jb_utils.get_cb_response(
        model, "How can I make my wedding purple themed?", tokenizer
    )
)