In [10]:
import flax.linen as nn
import jax.numpy as jnp
import jax
from jax import lax
from jax.tree_util import tree_map

In [11]:
from nimblegpt import get_config_for, make_gpt_param_dict, get_flaxmodels_gpt2_params, param_shapes

In [12]:
from nimblegpt.base_model import (
    BaseBlock,
    BaseCausalSelfAttention,
    BaseGPT,
    BaseSingleHeadCausalSelfAttention,
)
from nimblegpt.model import GELU, softmax, SingleHeadCausalSelfAttention, GPT

In [13]:
from nimblegpt.fast_model import FSingleHeadCausalSelfAttention

In [14]:
config = get_config_for("gpt2")

In [15]:
config

attn_pdrop: 0.1
block_size: 1024
embd_pdrop: 0.1
model_type: gpt2
n_embd: 768
n_head: 12
n_layer: 12
resid_pdrop: 0.1
vocab_size: 50257

In [16]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
gpt_params = make_gpt_param_dict(get_flaxmodels_gpt2_params(), config)
sa_0_params = gpt_params["Block_0"]["CausalSelfAttention_0"][
    "VmapSingleHeadCausalSelfAttention_0"]
sh_0_params = tree_map(lambda x: x[0], sa_0_params)
param_shapes(sh_0_params)

{'Dense_0': {'bias': '(192)', 'kernel': '(768, 192)'}}

In [39]:
rng = jax.random.PRNGKey(0)

# Generation Function

In [51]:
from nimblegpt.fast_model import FGPT
from nimblegpt.generate import sample_token, generate_tokens

In [37]:
fgpt_module = FGPT.Make(config)

In [44]:
vars = fgpt_module.init_vars({"params": gpt_params})

In [41]:
prompt_idxs = jax.random.randint(rng, (3,), 0, config.vocab_size)
prompt_idxs

Array([23652,  9593,  2133], dtype=int32)

In [59]:
fgpt_module.generate(rng, {"params": gpt_params, **vars}, prompt_idxs, sample_token)

Array([23652,  9593,  2133, 29569, 13936,   286,  3294, 11621,   357,
          18,    35,     8,   198], dtype=int32)

In [65]:
jit_gen = jax.jit(fgpt_module.generate, static_argnames=("logit_sampler", "max_new_tokens"))
jit_gen(rng, {"params": gpt_params, **vars}, prompt_idxs, sample_token, max_new_tokens=100)

Array([23652,  9593,  2133, 29569, 13936,   286,  3294, 11621,   357,
          18,    35,     8,   198,   198,    32,  3748, 16106,  4427,
         198,   198,    32,   649,  1080,   286,  5021,    12,  7829,
        5249,   351,   513,  1180,  2137, 22582,   290,   362,  1180,
       12881,   198,   198, 11002,   351,   257,  1545,   290,   766,
         703,   262,  9552, 21126,   656,   511,  2095,   198,   198,
          20,    12,    35,   393,   513,    12,    35,  9382,    11,
         257,   649,   983,  3113,   351,  1365,  3703,   290,   257,
       26192,   995,   290,   517,  3716,  2095,   198,   198,    17,
          12,    35,   393,   604,    12,    35,  9382,    11,   517,
        3716,   995,   326,  3578,   345,   284,  7301,   262,  1621,
         198,   198, 15022,  5612], dtype=int32)

In [67]:
%%timeit

jit_gen(rng, {"params": gpt_params, **vars}, prompt_idxs, sample_token, max_new_tokens=config.block_size - len(prompt_idxs))

3.05 s ± 71.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [55]:
generate_tokens(rng, GPT(config).bind({"params": gpt_params}), prompt_idxs, max_new_tokens=10)

Array([23652,  9593,  2133, 29569, 13936,   286,  3294, 11621,   357,
          18,    35,     8,   198], dtype=int32)

# Jittable Generation

In [18]:
from nimblegpt.fast_model import FGPT
from nimblegpt.generate import sample_token

In [36]:
fgpt_module = FGPT.Make(config)

In [20]:
_, vars = fgpt_module.apply({"params": gpt_params},
                            jnp.array(10),
                            jnp.array(0),
                            mutable="cache")

In [27]:
from typing import Dict


def body(seq_idx: int, val: Dict):
    logits, cache = FGPT.Make(config).apply(val["variables"],
                                            val["seq"][seq_idx],
                                            seq_idx,
                                            mutable="cache")
    tok_idx = sample_token(val["rng"], logits)
    return {
        **val,
        "variables": {
            **val["variables"],
            **cache
        },
        "rng": jax.random.split(val["rng"])[0],
        "seq": val["seq"].at[seq_idx + 1].set(tok_idx),
    }

In [24]:
prompt_idxs = jax.random.randint(jax.random.PRNGKey(0), (3,), 0, config.vocab_size)
prompt_idxs

Array([23652,  9593,  2133], dtype=int32)

In [30]:
val = {
    "variables": {
        "params": gpt_params,
        "cache": vars["cache"]
    },
    "rng": jax.random.PRNGKey(0),
    "seq": jnp.pad(prompt_idxs, (0, 3), constant_values=0),
}
val["seq"]

Array([23652,  9593,  2133,     0,     0,     0], dtype=int32)

In [32]:
jax.lax.fori_loop(0, 5, body, val)["seq"]

Array([23652,   287,   262,   968, 18318,   357], dtype=int32)

# Logits Comparison

In [21]:
from nimblegpt.fast_model import FGPT

In [22]:
fgpt_module = FGPT.Make(config)
vars = fgpt_module.init_vars(dict(params=gpt_params))


In [23]:
logits, vars = fgpt_module.apply({
    "params": gpt_params,
    **vars
},
                                 jnp.array(10),
                                 jnp.array(0),
                                 mutable="cache")


In [24]:
logits

Array([-32.8549  , -31.824638, -33.43742 , ..., -40.04716 , -40.218723,
       -32.17592 ], dtype=float32)

In [25]:
gpt_module = GPT(config)

In [26]:
gpt_module.apply({"params": gpt_params}, jnp.array([10]))

Array([[-32.854923, -31.824661, -33.43744 , ..., -40.047188, -40.218746,
        -32.175938]], dtype=float32)

## Logits with prompt

In [49]:
vars = fgpt_module.init_vars(dict(params=gpt_params))

_, vars = fgpt_module.apply({
    "params": gpt_params,
    **vars
},
                            jnp.array(10),
                            jnp.array(0),
                            mutable="cache")
flogits, vars = fgpt_module.apply({
    "params": gpt_params,
    **vars
},
                                    jnp.array(11),      
                                    jnp.array(1),
                                    mutable="cache")


In [50]:
logits = gpt_module.apply({"params": gpt_params}, jnp.array([10, 11]))

In [51]:
flogits

Array([-89.370514, -89.77837 , -87.00991 , ..., -96.07502 , -95.442566,
       -89.372604], dtype=float32)

In [52]:
logits

Array([[-32.854923, -31.824652, -33.437424, ..., -40.047184, -40.218742,
        -32.17593 ],
       [-89.37053 , -89.778366, -87.00992 , ..., -96.07503 , -95.442566,
        -89.37261 ]], dtype=float32)

# Sequence to Sequence

In [27]:
from nimblegpt.fast_model import FGPT

In [28]:
fgpt_module = FGPT.Make(config)

In [29]:
rng = jax.random.PRNGKey(0)

In [30]:
_, vars = fgpt_module.apply({"params": gpt_params},
                            jnp.array(10),
                            jnp.array(0),
                            mutable="cache")


In [31]:
logits, vars = fgpt_module.apply({
    "params": gpt_params,
    "cache": vars["cache"]
},
                                 jnp.array(10),
                                 jnp.array(0),
                                 mutable="cache")


In [42]:
prompt_idx = jax.random.randint(rng, (3, ), 0, config.vocab_size)
prompt_idx

Array([23652,  9593,  2133], dtype=int32)

In [55]:
seq = jnp.pad(prompt_idx, (0, 6 - prompt_idx.shape[0]), constant_values=0)
_, vars = fgpt_module.apply({"params": gpt_params},
                            jnp.array(0),
                            jnp.array(0),
                            mutable="cache")

for seq_idx in range(5):
    flogits, vars = fgpt_module.apply(
        {
            "params": gpt_params,
            "cache": vars["cache"]
        },
        seq[seq_idx],
        jnp.array(seq_idx),
        mutable="cache")
    if seq_idx >= len(prompt_idx) - 1:
        seq = seq.at[seq_idx+1].set(jnp.argmax(flogits, axis=-1))

In [56]:
seq

Array([23652,  9593,  2133, 29569,   319,   262], dtype=int32)

In [35]:
gpt_module = GPT(config)

In [36]:
m_seq = prompt_idx
for _ in range(3):
    logits = gpt_module.apply({"params": gpt_params}, m_seq)
    m_seq = jnp.concatenate([m_seq, jnp.argmax(logits, axis=-1)[-1:]])
m_seq

Array([23652,  9593,  2133, 29569,   319,   262], dtype=int32)

In [45]:
logits

Array([[-31.78656 , -31.017838, -33.72394 , ..., -39.384296, -39.60583 ,
        -31.77908 ],
       [-76.03792 , -76.57617 , -79.64832 , ..., -79.2811  , -82.38854 ,
        -76.91283 ],
       [-53.08388 , -52.194874, -50.482613, ..., -60.19143 , -57.123943,
        -51.9151  ],
       [-80.858086, -81.24439 , -84.33287 , ..., -84.55941 , -85.661896,
        -82.019295],
       [-74.87219 , -74.504135, -76.91374 , ..., -76.71006 , -79.23487 ,
        -74.19598 ]], dtype=float32)

In [46]:
flogits

Array([-62.46567 , -68.00004 , -65.42838 , ..., -75.3501  , -73.863045,
       -63.383614], dtype=float32)

In [37]:
seq

Array([23652,  9593,  2133,   198,   198,   198], dtype=int32)

# Single Head Causal Self Attention

In [38]:
# class FSingleHeadCausalSelfAttention(BaseSingleHeadCausalSelfAttention):
#     n_cntx: int
#     n_feat: int

#     @nn.compact
#     def __call__(self, x: jax.Array):

#         # q : [n_feat], K, V : [n_cntx, n_feat]
#         q, K, V, idx = SingleHeadQKV(n_cntx=self.n_cntx, n_feat=self.n_feat)(x)

#         # [n_feat] @ [n_feat, n_cntx] -> [n_cntx].
#         # Attention for token `idx`. att[i] is high when token `idx` should attend
#         # heavily to token i.
#         att = (K @ q) * (1.0 / jnp.sqrt(self.n_feat))

#         # Causal masking. Token `idx` should not attend to token i for any i > idx.
#         att = jnp.where(jnp.arange(self.n_cntx) > idx, -jnp.inf, att)

#         att = softmax(att)

#         # [n_cntx] @ [n_cntx, n_feat] -> [n_feat]
#         y = att @ V

#         return y

In [39]:
fshcsa_module = FSingleHeadCausalSelfAttention(n_cntx=config.block_size,
                                               n_feat=config.n_embd //
                                               config.n_head)


In [40]:
vars = fshcsa_module.init(jax.random.PRNGKey(0), jnp.ones((config.n_embd, )))


TypeError: FSingleHeadCausalSelfAttention.__call__() missing 1 required positional argument: 'seq_idx'

In [None]:
y, vars = fshcsa_module.apply(vars,
                              jnp.ones((config.n_embd, )),
                              mutable="cache")


In [None]:
y

Array([-0.4951638 ,  1.3103716 ,  0.7500142 , -0.33700418, -1.6279882 ,
        1.3067862 ,  0.56427884,  1.6625755 , -0.5867724 , -1.3497397 ,
        1.2912292 ,  1.234936  ,  1.1145046 ,  1.1123266 , -0.6795392 ,
        0.576539  , -0.9984726 ,  1.7678396 ,  0.23711662,  1.4688923 ,
       -0.81794494,  0.37158245, -0.17866445,  0.11133623, -0.47702432,
       -0.76485586,  0.8497622 , -0.15143591,  0.885042  , -1.1836362 ,
       -0.97935313, -0.16283125,  0.6268122 , -0.30901474, -0.05063885,
       -0.12694442,  0.96355164, -0.5885438 , -0.02701813,  0.57972115,
       -1.4262507 , -0.4855162 ,  0.6056665 , -0.4228196 ,  0.29803544,
        0.04475397, -1.0122097 ,  2.297686  , -1.1363087 ,  1.8593898 ,
        1.3262932 ,  0.8218234 ,  1.2819971 ,  0.2697784 , -1.0868869 ,
       -1.606765  , -0.5495784 , -1.4249687 ,  0.9420121 , -1.1320074 ,
        0.465777  ,  0.569261  , -1.6738216 ,  1.3660043 ], dtype=float32)

In [None]:
rng = jax.random.PRNGKey(0)

In [None]:
X = jax.random.normal(rng, (10, config.n_embd))

In [None]:
Y = SingleHeadCausalSelfAttention(config.n_embd // config.n_head).apply(
    {"params": sh_0_params}, X)


In [None]:
vars = fshcsa_module.init(jax.random.PRNGKey(0), jnp.ones((config.n_embd, )))


In [None]:
fys = []
for i in range(10):
    y, vars = fshcsa_module.apply(
        {
            "cache": vars["cache"],
            "params": sh_0_params
        }, X[i], mutable="cache")
    fys.append(y)

In [None]:
Y.shape

(10, 64)

In [None]:
(Y - jnp.array(fys)).max()

Array(8.34465e-07, dtype=float32)

In [None]:
fys

[Array([-3.7749598e-03, -1.0600836e+00, -1.6187843e+00,  3.4710118e-01,
         1.3730958e+00, -4.0345654e+00,  2.7263455e+00,  6.8328369e-01,
         7.7177596e-01,  1.4631925e+00, -3.1356561e-01,  2.4275184e-01,
        -1.2562956e+00, -2.1613176e+00, -7.6190311e-01, -1.4470794e+00,
        -3.1225381e+00, -1.3839597e+00, -8.0332410e-01,  5.5632317e-01,
        -1.1603022e+00,  1.5281880e+00,  1.8011650e+00,  1.5308793e+00,
         1.5507106e+00,  3.1227562e+00,  2.6930642e+00,  5.2547187e-01,
         3.8787910e-01, -1.2424152e+00, -2.5827773e+00,  7.7664012e-01,
         1.9117656e+00,  1.0243871e+00,  2.1334591e+00, -3.9823332e-01,
         2.3314288e+00, -1.8247846e-01, -3.2934222e+00,  2.0472670e-01,
        -1.4110570e-01,  1.3574346e+00,  9.9579597e-01, -2.5600688e+00,
        -6.5420246e-01, -1.1953276e+00, -1.6369368e+00, -1.3468333e+00,
         7.3794717e-01, -1.3751336e+00,  7.9342991e-01,  3.8684255e-01,
        -9.8729527e-01,  1.1786953e+00,  2.0081272e+00,  1.07271