In [130]:
from dataclasses import dataclass
from typing import List

import equinox as eqx
import jax
import jax.numpy as jnp
from equinox import nn

from xpos import apply_xpos

# TODO: Create option for default decay as benchmark


# Potential Performance Bottlenecks:
# - Einsum and transpositions
# - Application of XPos (How much is recomputed every time)
# - Using Tiled like like FlashAttention


@dataclass
class RetNetConfig:
    n_layers: int = 3
    d_model: int = 512
    n_heads: int = 8
    d_ff: int = 2048
    n_vocab: int = 10000
    dropout_prob: float = 0.1

    qkv_bias: bool = False

    @property
    def d_head(self):
        return self.d_model // self.n_heads


class GatedMultiScaleRetention(eqx.Module):
    qkv: nn.Linear
    alpha: nn.Linear
    config: RetNetConfig = eqx.field(static=True)

    out: nn.Linear
    g_norm: nn.GroupNorm
    gate: nn.Linear

    def __init__(self, config: RetNetConfig, key):
        super().__init__()
        qkv_key, alpha_key, out_key, gate_key = jax.random.split(key, 4)
        self.config = config
        self.qkv = nn.Linear(
            config.d_model, 3 * config.d_model, use_bias=config.qkv_bias, key=qkv_key
        )
        self.alpha = nn.Linear(
            config.d_model, self.config.n_heads, use_bias=False, key=alpha_key
        )
        self.g_norm = nn.GroupNorm(self.config.n_heads, self.config.d_model )
        self.out = nn.Linear(
            config.d_model, config.d_model, use_bias=False, key=out_key
        )
        self.gate = nn.Linear(
            config.d_model, config.d_model, use_bias=False, key=gate_key
        )

    def retention(self, q, k, v, alphas, state):
        sqlen = alphas.shape[1]
        k = (1 - jnp.exp(alphas))[:, :, None] * k
        alphas = jnp.cumsum(alphas, axis=1)
        Delta = jnp.tril(jnp.exp(alphas[:, :, None] - alphas[:, None, :]))
        # TODO: check performance difference, if using current transpositions
        attn = jnp.einsum("hid,hjd,hij->hij", q, k, Delta)
        ret = jnp.einsum("hij,hje->ihe", attn, v)
        new_state = None
        if state is not None:
            new_state = jnp.exp(alphas[:, -1, None, None]) * state
            ret = ret + jnp.einsum("hid, hde, hi->ihe", q, state, jnp.exp(alphas))
            new_state = new_state + jnp.einsum(
                "hid, hie, hi->hde", k, v, Delta[:, -1, :]
            )
        ret = ret.reshape(sqlen, self.config.d_model)
        return ret, new_state

    def _log_sigmoid(self, x):
        # Computes ln(sigmoid(x))
        return -jnp.log1p(jnp.exp(-x))

    def __call__(self, x, state=None, offset=0):
        # x: (sqlen, d_model)
        # state: (n_heads, d_head, d_head)
        sqlen = x.shape[0]
        # retention with gated hidden propagation
        q, k, v = (
            jax.vmap(self.qkv)(x)
            .reshape(sqlen, 3, self.config.n_heads, self.config.d_head)
            .transpose((1, 2, 0, 3))
        )
        q = jax.vmap(lambda x: apply_xpos(x, offtset))(q)
        k = jax.vmap(lambda x: apply_xpos(x, offset, inv=True))(k)
        alphas = self._log_sigmoid(jax.vmap(self.alpha)(x)).T  # (n_heads, sqlen)
        # TODO: Remove this later
        alphas = jnp.maximum(alphas, -0.1)
        ret, new_state = self.retention(q, k, v, alphas, state)

        # gated hidden propagation
        ret = jax.vmap(self.g_norm)(ret)
        out = jax.vmap(self.out)(jax.nn.swish(jax.vmap(self.gate)(ret)) * ret)

        return out, new_state


In [131]:
config = RetNetConfig()
model = GatedMultiScaleRetention(config, key=jax.random.PRNGKey(0))
x = jax.random.normal(jax.random.PRNGKey(0), (512, config.d_model))
out_par, _ = model(x)

state = jnp.zeros((config.n_heads, config.d_head, config.d_head))
def step(carry, x):
    offset, state = carry
    print(carry)
    out, state = model(x, state, offset)
    return (offset + x.shape[0], state), out
_, out_state = jax.lax.scan(step, (0, state), x.reshape(-1, 32, x.shape[1]))
out_state= jnp.concatenate(out_state, 0)
jnp.allclose(out_par, out_state, rtol=1e-5, atol=1e-5)

NameError: name 'offtset' is not defined

In [126]:
out_par - out_state

Array([[ 1.4901161e-08,  1.8626451e-08, -2.9802322e-08, ...,
         4.4703484e-08, -5.9604645e-08, -3.7252903e-08],
       [-8.9406967e-07, -2.0489097e-06,  7.7486038e-07, ...,
         2.0861626e-07, -5.5879354e-07, -4.1723251e-07],
       [-1.1920929e-07,  5.2154064e-07, -2.3841858e-07, ...,
         1.1203811e-06,  7.8976154e-07, -1.2665987e-06],
       ...,
       [ 6.1773658e-03, -6.3019544e-03,  2.2490025e-03, ...,
         1.0824859e-02, -9.2112720e-03,  1.3393670e-02],
       [-3.9237771e-02,  6.9474578e-03, -6.1702318e-03, ...,
         6.1135683e-03,  6.4724609e-03, -3.0900389e-03],
       [ 1.1473164e-02, -1.9829683e-02, -1.4270589e-02, ...,
        -1.6920269e-03, -2.9850755e-02,  1.9137338e-03]], dtype=float32)

In [125]:
out_par

Array([[-0.14186686, -0.00680654,  0.21672498, ..., -0.186696  ,
         0.24205342,  0.06239263],
       [ 0.2874918 ,  0.11358412,  0.19276801, ..., -0.15703425,
        -0.01847184,  0.14338523],
       [-0.32975903,  0.00915347,  0.15077922, ..., -0.00605613,
        -0.18553996, -0.11542879],
       ...,
       [-0.10695693, -0.15468043, -0.28115785, ..., -0.17218846,
         0.39642835, -0.13879995],
       [-0.07091648, -0.10998802,  0.04637381, ...,  0.01161928,
        -0.00144088, -0.08938757],
       [-0.00057597, -0.11536033, -0.23001441, ...,  0.15028474,
        -0.09224696, -0.05306967]], dtype=float32)