In [13]:
import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit, vmap, make_jaxpr

import numpy as np


In [14]:
key = random.PRNGKey(1)
normal = random.normal(key, (3, 3))
normal

Array([[ 0.690805  , -0.48744103, -1.155789  ],
       [ 0.12108463,  1.2010182 , -0.5078766 ],
       [ 0.91568655,  1.70968   , -0.36749417]], dtype=float32)

In [15]:
def iterate(mat, n: int):
    for i in range(n):
        mat = jnp.dot(mat, mat.T)
    return mat
jit_iterate = jit(iterate, static_argnames=("n",))

In [16]:
mhlo_module = jit_iterate.lower(normal, 2).compiler_ir(dialect="mhlo")


In [17]:
from jaxlib.mlir.ir import StringAttr
mhlo_module.operation.attributes["sym_name"] = StringAttr.get("the_module", mhlo_module.context)

In [18]:
mhlo_module.operation.get_asm()

'module @the_module {\n  func.func public @main(%arg0: tensor<3x3xf32> {mhlo.sharding = ""}) -> tensor<3x3xf32> {\n    %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x3xf32>) -> tensor<3x3xf32>\n    %1 = "mhlo.dot_general"(%arg0, %0) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>\n    %2 = "mhlo.transpose"(%1) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<3x3xf32>) -> tensor<3x3xf32>\n    %3 = "mhlo.dot_general"(%1, %2) {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>\n    return %3 : tensor<3x3xf32>\n  }\n}\n'

In [19]:
import iree.compiler as ireec
import iree.runtime as ireert

vmfb = ireec.compile_str(
    mhlo_module.operation.get_asm(),
    target_backends=["cuda"],
    input_type="mhlo",
    extra_args=[
        "--iree-hal-cuda-llvm-target-arch=sm_80",
        "--iree-flow-dump-dispatch-graph",
        "--iree-flow-dump-dispatch-graph-output-file=/tmp/attention_layer.dot",
    ],
)

In [20]:
config = ireert.Config(driver_name="cuda")
ctx = ireert.SystemContext(config=config)
vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, vmfb)
ctx.add_vm_module(vm_module)
module = ctx.modules.the_module

In [21]:
np.asarray(module["main"](normal))

array([[ 4.262611  ,  0.84735847,  1.5321438 ],
       [ 0.84735847,  8.47525   , 13.211212  ],
       [ 1.5321438 , 13.211212  , 20.759777  ]], dtype=float32)

In [22]:
from flax import linen as nn

class Net(nn.Module):
    n_hidden: int = 4
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.n_hidden, name="dense")(x)
        x = nn.relu(x)
        x = nn.Dense(features=1, name="dense_1")(x)
        return x

key = jax.random.PRNGKey(3)
model = Net(n_hidden=8)
vars = model.init(key, jnp.ones((4,)))

key, input_key = jax.random.split(key)
inputs = jax.random.normal(input_key, (4, 5, 4))
print("params =", jax.tree_map(jnp.shape, vars))

print("shape of output:", model.apply(vars, inputs).shape)

# Cool: `model` does not carry any state, so you can just recreate it.
# Differences are exactly zero:
jnp.max(Net(n_hidden=8).apply(vars, inputs) - Net(n_hidden=8).apply(vars, inputs))

params = FrozenDict({
    params: {
        dense: {
            bias: (8,),
            kernel: (4, 8),
        },
        dense_1: {
            bias: (1,),
            kernel: (8, 1),
        },
    },
})
shape of output: (4, 5, 1)


Array(0., dtype=float32)

: 