In [2]:
import haliax as hax  # this amuses me more than it should
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# smaller versions of the numbers from the above table
Layer = hax.Axis("layer", 4)
Head = hax.Axis("head", 8)
Key = hax.Axis("key", 16)
Embed = hax.Axis("embed", 32)
Mlp = hax.Axis("mlp", Embed.size * 4)  # this is the "feed-forward size", above

# other numbers we need
Batch = hax.Axis("batch", 8)
Pos = hax.Axis("position", 128)  # how long each sequence is
Vocab = hax.Axis("vocab", len(tokenizer))

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [4]:
import jax.numpy as jnp

bias = hax.zeros(Mlp)
weight = hax.ones((Embed, Mlp))
word_embedding = hax.zeros((Vocab, Embed))
data = hax.ones((Batch, Pos), dtype=jnp.int32)
layer_indices = hax.arange(Layer)  # 0...Layer.size

In [11]:
layer_indices

NamedArray(array=Array([0, 1, 2, 3], dtype=int32), axes=(Axis(name='layer', size=4),))

In [12]:
a = jnp.zeros((32, 32 * 4))
named_a = hax.named(a, (Embed, Mlp))
named_a = hax.named(a, ("embed", "mlp"))  # ok, b/c axis sizes can be inferred

In [14]:
import jax.random
from jax.random import PRNGKey

base_key = PRNGKey(0)
k_w, k_e, k_d = jax.random.split(base_key, 3)  # keys for each of the generations we do below


In [15]:

weight = hax.random.normal(k_w, (Embed, Mlp))
word_embedding = hax.random.normal(k_e, (Vocab, Embed))
data = hax.random.randint(k_d, (Batch, Pos), 0, Vocab.size)  # samples from [0, Vocab)

In [27]:
m_weight = hax.mean(weight, Embed) # average each of the 'Embed' rows?
m_weight = hax.mean(weight, "embed")  # also ok
total = hax.sum(weight, (Embed, Mlp))  # equivalent to hax.sum(weight)

In [35]:
M = hax.Axis("M", 5)
N = hax.Axis("N", 4)

a = hax.arange(M)
b = hax.arange(N)

print(a)
print(b)

c = a.broadcast_axis(N) * b
print(c.axes)
print(c.array)

NamedArray(array=Array([0, 1, 2, 3, 4], dtype=int32), axes=(Axis(name='M', size=5),))
NamedArray(array=Array([0, 1, 2, 3], dtype=int32), axes=(Axis(name='N', size=4),))
(Axis(name='N', size=4), Axis(name='M', size=5))
[[ 0  0  0  0  0]
 [ 0  1  2  3  4]
 [ 0  2  4  6  8]
 [ 0  3  6  9 12]]


In [42]:
f = a.broadcast_axis(N)

In [70]:
weight = hax.random.normal(k_w, (Embed, Mlp))
word_embedding = hax.random.normal(k_e, (Vocab, Embed)) # (column, row)

In [71]:
word_embedding.axes

(Axis(name='vocab', size=50257), Axis(name='embed', size=32))