HHLO is a Haskell library and runtime for building, compiling, and executing machine learning programs targeting StableHLO, the portable, versioned intermediate representation of the OpenXLA ecosystem.
Instead of replicating JAX's Python-based tracing infrastructure, HHLO generates StableHLO MLIR text directly from Haskell and compiles it to CPU or GPU via the PJRT plugin interface.
HHLO is structured in four layers:
┌─────────────────────────────────────┐
│ EDSL (HHLO.EDSL.Ops) │ Type-safe frontend: add, matmul, relu, etc.
├─────────────────────────────────────┤
│ IR Builder (HHLO.IR.Builder) │ Stateful monad for constructing MLIR
├─────────────────────────────────────┤
│ Pretty Printer (HHLO.IR.Pretty) │ Emits StableHLO MLIR text
├─────────────────────────────────────┤
│ PJRT Runtime (HHLO.Runtime.*) │ Compile → Execute on CPU or GPU
└─────────────────────────────────────┘
Text Emission + PJRT
The library emits StableHLO MLIR text directly and hands it to PJRT_Client_Compile. This is the same path used by JAX's C++ backend and avoids the heavy dependency of building LLVM/MLIR from source.
Phantom Types
Every tensor carries its shape and dtype as phantom type parameters:
Tensor '[2, 3] 'F32 -- 2×3 matrix of Float32Matmul, broadcast, and conv shapes are checked at compile time via type families.
ForeignPtr Finalizers
PJRT buffers and executables are managed by ForeignPtr finalizers that automatically call PJRT_Buffer_Destroy and PJRT_LoadedExecutable_Destroy when values are garbage-collected. You can still let references drop out of scope without explicit cleanup.
Dynamic Output Counts
The runtime queries the compiled executable for its actual number of outputs via PJRT_Executable_NumOutputs instead of guessing or hardcoding a maximum.
Async Execution
HHLO.Runtime.Async provides true non-blocking execution: executeAsync returns buffer handles immediately, bufferReady polls for completion, and awaitBuffers blocks until device-side computation finishes.
Device Enumeration & Selection
HHLO.Runtime.Device lets you discover and select specific GPUs at runtime:
addressableDevices api client -- list all devices
deviceKind api dev -- "cpu" or "NVIDIA GeForce RTX 5090"
defaultGPUDevice api client -- first non-CPU deviceMulti-GPU Inference Scaling
HHLO.Runtime.Execute provides executeReplicas for running the same compiled model concurrently across multiple GPUs:
compileWithOptions api client mlirText
(defaultCompileOptions { optNumReplicas = numDevs })
-- Launch independent forward passes on all GPUs
executeReplicas api exec
[ (gpu0, [bufA0, bufB0])
, (gpu1, [bufA1, bufB1])
, ...
]Multi-Result Operations
The AST Operation type supports multiple results, enabling ops like stablehlo.rng_bit_generator and multi-value control flow:
-- Two-result operation
(newState, output) <- rngBitGenerator stateMulti-Value Control Flow
whileLoop2 / conditional2 carry multiple typed tensors through loops and conditionals without manual packing:
-- Loop with two accumulators: counter and running sum
(resultCounter, resultSum) <- whileLoop2 counter0 sum0
(\c s -> compare c limit "LT")
(\c s -> do
cNext <- add c one
sNext <- add s cNext
returnTuple2 cNext sNext)Random Number Generation
Three RNG primitives are exposed in the EDSL:
uniform <- rngUniform a b -- uniform in [a, b)
normal <- rngNormal -- standard normal (mean 0, std 1)
(newSt, bits) <- rngBitGenerator state -- Threefry bit generator- GHC 9.6+ and Cabal 3.10+
- Linux x86_64 (other platforms supported by PJRT artifacts may work)
curl,tar, and standard C toolchain (gccorclang)libstdc++andlibdl(usually present on Linux)
Run the provided script to download prebuilt PJRT plugins:
./pjrt_script.shThis downloads libpjrt_cpu.so from the zml/pjrt-artifacts nightly builds into deps/pjrt/. If you have an NVIDIA GPU with nvidia-smi available, the CUDA plugin is also fetched automatically.
cabal build allThis compiles the library, the demo, the examples, and the test suite.
cabal run example-add --flag=examples
cabal testNote: All
example-*executables are guarded by theexamplesflag inhhlo.cabal(defaults toFalse). Append--flag=examplesto everycabal run example-*command.
The PJRT CUDA plugin depends on NVIDIA runtime libraries: cuDNN, NCCL, and NVSHMEM. These are commonly available via conda, pip, or system packages.
If you already have them (e.g. via PyTorch or JAX installations), simply run:
./setup_gpu_env.sh
source ~/.bashrcThis idempotent script auto-discovers the libraries and appends them to ~/.bashrc. After that, GPU examples work directly:
cabal run example-gpu-add --flag=examples
cabal run example-gpu-matmul-bench --flag=examples
cabal run example-multi-gpu-inference --flag=examples{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
import HHLO.Core.Types
import HHLO.EDSL.Ops
import HHLO.IR.AST (FuncArg(..), TensorType(..))
import HHLO.IR.Builder
import HHLO.IR.Pretty
import qualified Data.Text as T
-- Build a program: c = a + b
program :: Module
program = moduleFromBuilder @'[2,2] @'F32 "main"
[ FuncArg "a" (TensorType [2, 2] F32)
, FuncArg "b" (TensorType [2, 2] F32)
]
$ do
a <- arg
b <- arg
c <- add a b
return c
main :: IO ()
main = T.putStrLn (render program)Output:
module {
func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
}cabal run hhlo-demoThe demo builds a stablehlo.add program via the EDSL, compiles it with PJRT CPU, creates F32 input buffers, executes, and reads back the result:
=== HHLO End-to-End Demo ===
Loading PJRT CPU plugin...
Plugin loaded.
...
Result: [6.0,8.0,10.0,12.0]
SUCCESS: Results match expected values!
Standalone examples are provided in examples/:
| # | Command | Description |
|---|---|---|
| 1 | cabal run example-add --flag=examples |
Element-wise c = a + b |
| 2 | cabal run example-matmul --flag=examples |
2×3 @ 3×2 matrix multiply |
| 3 | cabal run example-chain-ops --flag=examples |
(a + b) * (a - b) |
| 4 | cabal run example-async --flag=examples |
Async executeAsync + relu |
| 5 | cabal run example-mlp --flag=examples |
2-layer MLP |
| 6 | cabal run example-mlp-batched --flag=examples |
Batched MLP |
| 7 | cabal run example-tuple --flag=examples |
Multi-result func.func |
| 8 | cabal run example-reduce --flag=examples |
reduceSum over all dimensions |
| 9 | cabal run example-softmax --flag=examples |
1-D and batched 2-D softmax |
| 10 | cabal run example-conv2d --flag=examples |
NHWC conv2d |
| 11 | cabal run example-batch-norm --flag=examples |
Batch norm inference |
| 12 | cabal run example-while --flag=examples |
whileLoop count-up |
| 13 | cabal run example-conditional --flag=examples |
conditional if-then-else |
| 14 | cabal run example-gather --flag=examples |
gather rows from matrix |
| 15 | cabal run example-scatter --flag=examples |
scatter replace into vector |
| 16 | cabal run example-slice --flag=examples |
slice sub-array extraction |
| 17 | cabal run example-pad --flag=examples |
pad with edge/interior padding |
| 18 | cabal run example-dynamic-slice --flag=examples |
dynamicSlice runtime indices |
| 19 | cabal run example-sort --flag=examples |
sort 1-D ascending |
| 20 | cabal run example-select --flag=examples |
Element-wise ternary select |
| 21 | cabal run example-map --flag=examples |
map with custom computation |
| 22 | cabal run example-new-ops-smoke-test --flag=examples |
Smoke test for newer ops |
| 23 | cabal run example-resnet --flag=examples |
ResNet-18 toy (8×8 input) |
| 24 | cabal run example-alexnet --flag=examples |
AlexNet toy (16×16 input) |
| 25 | cabal run example-transformer --flag=examples |
Transformer encoder (1×4×16) |
| 26 | cabal run example-unet --flag=examples |
UNet segmentation toy (16×16) |
| 30 | cabal run example-rng-uniform --flag=examples |
rngUniform random floats [0,1) |
| 31 | cabal run example-rng-normal --flag=examples |
rngNormal standard normal distribution |
| 32 | cabal run example-rng-bit-generator --flag=examples |
rngBitGenerator Threefry PRNG |
| 33 | cabal run example-multi-value-loop --flag=examples |
whileLoop2 with two loop-carried values |
| 27 | cabal run example-gpu-add --flag=examples |
GPU smoke test |
| 28 | cabal run example-gpu-matmul-bench --flag=examples |
GPU 4096×4096 benchmark |
| 29 | cabal run example-multi-gpu-inference --flag=examples |
Multi-GPU concurrent matmul |
cabal testRuns 124 tests across three tiers:
- Tier 1 — Golden tests — Verify rendered MLIR text for EDSL ops, IR constructs, NN layers, and control flow.
- Tier 2 — End-to-end runtime tests — Load the PJRT CPU plugin, compile StableHLO programs, execute them, and verify numerical results. Covers arithmetic, matmul, reductions, data movement, and NN ops.
- Tier 3 — Runtime integration tests — Buffer metadata queries, async execution, and error handling.
HHLO_TEST_GPU=1 cabal testRuns the full 124 CPU tests plus 6 additional GPU integration tests:
EndToEnd.GPU— GPU availability and device enumerationRuntime.BufferGPU— Buffer round-trip and metadata queries on GPURuntime.AsyncGPU— Async execution andbufferReadypolling on GPURuntime.MultiGPU— ConcurrentexecuteReplicasacross all GPUs
Sample output:
HHLO Tests
EDSL.Ops
Binary element-wise
add: OK
...
EndToEnd.Arithmetic
relu: OK (0.02s)
...
Runtime.Buffer
buffer round-trip f32: OK
Runtime.Async
buffer ready after sync execute: OK (0.02s)
EndToEnd.GPU
gpu available: OK
Runtime.BufferGPU
gpu buffer round-trip f32: OK
Runtime.AsyncGPU
gpu executeAsync + await: OK
Runtime.MultiGPU
execute replicas on all GPUs: OK
All 130 tests passed (16.27s)
.
├── app/ # hhlo-demo executable
├── cbits/ # C shim around PJRT C API
│ ├── pjrt_c_api.h # Upstream PJRT header
│ ├── pjrt_shim.c # Thin wrapper exposing flat C functions
│ └── pjrt_shim.h # C header for the shim
├── deps/
│ └── pjrt/ # Downloaded PJRT plugins (.so files)
│ └── lib_symlinks/ # Compatibility symlinks for missing library versions
├── doc/ # Architecture and design documents
├── examples/ # Standalone example programs (01–33)
├── src/HHLO/
│ ├── Core/Types.hs # DType, Shape, HostType type families
│ ├── IR/
│ │ ├── AST.hs # MLIR AST (Operation, Function, Module)
│ │ ├── Builder.hs # Stateful Builder monad + Tensor/Tuple GADTs
│ │ └── Pretty.hs # MLIR text pretty-printer
│ ├── EDSL/Ops.hs # Type-safe frontend ops (50+ ops)
│ └── Runtime/
│ ├── PJRT/
│ │ ├── FFI.hs # C FFI declarations
│ │ ├── Types.hs # Opaque pointer newtypes + buffer type constants
│ │ ├── Error.hs # PJRT error handling
│ │ └── Plugin.hs # Backend-agnostic plugin loading (withPJRT)
│ ├── Device.hs # Device enumeration & selection
│ ├── Compile.hs # MLIR → PJRT executable
│ ├── Compile.hs # MLIR → PJRT executable (with `CompileOptions`)
│ ├── Execute.hs # Synchronous + device-targeted + multi-GPU replica execution
│ ├── Async.hs # Non-blocking execution with PJRT_Event
│ └── Buffer.hs # Host↔device buffer transfers + metadata queries
├── test/
│ ├── Test/
│ │ ├── EDSL/Ops.hs
│ │ ├── IR/
│ │ │ ├── Builder.hs
│ │ │ ├── Pretty.hs
│ │ │ ├── PrettyOps.hs
│ │ │ ├── PrettyNN.hs
│ │ │ └── PrettyControlFlow.hs
│ │ ├── Runtime/
│ │ │ ├── EndToEnd*.hs # CPU E2E test modules
│ │ │ ├── EndToEndGPU.hs # GPU availability test
│ │ │ ├── Buffer.hs
│ │ │ ├── BufferGPU.hs # GPU buffer integration tests
│ │ │ ├── Async.hs
│ │ │ ├── AsyncGPU.hs # GPU async tests
│ │ │ ├── MultiGPU.hs # Multi-GPU inference scaling tests
│ │ │ └── Errors.hs
│ │ └── Utils.hs
│ └── Main.hs
├── hhlo.cabal
├── pjrt_script.sh # Downloads PJRT plugins
├── setup_gpu_env.sh # Auto-configures LD_LIBRARY_PATH for GPU
└── README.md
The doc/ directory contains detailed design documents:
| Document | Contents |
|---|---|
implementation-design.md |
Four-layer architecture and design decisions |
progress-and-remaining-work.md |
Current status, completed features, and backlog |
test-suite-documentation.md |
Test catalog and tier descriptions |
MIT License — see LICENSE.