Skip to content

Better support for symbolic (e.g. jax tracer) values in non-index coordinates #10924

@mjwillson

Description

@mjwillson

Is your feature request related to a problem?

Hello! xarray_jax maintainer here.

I am trying to decide whether it is feasible to rely on using symbolic arrays (such as jax tracers used during JIT compilation) as coordinate data in xarray.

For index coordinates this doesn't work at all, because xarray requires dynamic access to the data in the index to do alignment. So our approach in xarray_jax is to treat index coordinates as static data. I think this is acceptable: even if we could do index alignment with tracers, the result would be variable-shape which jax.jit can't handle without re-tracing anyway.

For non-index coordinates though, it is possible to use jax tracers! However these coordinates have a habit of disappearing due to failed alignment checks, for example in arithmetic operations:

import xarray as xr
import jax
import jax.numpy as jnp

def foo(coord):
   a = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
   b = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
   print((a + b).coords)

foo(jnp.arange(3))
# prints:
# Coordinates:
#    foo      (x) int32 12B ...

jax.jit(foo)(jnp.arange(3))
# prints:
# Coordinates:
#    *empty*

The coordinate is (correctly) preserved when not JIT-ing, because xarray is able to do an equality comparison to check that the coordinate is compatible.

When JITing however, the coordinate array will be a jax Tracer, and the equality comparison will fail with a jax.errors.TracerBoolConversionError (or similar). xarray catches this under the hood here since TracerBoolConversionError subclasses TypeError. xarray then treats the failed comparison as not-equal resulting in the coordinate being dropped.

This isn't entirely consistent with behaviour elsewhere in xarray.

For example when doing an explicit merge it is possible to set compat='override' which ensures that coordinates aren't tested for equality and so jax tracers survive. And I believe compat='override' is to become the default soon.

xr.align also lets jax Tracers survive:

def foo(coord):
   a = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
   b = xr.DataArray(data=jnp.ones(3), dims=['x'], coords={'foo': (['x'], coord)})
   a, b = xr.align(a, b)
   print(a.coords)
   print(b.coords)

jax.jit(foo)(jnp.arange(3))
# prints:
# Coordinates:
#     foo      (x) int32 12B ...
# Coordinates:
#     foo      (x) int32 12B ...

Overall, while I can see why some might want them, these compatibility checks on non-index coordinates seem a bit of an awkward niche behaviour which is configurable in some places but not others, and (at least after the compat='override' default takes over) will have different defaults in different places too IIUC, which is not ideal. And of course they throw a bit of a spanner in the works of using jax here.

Describe the solution you'd like

Three suggestions:

  1. Add a xarray.set_options(arithmetic_compat='override') setting which controls this, and make 'override' the default for it at the same time compat='override' becomes the default more widely. Note there is already an arithmetic_join setting which controls how index coordinates are joined/aligned, but this doesn't give any control over these compatibility checks on non-index coordinates.
    One might also consider exposing the compat option on xr.align too and any other places where compatibility checks are happening or for consistency ought to happen, to give control over this.

  2. Allow jax.errors.TracerBoolConversionError to bubble up from generic equality-testing code (or perhaps catch and reraise it as something less jax-specific like xr.SymbolicComparisonError) rather than catching it and treating it as unequal. Then in coordinate compatibility-checking code, catch this error and fall back on compat='override' behavior in this specific case. This has the advantage that non-index-coord compatibility checks don't have to be disabled across the board, only for jax Tracers. It would require some jax-specific special case in xarray though.

  3. Just to get rid of alignment checks on non-index coordinates entirely, if noone feels strongly about them. They are already somewhat inconsistent both in configurability and in default behaviour. In many cases it's a safe assumption that any non-index coordinates are a fixed function of the index coordinates and so it's sufficient to check for alignment of the index coordinates. And perhaps other cases it would be OK to make any checking the user's responsibility. I realise this would be backwards-incompatible though.

Describe alternatives you've considered

See above.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions