Skip to content

overshiki/hhlo

Repository files navigation

HHLO — Haskell Frontend for StableHLO

HHLO is a Haskell library and runtime for building, compiling, and executing machine learning programs targeting StableHLO, the portable, versioned intermediate representation of the OpenXLA ecosystem.

Instead of replicating JAX's Python-based tracing infrastructure, HHLO generates StableHLO MLIR text directly from Haskell and compiles it to CPU or GPU via the PJRT plugin interface.


Design

HHLO is structured in four layers:

┌─────────────────────────────────────┐
│  EDSL (HHLO.EDSL.Ops)               │  Type-safe frontend: add, matmul, relu, etc.
├─────────────────────────────────────┤
│  IR Builder (HHLO.IR.Builder)       │  Stateful monad for constructing MLIR
├─────────────────────────────────────┤
│  Pretty Printer (HHLO.IR.Pretty)    │  Emits StableHLO MLIR text
├─────────────────────────────────────┤
│  PJRT Runtime (HHLO.Runtime.*)      │  Compile → Execute on CPU or GPU
└─────────────────────────────────────┘

Text Emission + PJRT

The library emits StableHLO MLIR text directly and hands it to PJRT_Client_Compile. This is the same path used by JAX's C++ backend and avoids the heavy dependency of building LLVM/MLIR from source.

Phantom Types

Every tensor carries its shape and dtype as phantom type parameters:

Tensor '[2, 3] 'F32   -- 2×3 matrix of Float32

Matmul, broadcast, and conv shapes are checked at compile time via type families.

ForeignPtr Finalizers

PJRT buffers and executables are managed by ForeignPtr finalizers that automatically call PJRT_Buffer_Destroy and PJRT_LoadedExecutable_Destroy when values are garbage-collected. You can still let references drop out of scope without explicit cleanup.

Dynamic Output Counts

The runtime queries the compiled executable for its actual number of outputs via PJRT_Executable_NumOutputs instead of guessing or hardcoding a maximum.

Async Execution

HHLO.Runtime.Async provides true non-blocking execution: executeAsync returns buffer handles immediately, bufferReady polls for completion, and awaitBuffers blocks until device-side computation finishes.

Device Enumeration & Selection

HHLO.Runtime.Device lets you discover and select specific GPUs at runtime:

addressableDevices api client        -- list all devices
deviceKind api dev                   -- "cpu" or "NVIDIA GeForce RTX 5090"
defaultGPUDevice api client          -- first non-CPU device

Multi-GPU Inference Scaling

HHLO.Runtime.Execute provides executeReplicas for running the same compiled model concurrently across multiple GPUs:

compileWithOptions api client mlirText
    (defaultCompileOptions { optNumReplicas = numDevs })

-- Launch independent forward passes on all GPUs
executeReplicas api exec
    [ (gpu0, [bufA0, bufB0])
    , (gpu1, [bufA1, bufB1])
    , ...
    ]

Multi-Result Operations

The AST Operation type supports multiple results, enabling ops like stablehlo.rng_bit_generator and multi-value control flow:

-- Two-result operation
(newState, output) <- rngBitGenerator state

Multi-Value Control Flow

whileLoop2 / conditional2 carry multiple typed tensors through loops and conditionals without manual packing:

-- Loop with two accumulators: counter and running sum
(resultCounter, resultSum) <- whileLoop2 counter0 sum0
    (\c s -> compare c limit "LT")
    (\c s -> do
        cNext <- add c one
        sNext <- add s cNext
        returnTuple2 cNext sNext)

Random Number Generation

Three RNG primitives are exposed in the EDSL:

uniform  <- rngUniform a b      -- uniform in [a, b)
normal   <- rngNormal            -- standard normal (mean 0, std 1)
(newSt, bits) <- rngBitGenerator state   -- Threefry bit generator

Installation

System Requirements

  • GHC 9.6+ and Cabal 3.10+
  • Linux x86_64 (other platforms supported by PJRT artifacts may work)
  • curl, tar, and standard C toolchain (gcc or clang)
  • libstdc++ and libdl (usually present on Linux)

Download PJRT Plugins

Run the provided script to download prebuilt PJRT plugins:

./pjrt_script.sh

This downloads libpjrt_cpu.so from the zml/pjrt-artifacts nightly builds into deps/pjrt/. If you have an NVIDIA GPU with nvidia-smi available, the CUDA plugin is also fetched automatically.

Build the Project

cabal build all

This compiles the library, the demo, the examples, and the test suite.


Usage

CPU (works out of the box)

cabal run example-add --flag=examples
cabal test

Note: All example-* executables are guarded by the examples flag in hhlo.cabal (defaults to False). Append --flag=examples to every cabal run example-* command.

GPU (requires runtime libraries)

The PJRT CUDA plugin depends on NVIDIA runtime libraries: cuDNN, NCCL, and NVSHMEM. These are commonly available via conda, pip, or system packages.

If you already have them (e.g. via PyTorch or JAX installations), simply run:

./setup_gpu_env.sh
source ~/.bashrc

This idempotent script auto-discovers the libraries and appends them to ~/.bashrc. After that, GPU examples work directly:

cabal run example-gpu-add --flag=examples
cabal run example-gpu-matmul-bench --flag=examples
cabal run example-multi-gpu-inference --flag=examples

EDSL Quick Start

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}

import HHLO.Core.Types
import HHLO.EDSL.Ops
import HHLO.IR.AST (FuncArg(..), TensorType(..))
import HHLO.IR.Builder
import HHLO.IR.Pretty
import qualified Data.Text as T

-- Build a program: c = a + b
program :: Module
program = moduleFromBuilder @'[2,2] @'F32 "main"
    [ FuncArg "a" (TensorType [2, 2] F32)
    , FuncArg "b" (TensorType [2, 2] F32)
    ]
    $ do
        a <- arg
        b <- arg
        c <- add a b
        return c

main :: IO ()
main = T.putStrLn (render program)

Output:

module {
  func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
      %0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
      return %0 : tensor<2x2xf32>
  }
}

Running the Demo

cabal run hhlo-demo

The demo builds a stablehlo.add program via the EDSL, compiles it with PJRT CPU, creates F32 input buffers, executes, and reads back the result:

=== HHLO End-to-End Demo ===
Loading PJRT CPU plugin...
Plugin loaded.
...
Result: [6.0,8.0,10.0,12.0]
SUCCESS: Results match expected values!

Running Examples

Standalone examples are provided in examples/:

# Command Description
1 cabal run example-add --flag=examples Element-wise c = a + b
2 cabal run example-matmul --flag=examples 2×3 @ 3×2 matrix multiply
3 cabal run example-chain-ops --flag=examples (a + b) * (a - b)
4 cabal run example-async --flag=examples Async executeAsync + relu
5 cabal run example-mlp --flag=examples 2-layer MLP
6 cabal run example-mlp-batched --flag=examples Batched MLP
7 cabal run example-tuple --flag=examples Multi-result func.func
8 cabal run example-reduce --flag=examples reduceSum over all dimensions
9 cabal run example-softmax --flag=examples 1-D and batched 2-D softmax
10 cabal run example-conv2d --flag=examples NHWC conv2d
11 cabal run example-batch-norm --flag=examples Batch norm inference
12 cabal run example-while --flag=examples whileLoop count-up
13 cabal run example-conditional --flag=examples conditional if-then-else
14 cabal run example-gather --flag=examples gather rows from matrix
15 cabal run example-scatter --flag=examples scatter replace into vector
16 cabal run example-slice --flag=examples slice sub-array extraction
17 cabal run example-pad --flag=examples pad with edge/interior padding
18 cabal run example-dynamic-slice --flag=examples dynamicSlice runtime indices
19 cabal run example-sort --flag=examples sort 1-D ascending
20 cabal run example-select --flag=examples Element-wise ternary select
21 cabal run example-map --flag=examples map with custom computation
22 cabal run example-new-ops-smoke-test --flag=examples Smoke test for newer ops
23 cabal run example-resnet --flag=examples ResNet-18 toy (8×8 input)
24 cabal run example-alexnet --flag=examples AlexNet toy (16×16 input)
25 cabal run example-transformer --flag=examples Transformer encoder (1×4×16)
26 cabal run example-unet --flag=examples UNet segmentation toy (16×16)
30 cabal run example-rng-uniform --flag=examples rngUniform random floats [0,1)
31 cabal run example-rng-normal --flag=examples rngNormal standard normal distribution
32 cabal run example-rng-bit-generator --flag=examples rngBitGenerator Threefry PRNG
33 cabal run example-multi-value-loop --flag=examples whileLoop2 with two loop-carried values
27 cabal run example-gpu-add --flag=examples GPU smoke test
28 cabal run example-gpu-matmul-bench --flag=examples GPU 4096×4096 benchmark
29 cabal run example-multi-gpu-inference --flag=examples Multi-GPU concurrent matmul

Tests

CPU Tests (default)

cabal test

Runs 124 tests across three tiers:

  • Tier 1 — Golden tests — Verify rendered MLIR text for EDSL ops, IR constructs, NN layers, and control flow.
  • Tier 2 — End-to-end runtime tests — Load the PJRT CPU plugin, compile StableHLO programs, execute them, and verify numerical results. Covers arithmetic, matmul, reductions, data movement, and NN ops.
  • Tier 3 — Runtime integration tests — Buffer metadata queries, async execution, and error handling.

GPU Tests

HHLO_TEST_GPU=1 cabal test

Runs the full 124 CPU tests plus 6 additional GPU integration tests:

  • EndToEnd.GPU — GPU availability and device enumeration
  • Runtime.BufferGPU — Buffer round-trip and metadata queries on GPU
  • Runtime.AsyncGPU — Async execution and bufferReady polling on GPU
  • Runtime.MultiGPU — Concurrent executeReplicas across all GPUs

Sample output:

HHLO Tests
  EDSL.Ops
    Binary element-wise
      add:                            OK
      ...
  EndToEnd.Arithmetic
    relu:                             OK (0.02s)
    ...
  Runtime.Buffer
    buffer round-trip f32:            OK
  Runtime.Async
    buffer ready after sync execute:  OK (0.02s)
  EndToEnd.GPU
    gpu available:                    OK
  Runtime.BufferGPU
    gpu buffer round-trip f32:        OK
  Runtime.AsyncGPU
    gpu executeAsync + await:         OK
  Runtime.MultiGPU
    execute replicas on all GPUs:     OK

All 130 tests passed (16.27s)

Project Structure

.
├── app/                    # hhlo-demo executable
├── cbits/                  # C shim around PJRT C API
│   ├── pjrt_c_api.h        # Upstream PJRT header
│   ├── pjrt_shim.c         # Thin wrapper exposing flat C functions
│   └── pjrt_shim.h         # C header for the shim
├── deps/
│   └── pjrt/               # Downloaded PJRT plugins (.so files)
│       └── lib_symlinks/   # Compatibility symlinks for missing library versions
├── doc/                    # Architecture and design documents
├── examples/               # Standalone example programs (01–33)
├── src/HHLO/
│   ├── Core/Types.hs       # DType, Shape, HostType type families
│   ├── IR/
│   │   ├── AST.hs          # MLIR AST (Operation, Function, Module)
│   │   ├── Builder.hs      # Stateful Builder monad + Tensor/Tuple GADTs
│   │   └── Pretty.hs       # MLIR text pretty-printer
│   ├── EDSL/Ops.hs         # Type-safe frontend ops (50+ ops)
│   └── Runtime/
│       ├── PJRT/
│       │   ├── FFI.hs      # C FFI declarations
│       │   ├── Types.hs    # Opaque pointer newtypes + buffer type constants
│       │   ├── Error.hs    # PJRT error handling
│       │   └── Plugin.hs   # Backend-agnostic plugin loading (withPJRT)
│       ├── Device.hs       # Device enumeration & selection
│       ├── Compile.hs      # MLIR → PJRT executable
│       ├── Compile.hs      # MLIR → PJRT executable (with `CompileOptions`)
│       ├── Execute.hs      # Synchronous + device-targeted + multi-GPU replica execution
│       ├── Async.hs        # Non-blocking execution with PJRT_Event
│       └── Buffer.hs       # Host↔device buffer transfers + metadata queries
├── test/
│   ├── Test/
│   │   ├── EDSL/Ops.hs
│   │   ├── IR/
│   │   │   ├── Builder.hs
│   │   │   ├── Pretty.hs
│   │   │   ├── PrettyOps.hs
│   │   │   ├── PrettyNN.hs
│   │   │   └── PrettyControlFlow.hs
│   │   ├── Runtime/
│   │   │   ├── EndToEnd*.hs       # CPU E2E test modules
│   │   │   ├── EndToEndGPU.hs     # GPU availability test
│   │   │   ├── Buffer.hs
│   │   │   ├── BufferGPU.hs       # GPU buffer integration tests
│   │   │   ├── Async.hs
│   │   │   ├── AsyncGPU.hs        # GPU async tests
│   │   │   ├── MultiGPU.hs        # Multi-GPU inference scaling tests
│   │   │   └── Errors.hs
│   │   └── Utils.hs
│   └── Main.hs
├── hhlo.cabal
├── pjrt_script.sh          # Downloads PJRT plugins
├── setup_gpu_env.sh        # Auto-configures LD_LIBRARY_PATH for GPU
└── README.md

Architecture Docs

The doc/ directory contains detailed design documents:

Document Contents
implementation-design.md Four-layer architecture and design decisions
progress-and-remaining-work.md Current status, completed features, and backlog
test-suite-documentation.md Test catalog and tier descriptions

License

MIT License — see LICENSE.

About

a Haskell library and runtime for building, compiling, and executing machine learning(inference) programs targeting StableHLO

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors