In [1]:
import jax

# https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
jax.config.update('jax_threefry_partitionable', True)

from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils

from typing import Callable
import numpy as np
from functools import partial
from helpers import ShardedTrainer, MeshManager, get_batch_gen

from hyena import HyenaOperator



In [2]:
key = random.PRNGKey(0)

I0000 00:00:1698094846.926878 1350637 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/amir/.venv311/lib/python3.11/site-packages/libtpu/libtpu.so
I0000 00:00:1698094846.926959 1350637 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1698094846.926964 1350637 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I0000 00:00:1698094849.939847 1350637 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [8]:
# causality check

layer = HyenaOperator(
    max_len=256,
    d_model=128,
    pos_embed_dim=7,
    filter_features=64,
    num_filter_layers=4
)

key, init_key, drop_key, data_key = random.split(key, 4)

x = random.normal(data_key, (1, 256, 128))

params = layer.init({"params": init_key, "dropout": drop_key}, x)

@jax.grad
def out_grads(x):
    logits = layer.apply(params, x, rngs={"dropout": drop_key}, train=True)
    # Get some scalar-valued output with respect to which to evaluate grad.
    out = jnp.sum(logits[:, 10, :])
    return out


grads = out_grads(x)
# Expectation: grad wrt sequence items with pos > 10 should be negligible, with pos <= 10 significant.
print(jnp.sum(jnp.abs(grads[0,11:,:])))
print(jnp.sum(jnp.abs(grads[0,:11,:])))

0.0003403878
475.82037


In [9]:
# performance test

jit_apply = jax.jit(lambda x: layer.apply(params, x, rngs={"dropout": drop_key}, train=True))

# warmup
jit_apply(x)

%timeit jit_apply(x).block_until_ready()

338 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
# standard transformer decoder

class HyenaDecoderLayer(nn.Module):
    features: int
    max_len: int
    filter_features: int
    num_filter_layers: int
    pos_embed_dim: int
    order: int
    hidden_features: int
    dropout: float = 0.0

    @nn.compact
    def __call__(self, x, train: bool = True):
        residual = x
        x = nn.LayerNorm()(x)
        x = HyenaOperator(
            self.max_len,
            self.features,
            self.pos_embed_dim,
            self.filter_features,
            self.num_filter_layers,
            self.order,
            dropout=self.dropout,
        )(x, train=train)
        x = nn.Dropout(rate=self.dropout)(x, deterministic=not train)
        x = residual + x

        residual = x
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.hidden_features)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.features)(x)
        x = nn.Dropout(rate=self.dropout)(x, deterministic=not train)
        x = residual + x

        return x


class HyenaDecoder(nn.Module):
    embedding: nn.Module
    layer_func: Callable
    n_layers: int
    dropout: float = 0.0

    @nn.compact
    def __call__(self, tokens, train: bool = True):
        embeds = self.embedding(tokens)
        embeds = nn.Dropout(rate=self.dropout)(embeds, deterministic=not train)

        for _ in range(self.n_layers):
            embeds = self.layer_func(dropout=self.dropout)(embeds, train=train)
        
        embeds = nn.LayerNorm()(embeds)
        logits = self.embedding.attend(embeds)
        return logits * 1 / jnp.sqrt(self.embedding.features)

In [11]:
n_dim = 128
n_layers = 6
seq_len = 512
embed_dim = 7
filter_features = 64
num_filters_layers = 4
order = 2
dropout = 0.1

embed_fn = partial(nn.Embed, num_embeddings=65, features=n_dim)
layer = partial(
    HyenaDecoderLayer,
    features=n_dim,
    max_len=seq_len,
    filter_features=filter_features,
    num_filter_layers=num_filters_layers,
    pos_embed_dim=embed_dim,
    order=order,
    hidden_features=n_dim * 4,
)
model = HyenaDecoder(
    embedding=embed_fn(),
    layer_func=layer,
    n_layers=n_layers,
    dropout=dropout
)

In [12]:
class ShakespeareTrainer(ShardedTrainer):
    def get_loss_function(self):
        def calculate_loss(params, rng, batch, train):
            inp_data, labels = batch

            rng, dropout_apply_rng = random.split(rng)
            logits = self.model.apply({'params': params}, inp_data, train=train, rngs={'dropout': dropout_apply_rng})
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
            acc = (logits.argmax(axis=-1) == labels).mean()
            return loss, (acc, rng)
        return calculate_loss
    
    def init_model(self, exmp_batch):
        self.rng = jax.random.PRNGKey(self.seed)
        self.rng, init_rng, dropout_init_rng = jax.random.split(self.rng, 3)

        # learning rate schedule and optimizer
        sched = optax.warmup_cosine_decay_schedule(
            init_value=0,
            peak_value=1e-3,
            warmup_steps=100,
            decay_steps=self.max_iters,
            end_value=7e-4,
        )
        optimizer = optax.chain(optax.clip_by_global_norm(1.0),
                                optax.adamw(sched, weight_decay=0.1, b2=0.98))

        # initialization function
        def init_fn(init_rng, dropout_init_rng, x, model, optimizer):
            inp_data = x[0]
            variables = self.model.init({'params': init_rng, 'dropout': dropout_init_rng}, inp_data, train=True)
            params = variables['params']
            state = train_state.TrainState.create(
                apply_fn=model.apply,
                params=params,
                tx=optimizer
            )
            return state

        # get init_fn output sharding (state)
        abstract_vars, logical_state_spec = MeshManager.get_var_sharding(init_fn,
                                                                         init_rng,
                                                                         dropout_init_rng,
                                                                         exmp_batch,
                                                                         model=self.model,
                                                                         optimizer=optimizer)
        # convert logical to physical sharding according to rules
        rules = (('batch', 'data'),)
        state_sharding = self.mesh_manager.logical_to_mesh(logical_state_spec, rules)

        jitted_init = jax.jit(
            init_fn, static_argnums=(3,4),
            in_shardings=(None, None, self.data_sharding),
            out_shardings=state_sharding
        )
        
        initialized_state = jitted_init(init_rng, dropout_init_rng, exmp_batch, model, optimizer)

        return initialized_state, state_sharding

In [13]:
device_count = jax.device_count()

device_mesh = mesh_utils.create_device_mesh((8,1))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
mesh_manager = MeshManager(mesh)
print(mesh)

data_shard_spec = PartitionSpec('data', None)
data_sharding = mesh_manager.mesh_sharding(data_shard_spec)

Mesh(device_ids=array([[0],
       [1],
       [2],
       [3],
       [6],
       [7],
       [4],
       [5]]), axis_names=('data', 'model'))


In [14]:
from urllib import request

# Test on the tiny Shakespeare dataset.
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = request.urlopen(url)
text = response.read().decode("utf-8")

vocab = sorted(set(text))
char2token = {c: i for (i, c) in enumerate(vocab)}
token2char = {i: c for (i, c) in enumerate(vocab)}
tokens = jnp.array([char2token[c] for c in text])

split_idx = int(0.9*len(tokens))
train, test = tokens[:split_idx], tokens[split_idx:]

In [15]:
def get_batch(idxs, data, seq_len):
    tok_idxs = jnp.array(idxs)[:, jnp.newaxis] + jnp.arange(seq_len)
    input_tokens = data[tok_idxs]
    target_tokens = data[tok_idxs + 1]

    batch = (input_tokens, target_tokens)
    sharded_batch = mesh_manager.shard_data(batch, data_sharding)

    return sharded_batch

get_train_batch = partial(get_batch, data=train, seq_len=seq_len)
get_test_batch = partial(get_batch, data=test, seq_len=seq_len)

exmp_batch = get_batch(np.arange(0, device_count), train, seq_len)
jax.debug.visualize_array_sharding(exmp_batch[0])

In [16]:
n_epochs = 40
batch_size = 64

# Don't choose inputs from the last seq_len indices as targets would exceed train/test size.
train_size = train.shape[0]-(seq_len+1)
test_size = test.shape[0]-(seq_len+1)

# Don't iterate over every single item in an epoch (unnecessarily large overlap between batches).
train_steps_per_epoch = train_size//(batch_size*256)
test_steps_per_epoch = test_size//(batch_size*25)

get_train_gen = partial(get_batch_gen,
                        train_size,
                        get_train_batch,
                        batch_size,
                        shuffle=True,
                        steps_per_epoch=train_steps_per_epoch)
get_val_gen = partial(get_batch_gen,
                        test_size,
                        get_test_batch,
                        batch_size,
                        shuffle=True,
                        steps_per_epoch=test_steps_per_epoch)

In [17]:
key, train_key = random.split(key)

num_train_iters = train_steps_per_epoch * n_epochs

# Create a trainer module with specified hyperparameters
trainer = ShakespeareTrainer(model,
                    model_name='ShakespeareHyena',
                    exmp_batch=exmp_batch,
                    max_iters=num_train_iters,
                    data_sharding=data_sharding,
                    mesh_manager=mesh_manager)



In [18]:
with mesh:
    trainer.train_model(get_train_gen, get_val_gen, num_epochs=n_epochs)

val_gen, train_key = get_val_gen(key=train_key)
val_acc, val_loss = trainer.eval_model(val_gen)

# Bind parameters to model for easier inference
trainer.model_bd = trainer.model.bind({'params': trainer.state.params})

  0%|          | 0/40 [00:00<?, ?it/s]

Training: 0it [00:00, ?it/s]

Epoch 1, train loss: 4.0149312019348145, accuracy: 0.11218111962080002
Epoch 1, val loss: 3.791551113128662, accuracy: 0.17125403881072998


Training: 0it [00:00, ?it/s]

Epoch 2, train loss: 3.501777410507202, accuracy: 0.25221025943756104
Epoch 2, val loss: 3.1810083389282227, accuracy: 0.3415014147758484


Training: 0it [00:00, ?it/s]

Epoch 3, train loss: 2.9595961570739746, accuracy: 0.35975173115730286
Epoch 3, val loss: 2.726419687271118, accuracy: 0.3989589512348175


Training: 0it [00:00, ?it/s]

Epoch 4, train loss: 2.5661888122558594, accuracy: 0.41193678975105286
Epoch 4, val loss: 2.3917236328125, accuracy: 0.4314327538013458


Training: 0it [00:00, ?it/s]

Epoch 5, train loss: 2.274303436279297, accuracy: 0.4431597590446472
Epoch 5, val loss: 2.157914638519287, accuracy: 0.4516880214214325


Training: 0it [00:00, ?it/s]

Epoch 6, train loss: 2.0633928775787354, accuracy: 0.4688175320625305
Epoch 6, val loss: 1.9853150844573975, accuracy: 0.47214099764823914


Training: 0it [00:00, ?it/s]

Epoch 7, train loss: 1.899193286895752, accuracy: 0.49304500222206116
Epoch 7, val loss: 1.860067367553711, accuracy: 0.488347589969635


Training: 0it [00:00, ?it/s]

Epoch 8, train loss: 1.779293179512024, accuracy: 0.5109378099441528
Epoch 8, val loss: 1.7787940502166748, accuracy: 0.49998098611831665


Training: 0it [00:00, ?it/s]

Epoch 9, train loss: 1.6863524913787842, accuracy: 0.5263186693191528
Epoch 9, val loss: 1.7086975574493408, accuracy: 0.5103892683982849


Training: 0it [00:00, ?it/s]

Epoch 10, train loss: 1.6162011623382568, accuracy: 0.536638081073761
Epoch 10, val loss: 1.6681420803070068, accuracy: 0.5164732933044434


Training: 0it [00:00, ?it/s]

Epoch 11, train loss: 1.5634039640426636, accuracy: 0.5451920032501221
Epoch 11, val loss: 1.634684443473816, accuracy: 0.5254865884780884


Training: 0it [00:00, ?it/s]

Epoch 12, train loss: 1.5197592973709106, accuracy: 0.5531581044197083
Epoch 12, val loss: 1.5946767330169678, accuracy: 0.5327993035316467


Training: 0it [00:00, ?it/s]

Epoch 13, train loss: 1.483270287513733, accuracy: 0.5598369836807251
Epoch 13, val loss: 1.5763037204742432, accuracy: 0.5365680456161499


Training: 0it [00:00, ?it/s]

Epoch 14, train loss: 1.4509625434875488, accuracy: 0.5663596987724304
Epoch 14, val loss: 1.559666633605957, accuracy: 0.5423752069473267


Training: 0it [00:00, ?it/s]

Epoch 15, train loss: 1.4306414127349854, accuracy: 0.5696631073951721
Epoch 15, val loss: 1.5368115901947021, accuracy: 0.5470408797264099


Training: 0it [00:00, ?it/s]

Epoch 16, train loss: 1.4068697690963745, accuracy: 0.5741622447967529
Epoch 16, val loss: 1.5246070623397827, accuracy: 0.5481204986572266


Training: 0it [00:00, ?it/s]

Epoch 17, train loss: 1.3909343481063843, accuracy: 0.5775141716003418
Epoch 17, val loss: 1.5133845806121826, accuracy: 0.5499081611633301


Training: 0it [00:00, ?it/s]

Epoch 18, train loss: 1.375877022743225, accuracy: 0.5809716582298279
Epoch 18, val loss: 1.505707859992981, accuracy: 0.5519471168518066


Training: 0it [00:00, ?it/s]

Epoch 19, train loss: 1.3627945184707642, accuracy: 0.583055853843689
Epoch 19, val loss: 1.4985592365264893, accuracy: 0.5551549792289734


Training: 0it [00:00, ?it/s]

Epoch 20, train loss: 1.3479368686676025, accuracy: 0.5862631797790527
Epoch 20, val loss: 1.4887877702713013, accuracy: 0.5582107305526733


Training: 0it [00:00, ?it/s]

Epoch 21, train loss: 1.3416390419006348, accuracy: 0.5872867703437805
Epoch 21, val loss: 1.4888650178909302, accuracy: 0.5587790608406067


Training: 0it [00:00, ?it/s]

Epoch 22, train loss: 1.3287646770477295, accuracy: 0.5910634398460388
Epoch 22, val loss: 1.4821523427963257, accuracy: 0.560383677482605


Training: 0it [00:00, ?it/s]

Epoch 23, train loss: 1.3183168172836304, accuracy: 0.5931326150894165
Epoch 23, val loss: 1.473443865776062, accuracy: 0.5614863038063049


Training: 0it [00:00, ?it/s]

Epoch 24, train loss: 1.313948392868042, accuracy: 0.5935578942298889
Epoch 24, val loss: 1.4609413146972656, accuracy: 0.564784824848175


Training: 0it [00:00, ?it/s]

Epoch 25, train loss: 1.3001668453216553, accuracy: 0.5972680449485779
Epoch 25, val loss: 1.45468008518219, accuracy: 0.5652434825897217


Training: 0it [00:00, ?it/s]

Epoch 26, train loss: 1.2967040538787842, accuracy: 0.5980319380760193
Epoch 26, val loss: 1.4609057903289795, accuracy: 0.566074550151825


Training: 0it [00:00, ?it/s]

Epoch 27, train loss: 1.2916759252548218, accuracy: 0.5988369584083557
Epoch 27, val loss: 1.453465461730957, accuracy: 0.5678883194923401


Training: 0it [00:00, ?it/s]

Epoch 28, train loss: 1.2828031778335571, accuracy: 0.6011037230491638
Epoch 28, val loss: 1.4495930671691895, accuracy: 0.569448709487915


Training: 0it [00:00, ?it/s]

Epoch 29, train loss: 1.2764837741851807, accuracy: 0.6032164692878723
Epoch 29, val loss: 1.4473909139633179, accuracy: 0.568307638168335


Training: 0it [00:00, ?it/s]

Epoch 30, train loss: 1.2716366052627563, accuracy: 0.6041309833526611
Epoch 30, val loss: 1.4437079429626465, accuracy: 0.570770263671875


Training: 0it [00:00, ?it/s]

Epoch 31, train loss: 1.267196774482727, accuracy: 0.6054187417030334
Epoch 31, val loss: 1.432915449142456, accuracy: 0.5716667771339417


Training: 0it [00:00, ?it/s]

Epoch 32, train loss: 1.2644811868667603, accuracy: 0.6058909893035889
Epoch 32, val loss: 1.4323363304138184, accuracy: 0.57298743724823


Training: 0it [00:00, ?it/s]

Epoch 33, train loss: 1.2542824745178223, accuracy: 0.6082468628883362
Epoch 33, val loss: 1.4410500526428223, accuracy: 0.5711709856987


Training: 0it [00:00, ?it/s]

Epoch 34, train loss: 1.2522103786468506, accuracy: 0.608751118183136
Epoch 34, val loss: 1.4313586950302124, accuracy: 0.5740825533866882


Training: 0it [00:00, ?it/s]

Epoch 35, train loss: 1.24908447265625, accuracy: 0.6094680428504944
Epoch 35, val loss: 1.4331425428390503, accuracy: 0.574985682964325


Training: 0it [00:00, ?it/s]

Epoch 36, train loss: 1.244664192199707, accuracy: 0.6108153462409973
Epoch 36, val loss: 1.4388471841812134, accuracy: 0.5726384520530701


Training: 0it [00:00, ?it/s]

Epoch 37, train loss: 1.2403591871261597, accuracy: 0.6117779016494751
Epoch 37, val loss: 1.431612491607666, accuracy: 0.57403165102005


Training: 0it [00:00, ?it/s]

Epoch 38, train loss: 1.2349002361297607, accuracy: 0.6131226420402527
Epoch 38, val loss: 1.4228782653808594, accuracy: 0.5760365128517151


Training: 0it [00:00, ?it/s]

Epoch 39, train loss: 1.2320034503936768, accuracy: 0.6140176653862
Epoch 39, val loss: 1.43060302734375, accuracy: 0.5749737024307251


Training: 0it [00:00, ?it/s]

Epoch 40, train loss: 1.2268190383911133, accuracy: 0.6152999401092529
Epoch 40, val loss: 1.4329246282577515, accuracy: 0.5764301419258118


In [20]:
@jax.jit
def pred(inp, key):
    logits = trainer.model_bd(inp, train=False)
    # Sample instead of simply taking argmax to reduce repetition.
    token_idx = random.categorical(key, logits, axis=-1)
    return token_idx

# warmup
pred(jnp.zeros((1, seq_len), dtype=jnp.int32), key).shape

(1, 512)

In [22]:
cond_text = "ROMEO:"
cond_tokens = [char2token[c] for c in cond_text]
cur_in = jnp.array([cond_tokens])
cur_idx = cur_in.shape[-1]
cur_in = np.pad(cur_in, ((0,0),(0,seq_len-cur_idx)))

sample_key = random.PRNGKey(3)

# number of chars to generate
num_gen = 800

for i in range(0, num_gen):
    sample_key = random.fold_in(sample_key, i)

    pred_tokens = pred(cur_in, sample_key)
    cur_pred_token = pred_tokens[0,cur_idx-1]

    if cur_idx < seq_len:
        cur_in[0,cur_idx] = cur_pred_token
        cur_idx += 1
    else:
        cur_in = np.roll(cur_in, shift=-1, axis=-1)
        cur_in[0,cur_idx-1] = cur_pred_token
    
    cond_tokens.append(int(cur_pred_token))

print(''.join([token2char[c] for c in cond_tokens]))

ROMEO:
Alas, my father, and I know come; let's prejoice
To die for reply, and they there, like man
Where will abuse off: mine arst Abury,'
I dead to. This true that was all eyes,
Passing him till 'we must speak of monsta's maid,
Should not enbear him! hear the duke, whood they are
commend to thee to do it?
And send not Barnardine, but giving a scorn
To find a perfect.

MENENIUS:
Tell him, you have sead'd to stand up to death.
You shall not indeed, say 'stain.
Long that was won; I'll not such a man!
shut it not that beheld how all I like thee,
Ratcliff, tribunes no dear silence we hereafter.

PRINCE EDWARD:
Most ragely from him succession:
I pretty troopque, Richmond the gates to say,
By curity of the matters but herein;
Which with a few day in the hire infect.
Look your silence than my name in k
