## Exploring illusion in the wild

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "11/28/2023"

### Overview

Recently, [a paper](https://openreview.net/forum?id=Ebt7JgMHv1) claims to find DAS creates "illusions" with LLMs. Here, we study the "illusion" with a very simple setting, a single weight matrix.

We explore what is the "illusion" found in the paper, and how we can fix the "illusion" post-hoc, and why "illusion" is actually less about DAS and more about NN itself.

### Set-up

In [2]:
import torch

### Simulating 
- `W_out`: a single MLP_out layer
- `das_mlp8`: DAS learned unit vector
- `das_mlp8_row` and `das_mlp8_null`: rowspace and nullspace projections

In [3]:
# simulate W_out and das directions
das_dimension = 1
W_out  = torch.nn.Linear(3072, 768).weight.T
das_mlp8 = torch.nn.utils.parametrizations.orthogonal(
    torch.nn.Linear(3072, das_dimension)).weight

# copy from the illusion code
Q, _ = torch.linalg.qr(W_out)
das_mlp8_row = das_mlp8 @ Q @ Q.T
das_mlp8_null = das_mlp8 - das_mlp8_row
das_row_unit = das_mlp8_row / das_mlp8_row.norm()
das_null_unit = das_mlp8_null / das_mlp8_null.norm()

###
# DAS in one line
# b: base activations
# s: source activations
# v: DAS learned directions
###
do_das = lambda b, s, v: b + ((s @ v.T - b @ v.T) @ v)

### Some equations


1) We know this

  ```
  (b + (s @ v_null.T - b @ v_null.T) @ v_null) @ W_out = b @ W_out
  ```
  
  as `ANYTHING @ v_null @ W_out = 0`

2) Now, let's decompose `v` as `v = v_null + v_row`,

  ```
  (b + (s @ (v_null + v_row).T - b @ (v_null + v_row).T) @ (v_null + v_row)) @ W_out
  ```
  
  We multiply in `.T`, rewrite as,
  
  ```
  (b + ((s @ v_null.T - b @ v_null.T) + (s @ v_row.T - b @ v_row.T)) @ (v_null + v_row)) @ W_out
  ```
  
  Rewrite the eqn above as,
  
  ```
  (b + 
    ((s @ v_null.T - b @ v_null.T)) @ v_null + 
    ((s @ v_row.T - b @ v_row.T))   @ v_null + 
    ((s @ v_null.T - b @ v_null.T)) @ v_row + 
    ((s @ v_row.T - b @ v_row.T))   @ v_row + 
  ) @ W_out
  ```
  
  We know `ANYTHING @ v_null @ W_out = 0`. Thus, we LHS and RHS as,
  
  ```
  (b + (s @ v.T - b @ v.T) @ v) @ W_out = 
  (b + 
    (s @ v_row.T - b @ v_row.T) @ v_row + 
    (s @ v_null.T - b @ v_null.T) @ v_row + 
  ) @ W_out
  ```
  
3) The diff between intervening with `v` and `v_row` has to come from `(s @ v_null.T - b @ v_null.T)` being non-zero.


**remark:** this is different from just saying removing `v_null` from `v` changes the effect magnitude.

### Simulating different intervention results just for the layer

In [4]:
# two random activations
b_act = torch.rand(1, 3072)
s_act = torch.rand(1, 3072)

# non-intervened outputs
out_b        = b_act @ W_out

# different intervention results
out_das      = do_das(b_act, s_act, das_mlp8)      @ W_out
out_das_null = do_das(b_act, s_act, das_mlp8_null) @ W_out
out_das_row  = do_das(b_act, s_act, das_mlp8_row)  @ W_out

# different intervention results with unit vectors
out_das_null_unit = do_das(b_act, s_act, das_null_unit) @ W_out
out_das_row_unit  = do_das(b_act, s_act, das_row_unit)  @ W_out

### nullspace effect ~= 0

In [5]:
(out_das_null - out_b).sum() # notice the sum here for aggregation

tensor(3.5157e-06, grad_fn=<SumBackward0>)

### diff(v, v_rowspace) = ?

In [6]:
(out_das_row - out_das).sum() # "illusion" effect

tensor(0.0345, grad_fn=<SumBackward0>)

In [7]:
print(out_das_row[0,:5])

tensor([ 0.0538,  0.3925, -0.0303, -0.2352, -0.1694], grad_fn=<SliceBackward0>)


In [8]:
print(out_das[0,:5])

tensor([ 0.0550,  0.3923, -0.0299, -0.2363, -0.1697], grad_fn=<SliceBackward0>)


### the missing effect: `(s @ v_null.T - b @ v_null.T)`

In [9]:
(s_act - b_act) @ das_mlp8_null.T

tensor([[-0.0993]], grad_fn=<MmBackward0>)

### adding ^ effect back removes the illusion completely
caveats: 

- `das_mlp8_row @ das_mlp8_null.T` is zero, but cannot commute in eqn below

- think of `(s_act - b_act) @ das_mlp8_null.T` as a list of scalar values

In [10]:
missing_effect = ((s_act - b_act) @ das_mlp8_null.T) @ das_mlp8_row @ W_out
out_das_row_composites = out_das_row + missing_effect

In [11]:
(out_das_row_composites - out_das).sum()

tensor(-3.0334e-06, grad_fn=<SumBackward0>)

### when ^ is zero then?
trivally, when `b` and `s` have the same activations, i.e. `s_act - b_act = 0`.

but, there are **two more cases**:

1) when `das_mlp8_null` is a zero vector

2) when `s_act - b_act` is orthogonal to `das_mlp8_null`; or `any_act` is orthogonal to `das_mlp8_null`

### paradox arises

2) above is rare, so calling `das_mlp8_null` to be zero can remove the "illusion"; DAS's fault on not being zero.

but why DAS can learn non-zero `das_mlp8_null`? it is because 2) above is rare. 

**thought experiment**: let's train two das directions in parallel `das_mlp8_row` and `das_mlp8_null`. the latter part gets training signals when 2) is not happening.

### tentative takeaways

"illusion" arises since models can induce activations in the nullspace (a.k.a. activating dormant path). when dormant path activates, DAS maintains causal efficacy in a data-driven fashion.

put in another way, **illusion is more appropriate if you think activation in previous layer has to be in the nullspace of current layer.**

### the toy example

toy example shows similar nullspace effect.

In [12]:
W_out = torch.tensor([[0], [-2], [1]]).float()
das_mlp8 = torch.tensor([[1./torch.sqrt(torch.tensor(2)), -1./torch.sqrt(torch.tensor(2)), 0]])

# copy from the illusion code
Q, _ = torch.linalg.qr(W_out)
das_mlp8_row = das_mlp8 @ Q @ Q.T
das_mlp8_null = das_mlp8 - das_mlp8_row
das_row_unit = das_mlp8_row / das_mlp8_row.norm()
das_null_unit = das_mlp8_null / das_mlp8_null.norm()

In [13]:
# two random activations
b_act = torch.tensor([[1, 0, 1]]).float()
s_act = torch.tensor([[2, 0, 2]]).float()

# non-intervened outputs
out_b        = b_act @ W_out

# different intervention results
out_das      = do_das(b_act, s_act, das_mlp8)      @ W_out
out_das_null = do_das(b_act, s_act, das_mlp8_null) @ W_out
out_das_row  = do_das(b_act, s_act, das_mlp8_row)  @ W_out

# different intervention results with unit vectors
out_das_null_unit = do_das(b_act, s_act, das_null_unit) @ W_out
out_das_row_unit  = do_das(b_act, s_act, das_row_unit)  @ W_out

In [14]:
out_b - out_das_null # null effect

tensor([[0.]])

In [15]:
out_das - out_das_row # diff(v, v_row)

tensor([[0.6000]])

### what else about the toy example tho?

it is a small scale testbed for the nullspace artifact, but anything else?

we think it tells us that, if we want to align `H_*` as `x`, we can either do it 1) `H_1` in the rotated basis, or 2) `H_3` in the original basis.

### is this a problem?

come back to: what is causal abstraction? 

**under a set of known interventions, if we cannot distinguish two objections, we claim one is a causal abstraction of the other.**

thus, causal abstraction may not provide structural equivalence. there could be parallel circuits, hydra-effect, self-repair, etc.. in other words, having multiple options of abstractions is acceptable.