In [1]:
import os
ncpu=1
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")
os.environ['JAX_PLATFORMS'] = "cpu"

import numpy as np
import jax
import sys
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

# This must be enabled when `make_jaxpr` is called, because otherwhise we get problems.
jax.config.update("jax_enable_x64", True)

import dace
#from dace.transformation.auto.auto_optimize import auto_optimize

In [2]:
from JaxprToSDFG import  JaxprToSDFG

## Demo Input
This is the input we are using.

In [3]:
N1, N2, N3 = 20, 20, 4
A = np.random.rand(N1, N2).astype(np.float64)
B = np.random.rand(N1, N2).astype(np.float64)
C = np.random.rand(N1).astype(np.float64)
D = np.random.rand(N1).astype(np.float64)
E = np.random.rand(N3).astype(np.float64)
F = np.random.rand(N1, N2).astype(np.float64)
G = np.random.rand(N1).astype(np.float64)
H = np.random.rand(N1, N2, N3).astype(np.float64)
S = np.float64(3)
SS = np.full((1, 1), S)
T = np.float64(3.14)
TT = np.array(T)
_OUT = np.ones((N1, N2))
_OUT2 = np.ones(N1)
_OUT3 = np.ones(N1 - 1)

Idx = np.array([0, 2, 3, 6, 7, 10, 11, 19]).astype(np.int32)
Idx2 = np.array([1, 2]).astype(np.int32)

two = np.full((N1, N2), 2.)


## Functions to Transforms
Here are various functions that we use to transform.
They are ordered in increasing complexities.

In [4]:
# A sinple function
def f1(A, B, two):
    return A + B * two
#

# This is the same as `f1` but this time `2.0` is a double literal and not an argument.
#  Casting does not work yet.
def f2(A, B):
    return A + B * 2.0
#

# This tests the constants
def f2_2(E):
    return E + np.array([1, 4, 5, 90], dtype=np.float64)
#

def f2_3(A, B, F):
    return jnp.sqrt(jnp.abs(jnp.ceil(jnp.tanh(jnp.maximum(F * 100., jnp.minimum(A * 100. , B * 100.))))))
#

def f2_4(A, B):
    return jnp.maximum(A, B)
#

# This is a bit more complicated since it is slicing and thus the input and output have different sizes.
def f3(A):
    return A[2:]
#

# This is a slicing with steps
# NOT SUPPORTED
def f3_2(A):
    return A[2:-2:2]
#


# This is a test for the return mechanism for tuple
def f4(A, B):
    return B + 1, A + 2
#

# This is a super simple stencil
def f5(C):
    return C[2:] - C[:-2]
#

# This is a test for the derivative
def f6(S: dace.float64):
    return jnp.exp(jnp.sin(S)**2)
#

def f6_2(S: dace.float64, T: dace.float64):
    return jnp.exp(jnp.sin(S)**2) * (jnp.cos(T * jnp.sin(S)))
#

# This tests the broadcast and the briadcastingJaxprToSDFG/translators/simpleTranslator.py abiolities of the simple translator
def f7(A: dace.float64[3, 2, 4], B: dace.float64[2, 4]):
    return A + B
#

A_f7_2 = np.random.rand(3, 2, 4).astype(np.float64)
B_f7_2 = np.random.rand(3, 2, 1).astype(np.float64)   # Without the trailing one, numpy would complain.
def f7_2(A: dace.float64[3, 2, 4], B: dace.float64[3, 2, 1]):
    return A + B
#

A_f7_3 = np.random.rand(2, 3, 1, 5).astype(np.float64)
B_f7_3 = np.random.rand(2, 1, 4, 5).astype(np.float64)
def f7_3(A: dace.float64[2, 3, 1, 5], B: dace.float64[2, 1, 4, 5]):
    return A + B
#


# This is for advanced indexing
def f8(C, Idx):
    return C[Idx]
#

## They do not work yet, because we are missing concatenate.
# This is also for advanced indexing
def f9(C, D, G):
    return jnp.where(C < 0.0, D, G)
#

def f9_2(H, Idx2):
    #return H.at[Idx2 , :, Idx2+1].get()
    return H[Idx2 , :, Idx2-1]
#

def f9_3(H, Idx):
    #return H.at[: , Idx, 1:].get()
    return H[Idx , :, :]
#

def f9_4(A, Idx):
    #return A.at[Idx, 4].get()
    return A[Idx, 4]
#


# CURRENTLY UNSUPPORTED IN THE TRANSLATOR
def f9_5(A, Idx):
    return A.at[4:10:3, 3:15:2].get()
#

def f9_6(C, D, G):
    Low = 0.25 < D
    Up  = D < 0.75
    Sel = jnp.logical_and(Low, Up)
    return jnp.where(Sel, C, G)
#
    


    





## Jaxpr
We now transform it into Jaxpr.

In [5]:
t = JaxprToSDFG()

#### `f1`

In [6]:
f1_jaxpr = jax.make_jaxpr(f1)(A, B, two)

In [7]:
print(f1_jaxpr)

{ lambda ; a:f64[20,20] b:f64[20,20] c:f64[20,20]. let
    d:f64[20,20] = mul b c
    e:f64[20,20] = add a d
  in (e,) }


In [8]:
f1_sdfg = t(f1_jaxpr)
f1_sdfg

In [9]:
# Currently it is requiered to make the renaming manually.
_OUT = f1_sdfg(A, B, two)

In [10]:
resExp = f1(A, B, two)
resDC  = _OUT

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

#### `f2`

In [11]:
f2_jaxpr = jax.make_jaxpr(f2)(A, B)
print(f2_jaxpr)

{ lambda ; a:f64[20,20] b:f64[20,20]. let
    c:f64[20,20] = mul b 2.0
    d:f64[20,20] = add a c
  in (d,) }


In [12]:
f2_sdfg = t(f2_jaxpr)

In [13]:
# Currently it is requiered to make the renaming manually.
_OUT = f2_sdfg(A, B)

In [14]:
resExp = f2(A, B)
resDC  = _OUT

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f2_2`

In [15]:
f2_2_jaxpr = jax.make_jaxpr(f2_2)(E)
print(f2_2_jaxpr)

{ lambda a:f64[4]; b:f64[4]. let c:f64[4] = add b a in (c,) }


In [16]:
f2_2_sdfg = t(f2_2_jaxpr)
resExp = f2_2(E)
resDC  = f2_2_sdfg(b=E)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f2_3`


In [17]:
f2_3_jaxpr = jax.make_jaxpr(f2_3)(A, B, F)
print(f2_3_jaxpr)

{ lambda ; a:f64[20,20] b:f64[20,20] c:f64[20,20]. let
    d:f64[20,20] = mul c 100.0
    e:f64[20,20] = mul a 100.0
    f:f64[20,20] = mul b 100.0
    g:f64[20,20] = min e f
    h:f64[20,20] = max d g
    i:f64[20,20] = tanh h
    j:f64[20,20] = ceil i
    k:f64[20,20] = abs j
    l:f64[20,20] = sqrt k
  in (l,) }


In [18]:
f2_3_sdfg = t(f2_3_jaxpr)
resExp = f2_3(A, B, F)
resDC  = f2_3_sdfg(A, B, F)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f2_4`

In [19]:
f2_4_jaxpr = jax.make_jaxpr(f2_4)(A, B)
print(f2_4_jaxpr)

{ lambda ; a:f64[20,20] b:f64[20,20]. let c:f64[20,20] = max a b in (c,) }


In [20]:
f2_4_sdfg = t(f2_4_jaxpr)
resExp = f2_4(A, B)
resDC  = f2_4_sdfg(A, B, F)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

#### `f3`
Simple slicing.

In [21]:
f3_jaxpr = jax.make_jaxpr(f3)(C)
print(f3_jaxpr)

{ lambda ; a:f64[20]. let
    b:f64[18] = slice[limit_indices=(20,) start_indices=(2,) strides=None] a
  in (b,) }


In [22]:
f3_sdfg = t(f3_jaxpr)

In [23]:
resExp = f3(C)
_OUTF3 = f3_sdfg(C)
resDC  = _OUTF3

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f3_2`

In [24]:
f3_2_jaxpr = jax.make_jaxpr(f3_2)(C)
print(f3_2_jaxpr)

{ lambda a:i32[8]; b:f64[20]. let
    c:i32[8,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(8, 1)] a
    d:f64[8] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1,)
      unique_indices=True
    ] b c
  in (d,) }


#### `f4`

In [25]:
f4_jaxpr = jax.make_jaxpr(f4)(A, B)
print(f4_jaxpr)

{ lambda ; a:f64[20,20] b:f64[20,20]. let
    c:f64[20,20] = add b 1.0
    d:f64[20,20] = add a 2.0
  in (c, d) }


In [26]:
f4_sdfg = t(f4_jaxpr)

In [27]:
resExp1, resExp2 = f4(A, B)
resDC1, resDC2 = f4_sdfg(A, B)

assert np.all(np.abs(resDC1 - resExp1) <= 10**(-13))
assert np.all(np.abs(resDC2 - resExp2) <= 10**(-13))

#### `f5`

In [28]:
f5_jaxpr = jax.make_jaxpr(f5)(C)
print(f5_jaxpr)

{ lambda ; a:f64[20]. let
    b:f64[18] = slice[limit_indices=(20,) start_indices=(2,) strides=None] a
    c:f64[18] = slice[limit_indices=(18,) start_indices=(0,) strides=None] a
    d:f64[18] = sub b c
  in (d,) }


In [29]:
f5_sdfg = t(f5_jaxpr)

In [30]:
resExp = f5(C)
resDC  = f5_sdfg(C)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

#### `f6`
These are the derivative tests.

In [31]:
f6_jaxpr = jax.make_jaxpr(f6)(S)
print(f6_jaxpr)

{ lambda ; a:f64[]. let
    b:f64[] = sin a
    c:f64[] = integer_pow[y=2] b
    d:f64[] = exp c
  in (d,) }


In [32]:
f6_sdfg = t(f6_jaxpr)

In [33]:
resExp = f6(S) #SS)
resDC  = f6_sdfg(S) #SS)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))

In [34]:
f6dd_jaxpr = jax.make_jaxpr(jax.grad(jax.grad(f6)))(S)
print(f6dd_jaxpr)
f6dd_sdfg = t(f6dd_jaxpr)

{ lambda ; a:f64[]. let
    b:f64[] = sin a
    c:f64[] = cos a
    d:f64[] = cos a
    e:f64[] = sin a
    f:f64[] = integer_pow[y=2] b
    g:f64[] = integer_pow[y=1] b
    h:f64[] = mul 2.0 g
    i:f64[] = integer_pow[y=1] b
    j:f64[] = integer_pow[y=0] b
    k:f64[] = mul 1.0 j
    l:f64[] = mul 2.0 i
    m:f64[] = exp f
    n:f64[] = mul 1.0 m
    o:f64[] = mul n l
    _:f64[] = mul o d
    p:f64[] = mul o 1.0
    q:f64[] = mul 1.0 d
    r:f64[] = mul n q
    s:f64[] = mul 2.0 r
    t:f64[] = mul s k
    u:f64[] = mul q l
    v:f64[] = mul 1.0 u
    w:f64[] = mul v m
    x:f64[] = mul w h
    y:f64[] = add_any t x
    z:f64[] = neg p
    ba:f64[] = mul z e
    bb:f64[] = mul y c
    bc:f64[] = add_any ba bb
  in (bc,) }


In [35]:
resExp = jax.grad(jax.grad(f6))(S)
resDC  = f6dd_sdfg(S)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f6_2`

In [36]:
f6_2G = jax.grad(f6_2, argnums=(0, 1))
f6_2G_jaxpr = jax.make_jaxpr(f6_2G)(S, T)
f6_2G_sdfg = t(f6_2G_jaxpr)

In [37]:
resExp = f6_2G(S, T)
resDC  = f6_2G_sdfg(S, T)
assert all([np.all(np.abs(dc - ex) <= 10**(-13))  for dc, ex in zip(resDC, resExp)])

#### `f7`

In [38]:
Af7 = np.random.rand(3, 2, 4).astype(np.float64)
Bf7 = np.random.rand(2, 4).astype(np.float64)

f7_jaxpr = jax.make_jaxpr(f7)(Af7, Bf7)
print(f7_jaxpr)

{ lambda ; a:f64[3,2,4] b:f64[2,4]. let
    c:f64[1,2,4] = broadcast_in_dim[broadcast_dimensions=(1, 2) shape=(1, 2, 4)] b
    d:f64[3,2,4] = add a c
  in (d,) }


In [39]:
f7_sdfg = t(f7_jaxpr)

resExp = f7(Af7, Bf7)
resDC  = f7_sdfg(Af7, Bf7)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))



##### `f7_2`

In [40]:
f7_2_jaxpr = jax.make_jaxpr(f7_2)(A_f7_2, B_f7_2)
print(f7_2_jaxpr)

{ lambda ; a:f64[3,2,4] b:f64[3,2,1]. let c:f64[3,2,4] = add a b in (c,) }


In [41]:
f7_2_sdfg = t(f7_2_jaxpr)

resExp = f7_2(A_f7_2, B_f7_2)
resDC  = f7_2_sdfg(A_f7_2, B_f7_2)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f7_3`

In [42]:
f7_3_jaxpr = jax.make_jaxpr(f7_3)(A_f7_3, B_f7_3)
print(f7_3_jaxpr)

{ lambda ; a:f64[2,3,1,5] b:f64[2,1,4,5]. let c:f64[2,3,4,5] = add a b in (c,) }


In [43]:
f7_3_sdfg = t(f7_3_jaxpr)

resExp = f7_3(A_f7_3, B_f7_3)
resDC  = f7_3_sdfg(A_f7_3, B_f7_3)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))

#### `f8`

In [44]:
with jax.disable_jit(disable=True):
    f8_jaxpr = jax.make_jaxpr(f8)(C, Idx)
    print(f8_jaxpr)

{ lambda ; a:f64[20] b:i32[8]. let
    c:bool[8] = lt b 0
    d:i32[8] = add b 20
    e:i32[8] = select_n c b d
    f:i32[8,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(8, 1)] e
    g:f64[8] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1,)
      unique_indices=False
    ] a f
  in (g,) }


In [45]:
# There is this warning I have to get rid of.
f8_sdfg = t(f8_jaxpr)

resExp = f8(C, Idx)
resDC  = f8_sdfg(C, Idx)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))



#### `f9`

In [46]:
with jax.disable_jit(disable=True):
    f9_jaxpr = jax.make_jaxpr(f9)(C, D, G)
print(f9_jaxpr)

{ lambda ; a:f64[20] b:f64[20] c:f64[20]. let
    d:bool[20] = lt a 0.0
    e:f64[20] = select_n d c b
  in (e,) }


In [47]:
f9_sdfg = t(f9_jaxpr)

resExp = f9(C, D, G)
resDC  = f9_sdfg(C, D, G)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f9_2`

In [48]:
f9_2_jaxpr = jax.make_jaxpr(f9_2)(H, Idx2)
print(f9_2_jaxpr)

{ lambda ; a:f64[20,20,4] b:i32[2]. let
    c:i32[2] = sub b 1
    d:bool[2] = lt b 0
    e:i32[2] = add b 20
    f:i32[2] = select_n d b e
    g:bool[2] = lt c 0
    h:i32[2] = add c 4
    i:i32[2] = select_n g c h
    j:i32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] f
    k:i32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] i
    l:i32[2,2] = concatenate[dimension=1] j k
    m:f64[2,20] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 20, 1)
      unique_indices=False
    ] a l
  in (m,) }


In [49]:
f9_2_sdfg = t(f9_2_jaxpr)

resExp = f9_2(H, Idx2)
resDC  = f9_2_sdfg(H, Idx2)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))



##### `f9_3`

In [50]:
f9_3_jaxpr = jax.make_jaxpr(f9_3)(H, Idx2)
print(f9_3_jaxpr)

{ lambda ; a:f64[20,20,4] b:i32[2]. let
    c:bool[2] = lt b 0
    d:i32[2] = add b 20
    e:i32[2] = select_n c b d
    f:i32[2,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(2, 1)] e
    g:f64[2,20,4] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 20, 4)
      unique_indices=False
    ] a f
  in (g,) }


In [51]:
f9_3_sdfg = t(f9_3_jaxpr)

resExp = f9_3(H, Idx2)
resDC  = f9_3_sdfg(H, Idx2)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f9_4`

In [52]:
f9_4_jaxpr = jax.make_jaxpr(f9_4)(A, Idx)
print(f9_4_jaxpr)

{ lambda ; a:f64[20,20] b:i32[8]. let
    c:bool[8] = lt b 0
    d:i32[8] = add b 20
    e:i32[8] = select_n c b d
    f:bool[] = lt 4 0
    g:i64[] = add 4 20
    h:i64[] = select_n f 4 g
    i:i64[8] = broadcast_in_dim[broadcast_dimensions=() shape=(8,)] h
    j:i32[8] = convert_element_type[new_dtype=int32 weak_type=False] i
    k:i32[8,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(8, 1)] e
    l:i32[8,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(8, 1)] j
    m:i32[8,2] = concatenate[dimension=1] k l
    n:f64[8] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] a m
  in (n,) }


In [53]:
f9_4_sdfg = t(f9_4_jaxpr)

resExp = f9_4(A, Idx)
resDC  = f9_4_sdfg(A, Idx)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))



##### `f9_6`

In [54]:
with jax.disable_jit(disable=True):
    f9_6_jaxpr = jax.make_jaxpr(f9_6)(C, D, G)
print(f9_6_jaxpr)

{ lambda ; a:f64[20] b:f64[20] c:f64[20]. let
    d:bool[20] = gt b 0.25
    e:bool[20] = lt b 0.75
    f:bool[20] = and d e
    g:f64[20] = select_n f c a
  in (g,) }


In [55]:
f9_6_sdfg = t(f9_6_jaxpr)

resExp = f9_6(C, D, G)
resDC  = f9_6_sdfg(C, D, G)
assert np.all(np.abs(resDC - resExp) <= 10**(-13))