In [None]:
import haliax as hax  # this amuses me more than it should
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# smaller versions of the numbers from the above table
Layer = hax.Axis("layer", 4)
Head = hax.Axis("head", 8)
Key = hax.Axis("key", 16)
Embed = hax.Axis("embed", 32)
Mlp = hax.Axis("mlp", Embed.size * 4)  # this is the "feed-forward size", above

# other numbers we need
Batch = hax.Axis("batch", 8)
Pos = hax.Axis("position", 128)  # how long each sequence is
Vocab = hax.Axis("vocab", len(tokenizer))

In [None]:
import jax.numpy as jnp

bias = hax.zeros(Mlp)
weight = hax.ones((Embed, Mlp))
word_embedding = hax.zeros((Vocab, Embed))
data = hax.ones((Batch, Pos), dtype=jnp.int32)
layer_indices = hax.arange(Layer)  # 0...Layer.size

In [None]:
a = jnp.zeros((32, 32 * 4))
named_a = hax.named(a, (Embed, Mlp))
named_a = hax.named(a, ("embed", "mlp"))  # ok, b/c axis sizes can be inferred

In [None]:
import jax.random
from jax.random import PRNGKey

base_key = PRNGKey(0)
k_w, k_e, k_d = jax.random.split(base_key, 3)  # keys for each of the generations we do below


In [None]:

weight = hax.random.normal(k_w, (Embed, Mlp))
word_embedding = hax.random.normal(k_e, (Vocab, Embed))
data = hax.random.randint(k_d, (Batch, Pos), 0, Vocab.size)  # samples from [0, Vocab)

In [None]:
m_weight = hax.mean(weight, Embed) # average each of the 'Embed' rows?
m_weight = hax.mean(weight, "embed")  # also ok
total = hax.sum(weight, (Embed, Mlp))  # equivalent to hax.sum(weight)

In [None]:
M = hax.Axis("M", 5)
N = hax.Axis("N", 4)

a = hax.arange(M)
b = hax.arange(N)

print(a)
print(b)

c = a.broadcast_axis(N) * b
print(c.axes)
print(c.array)

In [None]:
f = a.broadcast_axis(N)

In [None]:
weight = hax.random.normal(k_w, (Embed, Mlp))
word_embedding = hax.random.normal(k_e, (Vocab, Embed)) # (column, row)
big_embed = hax.dot("embed", word_embedding, weight)

In [None]:
batched_weight = hax.random.normal(k_w, (Batch, Embed, Mlp))
batched_embed = hax.random.normal(k_e, (Batch, Vocab, Embed))

batched_big_embed = hax.dot("embed", batched_embed, batched_weight)

assert batched_big_embed.axes == (Batch, Vocab, Mlp)

In [2]:
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import numpy as onp

mesh = Mesh(onp.array(jax.devices()), ("data",))
print(jax.devices())

# quick little utility to visualize meshes themselves
# this piggybacks on a nice visualization in Jax.
def visualize_mesh(mesh):
  arr = jnp.zeros(mesh.devices.shape)
  arr = jax.device_put(arr, NamedSharding(mesh, PartitionSpec(*mesh.axis_names)))
  jax.debug.visualize_array_sharding(arr)

visualize_mesh(mesh)

NameError: name 'hax' is not defined

In [None]:

Batch = hax.Axis("batch", 128)
Feature = hax.Axis("feature", 64)

x = hax.random.uniform(PRNGKey(0), (Batch, Feature))
y = hax.random.uniform(PRNGKey(1), Batch)

def mse(pred, target):
    return hax.mean((pred - target) * (pred - target), axis=Batch)

W = hax.random.uniform(PRNGKey(2), (Feature,))

y_pred = hax.dot(Feature, x, W)

In [None]:
mse(y_pred, y)

In [None]:
help(hax.dot)

In [None]:
query = hax.random.uniform(PRNGKey(0), (Pos, Key))
key = hax.random.uniform(PRNGKey(1), (Key, KPos))

In [None]:
KPos = Pos.alias("key_position")

In [None]:
import jax.numpy as jnp
jnp.concatenate(jnp.array([]), jnp.array([1]))

In [3]:
import optax
optax.OptState

typing.Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]

NameError: name 'optax' is not defined

In [4]:
import jmp

In [5]:
jmp.get_policy("f32")

Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.float32'>, output_dtype=<class 'jax.numpy.float32'>)