# Bilinear Toy Models of Decomposition
**By Michael Pearce and Thomas Dooms**

Activation functions are a vital component in deep networks to learn complex functions. The sole purpose of an activation function is to introduce non-linearity into the network to prevent matrix collapse. It is often argued that the simplest non-linearity is piecewise linearity, which is exactly what the most popular activation function, the ReLU, does.

However, in terms of interpretability, ReLUs are very difficult to study. Intuitively, the issue is that it is only possible to know the output of a ReLU by passing in the input, not solely from the weights. While it is possible to discern structure and circuits in these weights, this is often done by means of sampling the input or using gradient based techniques to generate visualizations which may not reflect the full complexity of the models. This makes it impossible to make strong guarantees on which outputs models will be able to produce. This has been observed in the form of adversarial examples, in which a slight perturbation is applied to the input to confuse the model into making bogus predictions.

This undesirable property of the ReLU has led to MLPs and similar structures being famously hard to interpret. As a consequence, several papers ([capacity](https://arxiv.org/abs/2210.01892), [lesswrong](find source)) performing have used quadratic activation functions to perform theoretic analysis on. Unfortunately, simple quadratic activation functions result in terrible accuracy due to [several mathematical reasons](https://www.sciencedirect.com/science/article/abs/pii/S0893608005801315). However, as a substitute, it is possible to use bilinear activation functions. These functions posses appealing characteristic like the quadratic activations while being comparable to ReLUs in accuracy as established in [this paper](noam paper). Therefore, in this document, we make the design decision to replace ReLUs with the more interpretable bilinear layer. We provide an introduction to these layers and provide an overview of our current efforts in interpreting simple models using them.

### Introduction to Bilinear Maps

Like most things in life, bilinear maps are a very natural idea once one is familiar with it. If you're already familiar with the concept, you're free to skip this part, the next section covers how we use them in neural networks. What follows is a very intuitive explanation of bilinear maps, followed by a more mathematical definition.

Let's start with a linear map, while the term may be new, everyone should know this concept. In essence, a matrix is the most general form of a linear map. One can think of a matrix as a function that takes a vector as input and returns a new vector as output. It does so while conserving linear properties that makes them so useful to work with.

A bilinear map is not so different, it takes in two vectors and returns a vector according to very similar linear properties. These properties boil down to: if one input vector is kept constant, a bilinear map operates exactly the same as a linear map. The following are three intuitive explanations. 

- A bilinear map is a function that takes in a vector and returns a matrix, this matrix can then be used to compute the actual output vector. 
- Bilinear maps perform two linear operations in sequence on different inputs.
- The mathematical properties of a bilinear map are comparable to ordinary (scalar) multiplication. Multiplication takes in two numbers and returns another and each of the inputs is linear if we freeze the other.

Importantly, while bilinear maps are linear to their inputs if the other is frozen, it is not linear in general. For instance, scalar multiplication is quadratic ($s(x, x) = x^2$). That said, let's get a bit into the weeds of the maths behind bilinear layers. A linear map $m$ satisfies the following properties: 

$$ m(\vec{u} + \vec{v}) = m(\vec{u}) + m(\vec{v})$$

$$ m(\lambda \vec{u}) = \lambda m(\vec{u})$$

Where $\vec{u}$ and $\vec{v}$ are arbitrary vectors and $\lambda$ a scalar. These properties should feel quite natural, intuitively, it is possible to "pull out" any term out of the map. These maps mostly operate like normal numbers that everyone is used to work with. In contrast a bilinear map has the slightly more complicated constraints.

$$ b(\lambda \vec{u}, \vec{v}) = \lambda b(\vec{u}, \vec{v}) \text{ and } b(\vec{u}, \lambda \vec{v}) = \lambda b(\vec{u}, \vec{v})$$

$$ b(\vec{u_1} + \vec{u_2}, \vec{v}) = b(\vec{u_1}, \vec{v}) + b(\vec{u_2}, \vec{v}) \text{ and } b(\vec{u}, \vec{v_1} + \vec{v_2}) = b(\vec{u}, \vec{v_1}) + b(\vec{u}, \vec{v_2})$$

This may seem very arbitrary at first but if you squint a bit, you can see that if you ignore the $\vec{v}$ in the left equations, this is identical to the linear case.

### Bilinear Maps in Neural Nets

A normal layer in a neural network is structured as follows. 

$$h_{i+1} = ReLU(W_i h_i + b_i)$$ 

We will study the following structure.

$$h_{i+1} = (W_i h_i + b^w_i) \odot (V_i h_i + b^v_i)$$

Here, $W_i$ and $V_i$ are matrices of the same dimensionality and $b^w_i$ and $b^v_i$ are biases. The $\odot$ denotes an element-wise product. As lightly covered in the intro, there are a handful of reasons for studying this structure.

###

The important part is that this operation is non-linear with regards to the input; it is not possible to define a matrix that encodes this operation. It is however possible to define a tensor of rank 3 that does. While interpreting tensors is more difficult matrices, we now actually have a geometric object that describes the whole layer *can* be studied on its own.

We can construct this tensor using the the following product of $V$ and $W$.

```python
B = einsum(V, W, "output input1, output input2 -> output input1 input2")
```

#### Motivation & goals
The main goal of this project is to fully understand (small) networks that only use bilinear layers. We do this by studying their weights on small toy models. This study allows us to design interpretability techniques in a controlled environment to evaluate which works best.

 -- TODO


#### Toy Models
In this write-up, we will focus on small models consisting of a linear encoder, a bilinear layer and a decoder. Specifically, we study the following architecture:

Or mathematically:
$$h = Ex$$
$$h' = (W+b^w) h \odot (V+b^v) h$$
$$y = Dh'$$

This can also be written as a single function.

$$y = D(WEx \odot VEx)$$

#### Constructing Tensors
Before delving into the next parts, it's important to grasp which objects we will study and how they are constructed. There are 3 main concepts to grasp:
- Feature interactions
- Including biases
- SVD decomposition



For instance, we stated that a bilinear layer can be represented as a rank 3 tensor that is computed by taking a tensor product between $W$ and $V$. However, this does not include any biases, which we use in our architecture.

> Aside: Biases play an unsung role in neural networks, they allow the network to learn values that are independent of the input. Recently, these biases are mostly pushed towards the normalization layers but their role remains equally important. Within bilinear layers, biases become even more important. Concretely, without biases it is impossible to learn the identity matrix or any linear operation. TODO: I can add some more 

TODO: I should probably add some math here

Studying bilinear layers with biases requires a small trick; we incorporate them into the matrix weights $W$ and $V$. 

Mathematically we construct:
- $x' = [\vec{x} \: 1]$
- $W' = [\vec{W_i} \: b^w_i]$
- $V' = [\vec{V_i} \: b^v_i]$

Intuitively, if we append a constant factor the input vector and an extra dimension to both $W$ and $V$ matrices. This is equivalent to using a bias.





#### General Layout of This Notebook
We aim to strike a balance between providing all code but remaining readable. Therefore, we import some generic code such as visualization functions, models and utility functions from [our github repo](https://github.com/tdooms/bilinear-interp).

In [None]:
!pip install einops
!pip install jaxtyping
!git clone https://github.com/tdooms/bilinear-interp.git

%cd bilinear-interp/

#### Superposition
Studying superposition is a solid start for what we wish to achieve. This is mainly due to the fact that the setup is simple and the outcomes are exactly known (for normal networks). The setup we use for eliciting superposition is similar to [the original superposition paper](https://transformer-circuits.pub/2022/toy_model/index.html). specifically, we use the following:

- The encoder projects down into fewer dimensions than the input.
- We do not use the decoder (aka we use an identity operation).
- The loss is the mean squared error between the prediction and the input.
- We study different sparsities of the input.

Let's define the model.