In [1]:
import jax
import jax.numpy as jnp
from jax._src.interpreters.partial_eval import (
  Offloadable as Offloadable,
)
import functools

def policy(prim, *avals, **params) -> Offloadable:
  return Offloadable(src='device', dst='pinned_host')

@functools.partial(jax.remat, policy=policy)  # type: ignore
def f(x):
  return jnp.sin(jnp.sin(x))

def g(x):
  b = f(x)
  return jnp.sum(b)

# Example usage
x = jnp.ones((16,))  # Example input array

lowered = jax.jit(jax.grad(g)).lower(x)
print(lowered.as_text())


module @jit_g attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<16xf32>
    %1 = stablehlo.cosine %arg0 : tensor<16xf32>
    %2 = stablehlo.custom_call @annotate_device_placement(%1) {backend_config = "", has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %3 = stablehlo.cosine %0 : tensor<16xf32>
    %4 = stablehlo.custom_call @annotate_device_placement(%3) {backend_config = "", has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %6:3 = stablehlo.optimization_barrie

In [2]:
compiled = lowered.compile()
compiled.cost_analysis()

[{'bytes accessed0{}': 7680.0,
  'bytes accessed2{}': 512.0,
  'bytes accessed1{}': 512.0,
  'optimal_seconds': 1.7197404389435178e-08,
  'utilization0{}': 17.0,
  'flops': 32.0,
  'utilization2{}': 1.0,
  'bytes accessed': 16896.0,
  'transcendentals': 48.0,
  'bytes accessedout{}': 8192.0,
  'utilization1{}': 2.0,
  'bytes accessedout{0}': 512.0,
  'bytes accessedout{1}': 512.0}]