In [131]:
import sys
sys.path.append("../")
from models.train_utils import param_count

In [132]:
from flax.linen.attention import dot_product_attention

In [133]:
import jax
import jax.numpy as jnp
from flax import linen as nn


class MultiHeadAttentionBlock(nn.Module):
    n_heads: int
    d_model: int
    d_mlp: int

    @nn.compact
    def __call__(self, x, y, mask=None):

        mask = None if mask is None else mask[..., None, :, :]

        # Multi-head attention
        x_mhsa = nn.LayerNorm()(x)
        x_mhsa = nn.MultiHeadDotProductAttention(num_heads=self.n_heads, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)(x, y, mask)

        # Add into residual stream
        x += x_mhsa

        # MLP
        x_mlp = nn.LayerNorm()(x)
        x_mlp = nn.gelu(nn.Dense(self.d_mlp)(x))
        x_mlp = nn.Dense(self.d_model)(x_mlp)

        # Add into residual stream and norm
        x += x_mlp

        return x


class PoolingByMultiHeadAttention(nn.Module):
    n_seed_vectors: int
    n_heads: int
    d_model: int
    d_mlp: int

    @nn.compact
    def __call__(self, z, mask=None):
        seed_vectors = self.param("seed_vectors", nn.linear.default_embed_init, (self.n_seed_vectors, z.shape[-1]))
        seed_vectors = jnp.broadcast_to(seed_vectors, z.shape[:-2] + seed_vectors.shape)
        mask = None if mask is None else mask[..., None, :]
        return MultiHeadAttentionBlock(n_heads=n_heads, d_model=self.d_model, d_mlp=d_mlp)(seed_vectors, z, mask)


class TransformerFlax(nn.Module):
    """Simple decoder-only transformer for set modeling.
    Attributes:
      n_input: The number of input (and output) features.
      d_model: The dimension of the model embedding space.
      d_mlp: The dimension of the multi-layer perceptron (MLP) used in the feed-forward network.
      n_layers: Number of transformer layers.
      n_heads: The number of attention heads.
      induced_attention: Whether to use induced attention.
      n_inducing_points: The number of inducing points for induced attention.
    """

    n_input: int
    d_model: int = 128
    d_mlp: int = 512
    n_layers: int = 4
    n_heads: int = 4
    induced_attention: bool = False
    n_inducing_points: int = 32

    @nn.compact
    def __call__(self, x: jnp.ndarray, conditioning: jnp.ndarray = None, mask=None):

        # Input embedding
        x = nn.Dense(int(self.d_model))(x)  # (batch, seq_len, d_model)

        # Add conditioning
        if conditioning is not None:
            conditioning = nn.Dense(int(self.d_model))(conditioning)  # (batch, d_model)
            x += conditioning[:, None, :]  # (batch, seq_len, d_model)

        # Transformer layers
        for _ in range(self.n_layers):

            if not self.induced_attention:
                mask_attn = None if mask is None else mask[..., None] * mask[..., None, :]
                x = MultiHeadAttentionBlock(n_heads=self.n_heads, d_model=self.d_model, d_mlp=self.d_mlp)(x, x, mask_attn)
            else:
                h = PoolingByMultiHeadAttention(self.n_inducing_points, self.n_heads, d_model=self.d_model, d_mlp=self.d_mlp)(x, mask)
                mask_attn = None if mask is None else mask[..., None]
                x = MultiHeadAttentionBlock(n_heads=self.n_heads, d_model=self.d_model, d_mlp=self.d_mlp)(x, h, mask_attn)

        # Final LayerNorm
        x = nn.LayerNorm()(x)

        # Unembed; zero init kernel to propagate zero residual initially before training
        x = nn.Dense(self.n_input, kernel_init=jax.nn.initializers.zeros)(x)

        return x


import math
import importlib

import jax
import jax.numpy as jnp
from flax import linen as nn
from einops import rearrange


def scaled_dot_product_attention(q, k, v, mask=None):
    """Compute scaled dot-product masked attention."""
    d_k = q.shape[-1]
    attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = jnp.where(mask[:, None, None, :] == 0, -9e15, attn_logits)
    attention = nn.softmax(attn_logits, axis=-1)
    values = jnp.matmul(attention, v)
    return values, attention

class Transformer(nn.Module):
    """Simple decoder-only transformer for set modeling.
    Attributes:
      n_input: The number of input (and output) features.
      d_model: The dimension of the model embedding space.
      d_mlp: The dimension of the multi-layer perceptron (MLP) used in the feed-forward network.
      n_layers: Number of transformer layers.
      n_heads: The number of attention heads.
      flash_attention: Flag that indicates whether to use flash attention or not.
    """

    n_input: int
    d_model: int = 128
    d_mlp: int = 512
    n_layers: int = 4
    n_heads: int = 4
    flash_attention: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray, conditioning: jnp.ndarray = None, mask=None):

        # Sequence length
        batch, seq_length = x.shape[0], x.shape[1]

        # Input embedding
        x = nn.Dense(int(self.d_model))(x)  # (batch, seq_len, d_model)
        if conditioning is not None:
            conditioning = nn.Dense(int(self.d_model))(conditioning)  # (batch, d_model)
            x += conditioning[:, None, :]  # (batch, seq_len, d_model)

        # Mask according to set cardinality
        mask_attn = jnp.ones((batch, seq_length)) if mask is None else mask

        # Transformer layers
        for _ in range(self.n_layers):
            
            # LayerNorm each time residual stream is written onto
            x1 = nn.LayerNorm()(x)

            # Get qkv projections
            qkv = nn.Dense(3 * self.d_model, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)(x1)

            # Project out separate q, k, v
            qkv = rearrange(qkv, "batch seq_length (n_heads d_heads_3) -> batch n_heads seq_length d_heads_3", n_heads=self.n_heads)
            q, k, v = jnp.split(qkv, 3, axis=-1)  # (batch, n_heads, seq_length, d_heads)

            # Compute attention
            x_heads, _ = scaled_dot_product_attention(q, k, v, mask=mask_attn)  # (batch, n_heads, seq_length, d_heads)
            x_heads = rearrange(x_heads, "batch n_heads seq_length d_heads -> batch seq_length (n_heads d_heads)")

            # x_heads = dot_product_attention(q[..., None], k[..., None], v[..., None], mask=None)[..., 0]
            # x_heads = rearrange(x_heads, "batch n_heads seq_length d_heads -> batch seq_length (n_heads d_heads)")

            # Output
            x_heads = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)(x_heads)

            x += x_heads  # Write residual stream

            # LayerNorm
            x2 = nn.LayerNorm()(x)

            # MLP
            x2 = nn.Dense(self.d_mlp)(x2)
            x2 = jax.nn.gelu(x2)
            x2 = nn.Dense(self.d_model)(x2)

            x += x2  # Write residual stream

        # Final LayerNorm
        x = nn.LayerNorm()(x)

        # Unembed
        x = nn.Dense(self.n_input, kernel_init=jax.nn.initializers.zeros)(x)

        return x

In [134]:
n_input = 64
d_model = 256
d_mlp = 1024
n_layers = 2
n_heads = 2

rng = jax.random.PRNGKey(42)
x = jax.random.normal(rng, (32, 16, n_input))

In [135]:
from clu import parameter_overview

In [136]:
transformer = Transformer(n_input=n_input, d_model=d_model, d_mlp=d_mlp, n_layers=n_layers, n_heads=n_heads)
_, params = transformer.init_with_output({"params": rng}, x)

In [137]:
print(parameter_overview.get_parameter_overview(params))

+--------------------------+-------------+---------+-----------+--------+
| Name                     | Shape       | Size    | Mean      | Std    |
+--------------------------+-------------+---------+-----------+--------+
| params/Dense_0/bias      | (256,)      | 256     | 0.0       | 0.0    |
| params/Dense_0/kernel    | (64, 256)   | 16,384  | 0.000496  | 0.126  |
| params/Dense_1/bias      | (768,)      | 768     | 0.0       | 0.0    |
| params/Dense_1/kernel    | (256, 768)  | 196,608 | 0.000112  | 0.0443 |
| params/Dense_2/bias      | (256,)      | 256     | 0.0       | 0.0    |
| params/Dense_2/kernel    | (256, 256)  | 65,536  | 0.000391  | 0.0625 |
| params/Dense_3/bias      | (1024,)     | 1,024   | 0.0       | 0.0    |
| params/Dense_3/kernel    | (256, 1024) | 262,144 | 2.5e-05   | 0.0624 |
| params/Dense_4/bias      | (256,)      | 256     | 0.0       | 0.0    |
| params/Dense_4/kernel    | (1024, 256) | 262,144 | 8.45e-06  | 0.0313 |
| params/Dense_5/bias      | (768,)   

In [138]:
transformer = TransformerFlax(n_input=n_input, d_model=d_model, d_mlp=d_mlp, n_layers=n_layers, n_heads=n_heads)
_, params = transformer.init_with_output({"params": rng}, x)

In [139]:
print(parameter_overview.get_parameter_overview(params))

+------------------------------------------------------------------------------+---------------+---------+-----------+--------+
| Name                                                                         | Shape         | Size    | Mean      | Std    |
+------------------------------------------------------------------------------+---------------+---------+-----------+--------+
| params/Dense_0/bias                                                          | (256,)        | 256     | 0.0       | 0.0    |
| params/Dense_0/kernel                                                        | (64, 256)     | 16,384  | 0.000496  | 0.126  |
| params/Dense_1/bias                                                          | (64,)         | 64      | 0.0       | 0.0    |
| params/Dense_1/kernel                                                        | (256, 64)     | 16,384  | 0.0       | 0.0    |
| params/LayerNorm_0/bias                                                      | (256,)        | 256    