### Setup

In [38]:
!pip install flax sentencepiece transformers --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m78.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m23.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m113.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [39]:
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Callable
import os

import jax
from jax import random, Array
import jax.numpy as jnp
import flax
import flax.linen as nn
import numpy as np
import torch

In [5]:
# Prevent JAX from prealloating memory. Helps avoid OOM errors with low memory usage
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

### Model

In [25]:
@dataclass
class TransformerConfig:
    vocab_size: int
    num_layers: int
    d_model: int
    d_ff: int
    n_heads: int
    d_k: int
    num_relative_pos: int
    eps: float = 1e-6

In [26]:
def relative_position_bucket(relative_position, is_decoder: bool, num_buckets=32, max_distance=128):
    relative_buckets = 0
    if is_decoder:
        relative_position = -jnp.clip(relative_position, a_max=0)
    else:
        num_buckets //= 2
        relative_buckets += (relative_position > 0) * num_buckets
        relative_position = jnp.abs(relative_position)

    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    relative_position_if_large = max_exact + (
        jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
    )
    relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)

    relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

    return relative_buckets.astype("i4")

def compute_bias(query_length: int, key_length: int, embedding: Callable, is_decoder: bool):
    context_position = jnp.arange(query_length, dtype="i4")[:, None]
    memory_position = jnp.arange(key_length, dtype="i4")[None, :]

    relative_position = memory_position - context_position
    position_bucket = relative_position_bucket(
        relative_position=relative_position,
        is_decoder=is_decoder,
        num_buckets=32,
        max_distance=128
    )

    values = embedding(position_bucket)
    values = values.transpose((2, 0, 1))[None, :, :, :]
    return values

In [27]:
class EmbeddingLayer(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.d_model)

    def __call__(self, x):
        return self.embedding(x)


class PositionEmbeddingLayer(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.pos_embedding = nn.Embed(num_embeddings=self.config.num_relative_pos, features=self.config.n_heads)

    def __call__(self, x):
        return self.pos_embedding(x)


class MLP(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.ff1 = nn.Dense(self.config.d_ff, use_bias=False)
        self.ff2 = nn.Dense(self.config.d_model, use_bias=False)

    def __call__(self, x):
        x = nn.relu(self.ff1(x))
        return self.ff2(x)

class LayerNorm(nn.Module):
    config: TransformerConfig
    weight_init: Callable[..., jnp.ndarray] = jax.nn.initializers.ones

    def setup(self):
        self.weight = self.param("weight", self.weight_init, (self.config.d_model,))

    def __call__(self, x: Array) -> Array:
        variance = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
        norm_x = x * jax.lax.rsqrt(variance + self.config.eps)
        return self.weight * norm_x


class SelfAttention(nn.Module):
    config: TransformerConfig
    is_decoder: bool

    def setup(self):
        self.w_q = nn.Dense(self.config.d_model, use_bias=False)
        self.w_k = nn.Dense(self.config.d_model, use_bias=False)
        self.w_v = nn.Dense(self.config.d_model, use_bias=False)
        self.w_o = nn.Dense(self.config.d_model, use_bias=False)

    def __call__(self, x: Array, position_bias: Array) -> Array:
        batch_size, n, _ = x.shape
        if self.is_decoder:
            x = jax.lax.expand_dims(x[:, -1, :], 1)
            Q = self.w_q(x)
        else:
            Q = self.w_q(x)
        K, V = self.w_k(x), self.w_v(x)
        Q, K, V = self.split_heads(Q), self.split_heads(K), self.split_heads(V)
        # Q is of shape (batch_size, n_heads, 1 or n, d_k)
        # K, V are of shape (batch_size, n_heads, n, d_k)
        scores = jax.lax.dot_general(Q, K, (((3,), (3,)), ((0,1), (0,1))))
        scores += position_bias
        attn_weights = nn.softmax(scores, axis=-1)
        split_attn = jax.lax.dot_general(attn_weights, V, (((3,), (2,)), ((0,1), (0,1))))
        attention = self.unify_heads(split_attn)
        output = self.w_o(attention)
        return output

    def split_heads(self, x: Array) -> Array:
        batch_size, n, _ = x.shape
        x = jax.lax.reshape(x, (batch_size, n, self.config.n_heads, self.config.d_k))
        x = jax.lax.transpose(x, (0, 2, 1, 3))
        return x

    def unify_heads(self, x: Array) -> Array:
        batch_size, n_heads, n, _ = x.shape
        x = jax.lax.transpose(x, (0, 2, 1, 3))
        x = jax.lax.reshape(x, (batch_size, n, self.config.d_model))
        return x


class EncoderLayer(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.self_attn = SelfAttention(config=self.config, is_decoder=False)
        self.mlp = MLP(config=self.config)
        self.norm1 = LayerNorm(config=self.config)
        self.norm2 = LayerNorm(config=self.config)

    def __call__(self, x: Array, position_bias: Array) -> Array:
        normed_x = self.norm1(x)
        x += self.self_attn(normed_x, position_bias=position_bias)
        normed_x = self.norm2(x)
        x += self.mlp(normed_x)
        return x


class Encoder(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.encoder_layers = [EncoderLayer(config=self.config) for _ in range(self.config.num_layers)]
        self.self_attention_relative_attention_embedding = PositionEmbeddingLayer(config=self.config)

    def __call__(self, x: Array) -> Array:
        _, n, _ = x.shape
        self_attention_position_bias = compute_bias(n, n, embedding=self.self_attention_relative_attention_embedding, is_decoder=False)
        for i in range(self.config.num_layers):
            x = self.encoder_layers[i](x, self_attention_position_bias)
        return x


class CrossAttention(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.w_q = nn.Dense(self.config.d_model, use_bias=False)
        self.w_k = nn.Dense(self.config.d_model, use_bias=False)
        self.w_v = nn.Dense(self.config.d_model, use_bias=False)
        self.w_o = nn.Dense(self.config.d_model, use_bias=False)

    def __call__(self, x: Array, encoding: Array, position_bias: Array) -> Array:
        batch_size, n, _ = x.shape
        Q = self.w_q(x)
        K, V = self.w_k(encoding), self.w_v(encoding)
        Q, K, V = self.split_heads(Q), self.split_heads(K), self.split_heads(V)
        # Q is of shape (batch_size, n_heads, 1 or n, d_k)
        # K, V are of shape (batch_size, n_heads, n, d_k)
        scores = jax.lax.dot_general(Q, K, (((3,), (3,)), ((0,1), (0,1))))
        scores += position_bias
        attn_weights = nn.softmax(scores, axis=-1)
        split_attn = jax.lax.dot_general(attn_weights, V, (((3,), (2,)), ((0,1), (0,1))))
        attention = self.unify_heads(split_attn)
        output = self.w_o(attention)
        return output

    def split_heads(self, x: Array) -> Array:
        batch_size, n, _ = x.shape
        x = jax.lax.reshape(x, (batch_size, n, self.config.n_heads, self.config.d_k))
        x = jax.lax.transpose(x, (0, 2, 1, 3))
        return x

    def unify_heads(self, x: Array) -> Array:
        batch_size, _, n, _ = x.shape
        x = jax.lax.transpose(x, (0, 2, 1, 3))
        x = jax.lax.reshape(x, (batch_size, n, self.config.d_model))
        return x


class DecoderLayer(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.self_attn = SelfAttention(config=self.config, is_decoder=False)
        self.cross_attn = CrossAttention(config=self.config)
        self.mlp = MLP(config=self.config)
        self.norm1 = LayerNorm(config=self.config)
        self.norm2 = LayerNorm(config=self.config)
        self.norm3 = LayerNorm(config=self.config)

    def __call__(self, x: Array, encoding: Array, self_attention_position_bias: Array, cross_attention_position_bias: Array) -> Array:
        normed_x = self.norm1(x)
        x += self.self_attn(normed_x, position_bias=self_attention_position_bias)
        normed_x = self.norm2(x)
        x += self.cross_attn(normed_x, encoding, position_bias=cross_attention_position_bias)
        normed_x = self.norm3(x)
        x += self.mlp(normed_x)
        return x


class Decoder(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.decoder_layers = [DecoderLayer(config=self.config) for _ in range(self.config.num_layers)]
        self.self_attention_relative_attention_embedding = PositionEmbeddingLayer(config=self.config)
        self.enc_dec_attention_relative_attention_embedding = PositionEmbeddingLayer(config=self.config)

    def __call__(self, x: Array, encoding: Array) -> Array:
        _, n, _ = x.shape
        _, encoding_n, _ = encoding.shape
        self_attention_position_bias = compute_bias(1, n, embedding=self.self_attention_relative_attention_embedding, is_decoder=True)
        enc_dec_attention_position_bias = compute_bias(1, encoding_n, embedding=self.enc_dec_attention_relative_attention_embedding, is_decoder=True)
        for i in range(self.config.num_layers):
            x = self.decoder_layers[i](x,
                                       encoding,
                                       self_attention_position_bias,
                                       enc_dec_attention_position_bias)
        return jax.lax.expand_dims(x[:, -1, :], [1])


class Transformer(nn.Module):
    config: TransformerConfig

    def setup(self):
        self.embedding_layer = EmbeddingLayer(config=self.config)
        self.encoder = Encoder(config=self.config)
        self.final_encoder_layer_norm = LayerNorm(config=self.config)
        self.decoder = Decoder(config=self.config)
        self.final_decoder_layer_norm = LayerNorm(config=self.config)
        self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)

    def __call__(self, x: Array, max_tokens: int, debug: bool) -> Array:
        x = jax.lax.stop_gradient(x)
        x_emb = self.embedding_layer(x)
        encoding = self.encoder(x_emb)
        encoding = self.final_encoder_layer_norm(encoding)
        outputs = jnp.array([[0]])
        while outputs[:, -1] != 1 and len(outputs[0]) < max_tokens:
            if debug:
                bkpt()
            decoder_inputs_embedding = self.embedding_layer(outputs)
            decoding = self.decoder(decoder_inputs_embedding, encoding)
            decoding = self.final_decoder_layer_norm(decoding)
            decoding *= (self.config.d_model ** -0.5)
            next_logits = self.lm_head(decoding)
            next_token = jax.lax.argmax(next_logits, axis=2, index_dtype=int)
            outputs = jax.lax.concatenate([outputs, next_token], dimension=1)

        return outputs

### Initialize model and load weights

In [28]:
transformer_config = TransformerConfig(vocab_size=32128,
                                       num_layers=6,
                                       d_model=512,
                                       d_ff=2048,
                                       n_heads=8,
                                       d_k=64,
                                       num_relative_pos=32)

In [29]:
transformer = Transformer(config=transformer_config)

In [30]:
key = random.PRNGKey(0)
key, subkey = random.split(key, 2)
t5_params = transformer.init(subkey, random.randint(subkey, (1, 3), 0, 30000), max_tokens=5, debug=False)

encodings done!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!


In [31]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [32]:
# Load your weights here!
# Download from https://huggingface.co/t5-small/blob/main/flax_model.msgpack and upload to your Google Drive in the root folder
with open("drive/MyDrive/t5-small-jax-weights.msgpack", 'rb') as f:
    content = f.read()
    flax_model_weights = flax.serialization.msgpack_restore(content)
flax_model_weights;

In [33]:
# Additionally, upload "t5_cross_pos_embedding.pt" file to Google Colab and use it to instantiate a parameter that's missing from the msgpack file
enc_dec_pos_embedding_tensor = torch.load("t5_cross_pos_embedding.pt").numpy()
cross_attn_embedding = jnp.asarray(enc_dec_pos_embedding_tensor)

In [34]:
# Unfreeze t5_params to be able to modify them
t5_params = t5_params.unfreeze()

In [35]:
# Modify params to params from msgpack

# Decoder cross-attn
t5_params["params"]["decoder"]["decoder_layers_0"]["cross_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["1"]["EncDecAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["cross_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["1"]["EncDecAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["cross_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["1"]["EncDecAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["cross_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["1"]["EncDecAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["cross_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["1"]["EncDecAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["cross_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["1"]["EncDecAttention"]["k"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_1"]["cross_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["1"]["EncDecAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_0"]["cross_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["1"]["EncDecAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["cross_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["1"]["EncDecAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["cross_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["1"]["EncDecAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["cross_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["1"]["EncDecAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["cross_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["1"]["EncDecAttention"]["v"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_0"]["cross_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["1"]["EncDecAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["cross_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["1"]["EncDecAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["cross_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["1"]["EncDecAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["cross_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["1"]["EncDecAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["cross_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["1"]["EncDecAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["cross_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["1"]["EncDecAttention"]["q"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_0"]["cross_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["1"]["EncDecAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["cross_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["1"]["EncDecAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["cross_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["1"]["EncDecAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["cross_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["1"]["EncDecAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["cross_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["1"]["EncDecAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["cross_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["1"]["EncDecAttention"]["o"]["kernel"]

# Decoder self-attn
t5_params["params"]["decoder"]["decoder_layers_0"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_1"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_0"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_0"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_0"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]

# Encoder self-attn
t5_params["params"]["encoder"]["encoder_layers_0"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_1"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_2"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_3"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_4"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_5"]["self_attn"]["w_k"]["kernel"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["k"]["kernel"]

t5_params["params"]["encoder"]["encoder_layers_1"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_0"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_2"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_3"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_4"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_5"]["self_attn"]["w_v"]["kernel"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["v"]["kernel"]

t5_params["params"]["encoder"]["encoder_layers_0"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_1"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_2"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_3"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_4"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_5"]["self_attn"]["w_q"]["kernel"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["q"]["kernel"]

t5_params["params"]["encoder"]["encoder_layers_0"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_1"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_2"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_3"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_4"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_5"]["self_attn"]["w_o"]["kernel"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["0"]["SelfAttention"]["o"]["kernel"]

# Decoder mlps
t5_params["params"]["decoder"]["decoder_layers_0"]["mlp"]["ff1"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["mlp"]["ff1"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["mlp"]["ff1"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["mlp"]["ff1"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["mlp"]["ff1"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["mlp"]["ff1"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"]

t5_params["params"]["decoder"]["decoder_layers_0"]["mlp"]["ff2"]["kernel"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_1"]["mlp"]["ff2"]["kernel"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_2"]["mlp"]["ff2"]["kernel"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_3"]["mlp"]["ff2"]["kernel"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_4"]["mlp"]["ff2"]["kernel"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["decoder"]["decoder_layers_5"]["mlp"]["ff2"]["kernel"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"]

# Encoder mlps
t5_params["params"]["encoder"]["encoder_layers_0"]["mlp"]["ff1"]["kernel"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_1"]["mlp"]["ff1"]["kernel"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_2"]["mlp"]["ff1"]["kernel"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_3"]["mlp"]["ff1"]["kernel"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_4"]["mlp"]["ff1"]["kernel"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_5"]["mlp"]["ff1"]["kernel"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"]

t5_params["params"]["encoder"]["encoder_layers_1"]["mlp"]["ff2"]["kernel"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_0"]["mlp"]["ff2"]["kernel"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_2"]["mlp"]["ff2"]["kernel"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_3"]["mlp"]["ff2"]["kernel"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_4"]["mlp"]["ff2"]["kernel"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"]
t5_params["params"]["encoder"]["encoder_layers_5"]["mlp"]["ff2"]["kernel"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"]

# Norms
t5_params["params"]["decoder"]["decoder_layers_0"]["norm1"]["weight"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_1"]["norm1"]["weight"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_2"]["norm1"]["weight"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_3"]["norm1"]["weight"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_4"]["norm1"]["weight"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_5"]["norm1"]["weight"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["0"]["layer_norm"]["weight"]

t5_params["params"]["decoder"]["decoder_layers_0"]["norm2"]["weight"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_1"]["norm2"]["weight"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_2"]["norm2"]["weight"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_3"]["norm2"]["weight"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_4"]["norm2"]["weight"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_5"]["norm2"]["weight"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["1"]["layer_norm"]["weight"]

t5_params["params"]["decoder"]["decoder_layers_0"]["norm3"]["weight"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["2"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_1"]["norm3"]["weight"] = flax_model_weights["decoder"]["block"]["1"]["layer"]["2"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_2"]["norm3"]["weight"] = flax_model_weights["decoder"]["block"]["2"]["layer"]["2"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_3"]["norm3"]["weight"] = flax_model_weights["decoder"]["block"]["3"]["layer"]["2"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_4"]["norm3"]["weight"] = flax_model_weights["decoder"]["block"]["4"]["layer"]["2"]["layer_norm"]["weight"]
t5_params["params"]["decoder"]["decoder_layers_5"]["norm3"]["weight"] = flax_model_weights["decoder"]["block"]["5"]["layer"]["2"]["layer_norm"]["weight"]

t5_params["params"]["encoder"]["encoder_layers_0"]["norm1"]["weight"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_1"]["norm1"]["weight"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_2"]["norm1"]["weight"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_3"]["norm1"]["weight"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_4"]["norm1"]["weight"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["0"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_5"]["norm1"]["weight"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["0"]["layer_norm"]["weight"]

t5_params["params"]["encoder"]["encoder_layers_0"]["norm2"]["weight"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_1"]["norm2"]["weight"] = flax_model_weights["encoder"]["block"]["1"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_2"]["norm2"]["weight"] = flax_model_weights["encoder"]["block"]["2"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_3"]["norm2"]["weight"] = flax_model_weights["encoder"]["block"]["3"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_4"]["norm2"]["weight"] = flax_model_weights["encoder"]["block"]["4"]["layer"]["1"]["layer_norm"]["weight"]
t5_params["params"]["encoder"]["encoder_layers_5"]["norm2"]["weight"] = flax_model_weights["encoder"]["block"]["5"]["layer"]["1"]["layer_norm"]["weight"]

# All embeddings
t5_params["params"]["encoder"]["self_attention_relative_attention_embedding"]["pos_embedding"]["embedding"] = flax_model_weights["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"]
t5_params["params"]["decoder"]["self_attention_relative_attention_embedding"]["pos_embedding"]["embedding"] = flax_model_weights["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"]
t5_params["params"]["decoder"]["enc_dec_attention_relative_attention_embedding"]["pos_embedding"]["embedding"] = cross_attn_embedding
t5_params["params"]["embedding_layer"]["embedding"]["embedding"] = flax_model_weights["shared"]["embedding"]
t5_params["params"]["lm_head"]["kernel"] = jax.lax.transpose(flax_model_weights["shared"]["embedding"], (1, 0))

# Final layer norms
t5_params["params"]["final_encoder_layer_norm"]["weight"] = flax_model_weights["encoder"]["final_layer_norm"]["weight"]
t5_params["params"]["final_decoder_layer_norm"]["weight"] = flax_model_weights["decoder"]["final_layer_norm"]["weight"]

### Inference

In [40]:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [42]:
input_sentence = "summarize: Space exploration has always fascinated humans, and has led to countless discoveries and advancements in technology. The study of our solar system and beyond continues to uncover new mysteries, from the formation of planets to the search for extraterrestrial life. With the help of advanced technologies such as telescopes and spacecraft, scientists are able to gain a deeper understanding of the universe and our place in it. The pursuit of knowledge and discovery drives continued investment in space exploration, leading to a better understanding of not just the cosmos, but also ourselves."
inputs = tokenizer(input_sentence, return_tensors="np")
outputs = transformer.apply(t5_params, inputs.input_ids, max_tokens=20, debug=False)
decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
decoded_outputs

encodings done!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!
started token
finished token!


'the study of our solar system and beyond continues to uncover'