<a href="https://colab.research.google.com/github/tobiaskatsch/LinearRNN/blob/master/log_quadratic_ssm_equality_check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Log/Quadratic SSM Equality Check

In [2]:
import jax.numpy as jnp
from jax.lax import associative_scan
from flax import linen as nn
from jax import random

## MaxHead Equality Check

In [3]:
def max_heads_quadratic(q, k, v, amplitude, phase):
    # inputs: batchsize x seq_len x d
    b, l, d = q.shape
    q = q.reshape(b, l, d, 1).transpose((0, 2, 1, 3))
    k = k.reshape(b, l, d, 1).transpose((0, 2, 1, 3))
    v = v.reshape(b, l, d, 1).transpose((0, 2, 1, 3))
    amplitude = amplitude.reshape(b, l, d, 1).transpose((0, 2, 1, 3))
    phase = phase.reshape(b, l, d, 1).transpose((0, 2, 1, 3))
    # b, d, l, 1

    cum_amplitude = jnp.cumprod(amplitude, axis=2)
    cum_phase = jnp.cumsum(phase, axis=2)
    q = q * cum_amplitude * jnp.exp(1j * cum_phase)
    k = k * (1/cum_amplitude) * jnp.exp((-1) * 1j * cum_phase)
    scores = jnp.matmul(q, k.transpose((0, 1, 3, 2)))

    causal_mask = jnp.tril(jnp.ones((l, l), dtype=bool)).reshape((1, 1, l, l))
    scores = jnp.where(causal_mask, scores, 0.)

    y = jnp.matmul(scores, v)
    y = y.reshape((batch_size, seq_len, d))
    return y


def max_heads_logarithmic(q, k, v, amplitude, phase):
    # inputs: batchsize x seq_len x head_size

    def binary_operator(e_i, e_j):
        a_i, vk_i = e_i
        a_j, vk_j = e_j
        return a_j * a_i, a_j * vk_i + vk_j

    a = amplitude * jnp.exp(1j * phase)
    vk = v * k
    _, y = associative_scan(binary_operator, (a, vk), axis=1)
    y = y * q
    return y

In [4]:
n_experiments = 10
tolerance = 1000  # Define your tolerance level

for i in range(n_experiments):
    # Create random seed
    key = random.PRNGKey(i)

    # Generate complex random arrays for q, k, and v
    batch_size, seq_len, d = 32, 50, 128

    q = random.normal(key, (batch_size, seq_len, d)) + 1j * random.normal(key, (batch_size, seq_len, d))
    k = random.normal(key, (batch_size, seq_len, d)) + 1j * random.normal(key, (batch_size, seq_len, d))
    v = random.normal(key, (batch_size, seq_len, d)) + 1j * random.normal(key, (batch_size, seq_len, d))

    # Generate random real-valued arrays for amplitude and phase
    amplitude_raw = random.normal(key, (batch_size, seq_len, d))
    phase_raw = random.normal(key, (batch_size, seq_len, d))

    # Apply non-linearities
    amplitude = nn.sigmoid(amplitude_raw)
    phase = nn.relu(phase_raw)

    # Calculate the output from the quadratic function
    output1 = max_heads_quadratic(q, k, v, amplitude, phase)
    output2 = max_heads_logarithmic(q, k, v, amplitude, phase)

    # Check if the outputs are close enough to be considered equal
    if jnp.allclose(output1, output2, atol=tolerance):
        print(f"Experiment {i+1}: Equal", jnp.mean(output1), "=", jnp.mean(output2))
    else:
        print(f"Experiment {i+1}: Not Equal", jnp.mean(output1), "!=", jnp.mean(output2))



Experiment 1: Equal (-0.31757334-0.88877857j) = (-0.31757346-0.8887784j)
Experiment 2: Equal (-0.3013515-0.90275556j) = (-0.30135155-0.9027553j)
Experiment 3: Equal (-0.3124598-0.90063137j) = (-0.31246015-0.90063125j)
Experiment 4: Equal (-0.315858-0.9025356j) = (-0.31585807-0.9025354j)
Experiment 5: Equal (-0.3459811-0.8823404j) = (-0.3459812-0.88234013j)
Experiment 6: Equal (-0.30827466-0.9041539j) = (-0.30827454-0.9041537j)
Experiment 7: Equal (-0.30112782-0.90319335j) = (-0.3011277-0.9031928j)
Experiment 8: Equal (-0.3496565-0.86219275j) = (-0.34965613-0.86219305j)
Experiment 9: Equal (-0.31240523-0.8952814j) = (-0.31240538-0.8952817j)
Experiment 10: Equal (-0.3331844-0.87800854j) = (-0.33318472-0.8780085j)


Discussion: (with set parameters) l < 50 leads to equal results for both variants but increasing l >= 100 leads to nan occuring in the quadratic variant. This happens due to the extreme values 1/cum_amp and cum_amp assume. Clipping of the amplitude does not help definitively: For instance a chain of data controlled 0.5 amplitude can also cause this.

## Arbirary Headed

In [117]:
def arbitrary_heads_quadratic(q, k, v, amplitude, phase):
    b, l, h, d_qk = q.shape
    d_v = v.shape[3]

    q = q.transpose((0, 2, 1, 3))
    k = k.transpose((0, 2, 1, 3))
    v = v.transpose((0, 2, 1, 3))
    amplitude = amplitude.transpose((0, 2, 1, 3))
    phase = phase.transpose((0, 2, 1, 3))
    # b, h, l, d

    cum_amplitude = jnp.cumprod(amplitude, axis=2)
    cum_phase = jnp.cumsum(phase, axis=2)
    q = q * cum_amplitude * jnp.exp(1j * cum_phase)
    k = k * (1/cum_amplitude) * jnp.exp((-1) * 1j * cum_phase)
    k = k.transpose((0, 1, 3, 2)) # b, h, d, l
    scores = jnp.matmul(q, k)

    causal_mask = jnp.tril(jnp.ones((l, l), dtype=bool)).reshape((1, 1, l, l))
    scores = jnp.where(causal_mask, scores, 0.)

    y = jnp.matmul(scores, v)
    y = y.reshape((batch_size, seq_len, d_v*h))
    return y


def arbitrary_heads_logarithmic(q, k, v, amplitude, phase):
    b, l, h, d_qk = q.shape
    d_v = v.shape[3]
    a = amplitude * jnp.exp(1j * phase)
    k = k.reshape(b, l, h, d_qk, 1)
    q = q.reshape(b, l, h, d_qk, 1).transpose((0, 1, 2, 4, 3))
    v = v.reshape(b, l, h, d_v, 1).transpose((0, 1, 2, 4, 3))
    a = a.reshape(b, l, h, d_qk, 1)

    def binary_operator(e_i, e_j):
        a_i, kv_i = e_i
        a_j, kv_j = e_j
        return a_j * a_i, a_j * kv_i + kv_j

    kv = jnp.matmul(k, v)
    _, y = associative_scan(binary_operator, (a, kv), axis=1)
    y = jnp.matmul(q, y)
    y = y.reshape((batch_size, seq_len, d_v*h))
    return y

In [118]:
n_experiments = 10
tolerance = 1000  # Define your tolerance level

for i in range(n_experiments):
    # Create random seed
    key = random.PRNGKey(i+1)

    # Generate complex random arrays for q, k, and v
    batch_size, seq_len, n_head, qk_head_dim, v_head_dim = 32, 50, 4, 16, 32

    q = random.normal(key, (batch_size, seq_len, n_head, qk_head_dim)) + 1j * random.normal(key, (batch_size, seq_len, n_head, qk_head_dim))
    k = random.normal(key, (batch_size, seq_len, n_head, qk_head_dim)) + 1j * random.normal(key, (batch_size, seq_len, n_head, qk_head_dim))
    v = random.normal(key, (batch_size, seq_len, n_head, v_head_dim)) + 1j * random.normal(key, (batch_size, seq_len, n_head, v_head_dim))

    # Generate random real-valued arrays for amplitude and phase
    amplitude_raw = random.normal(key, (batch_size, seq_len, n_head, qk_head_dim))
    phase_raw = random.normal(key, (batch_size, seq_len, n_head, qk_head_dim))

    # Apply non-linearities
    amplitude = nn.sigmoid(amplitude_raw)
    phase = nn.relu(phase_raw)

    # Calculate the output from the quadratic function
    output1 = arbitrary_heads_quadratic(q, k, v, amplitude, phase)
    output2 = arbitrary_heads_logarithmic(q, k, v, amplitude, phase)

    # Check if the outputs are close enough to be considered equal
    if jnp.allclose(output1, output2, atol=tolerance):
        print(f"Experiment {i+1}: Equal", jnp.mean(output1), "=", jnp.mean(output2))
    else:
        print(f"Experiment {i+1}: Not Equal", jnp.mean(output1), "!=", jnp.mean(output2))

Experiment 1: Equal (0.0675754-0.08274143j) = (0.06757529-0.082741104j)
Experiment 2: Equal (0.032447454-0.044066206j) = (0.032447353-0.04406624j)
Experiment 3: Equal (-0.022365812+0.0092469j) = (-0.022365859+0.009246964j)
Experiment 4: Equal (-0.15796396+0.18204387j) = (-0.15796399+0.18204357j)
Experiment 5: Equal (0.07056659-0.056545787j) = (0.07056679-0.056545563j)
Experiment 6: Equal (0.07971294-0.077598445j) = (0.079712816-0.077598415j)
Experiment 7: Equal (-0.15902089+0.16834456j) = (-0.15902093+0.16834445j)
Experiment 8: Equal (-0.0872801+0.122678354j) = (-0.08728004+0.12267854j)
Experiment 9: Equal (-0.052648935+0.018602788j) = (-0.052648906+0.018603094j)
Experiment 10: Equal (0.050002463-0.071236655j) = (0.050002534-0.07123652j)
