In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

import dace

In [2]:
from JaxprToSDFG import  JaxprToSDFG

## Demo Input
This is the input we are using.

In [3]:
N1, N2, N3 = 20, 20, 4
A = np.random.rand(N1, N2).astype(np.float64)
B = np.random.rand(N1, N2).astype(np.float64)
C = np.random.rand(N1).astype(np.float64)
D = np.random.rand(N1).astype(np.float64)
E = np.random.rand(N3).astype(np.float64)
F = np.random.rand(N1, N2).astype(np.float64)
_OUT = np.ones((N1, N2))
_OUT2 = np.ones(N1)
_OUT3 = np.ones(N1 - 1)

two = np.full((N1, N2), 2.)


## Functions to Transforms
Here are various functions that we use to transform.
They are ordered in increasing complexities.

In [4]:
# A sinple function
def f1(A, B, two):
    return A + B * two
#

# This is the same as `f1` but this time `2.0` is a double literal and not an argument.
#  Casting does not work yet.
def f2(A, B):
    return A + B * 2.0
#

# This tests the constants
def f2_2(E):
    return E + np.array([1, 4, 5, 90], dtype=np.float64)
#

def f2_3(A, B, F):
    return jnp.sqrt(jnp.abs(jnp.ceil(jnp.tanh(jnp.maximum(F * 100., jnp.minimum(A * 100. , B * 100.))))))
#



# This is a bit more complicated since it is slicing and thus the input and output have different sizes.
def f3(A):
    return A[2:]
#

# This is a slicing with steps
def f3_2(A):
    return A[2:-2:2]
#


# This is a test for the return mechanism for tuple
def f4(A, B):
    return B + 1, A + 2
#

# This is a super simple stencil
def f5(C):
    return C[2:] - C[:-2]
#







## Jaxpr
We now transform it into Jaxpr.

#### `f1`

In [5]:
f1_jaxpr = jax.make_jaxpr(f1)(A, B, two)

In [6]:
print(f1_jaxpr)

{ lambda ; a:f32[20,20] b:f32[20,20] c:f32[20,20]. let
    d:f32[20,20] = mul b c
    e:f32[20,20] = add a d
  in (e,) }


In [7]:
t = JaxprToSDFG()
f1_sdfg = t(f1_jaxpr)
#f1_sdfg

In [8]:
# Currently it is requiered to make the renaming manually.
_OUT = f1_sdfg(A, B, two)

In [9]:
resExp = f1(A, B, two)
resDC  = _OUT

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

#### `f2`

In [10]:
f2_jaxpr = jax.make_jaxpr(f2)(A, B)
print(f2_jaxpr)

{ lambda ; a:f32[20,20] b:f32[20,20]. let
    c:f32[20,20] = mul b 2.0
    d:f32[20,20] = add a c
  in (d,) }


In [11]:
f2_sdfg = t(f2_jaxpr)

In [12]:
# Currently it is requiered to make the renaming manually.
_OUT = f2_sdfg(A, B)

In [13]:
resExp = f2(A, B)
resDC  = _OUT

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f2_2`

In [14]:
f2_2_jaxpr = jax.make_jaxpr(f2_2)(E)
print(f2_2_jaxpr)

{ lambda a:f32[4]; b:f32[4]. let c:f32[4] = add b a in (c,) }


In [15]:
f2_2_sdfg = t(f2_2_jaxpr)
resExp = f2_2(E)
resDC  = f2_2_sdfg(b=E)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f2_3`


In [16]:
f2_3_jaxpr = jax.make_jaxpr(f2_3)(A, B, F)
print(f2_3_jaxpr)

{ lambda ; a:f32[20,20] b:f32[20,20] c:f32[20,20]. let
    d:f32[20,20] = mul c 100.0
    e:f32[20,20] = mul a 100.0
    f:f32[20,20] = mul b 100.0
    g:f32[20,20] = min e f
    h:f32[20,20] = max d g
    i:f32[20,20] = tanh h
    j:f32[20,20] = ceil i
    k:f32[20,20] = abs j
    l:f32[20,20] = sqrt k
  in (l,) }


In [17]:
f2_3_sdfg = t(f2_3_jaxpr)
resExp = f2_3(A, B, F)
resDC  = f2_3_sdfg(A, B, F)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

#### `f3`
Simple slicing.

In [18]:
f3_jaxpr = jax.make_jaxpr(f3)(C)
print(f3_jaxpr)

{ lambda ; a:f32[20]. let
    b:f32[18] = slice[limit_indices=(20,) start_indices=(2,) strides=None] a
  in (b,) }


In [19]:
f3_sdfg = t(f3_jaxpr)

In [20]:
resExp = f3(C)
_OUTF3 = f3_sdfg(C)
resDC  = _OUTF3

assert np.all(np.abs(resDC - resExp) <= 10**(-13))

##### `f3_2`

In [21]:
f3_2_jaxpr = jax.make_jaxpr(f3_2)(C)
print(f3_2_jaxpr)

{ lambda a:i32[8]; b:f32[20]. let
    c:i32[8,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(8, 1)] a
    d:f32[8] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1,)
      unique_indices=True
    ] b c
  in (d,) }


#### `f4`

In [22]:
f4_jaxpr = jax.make_jaxpr(f4)(A, B)
print(f4_jaxpr)

{ lambda ; a:f32[20,20] b:f32[20,20]. let
    c:f32[20,20] = add b 1.0
    d:f32[20,20] = add a 2.0
  in (c, d) }


In [23]:
f4_sdfg = t(f4_jaxpr)

In [24]:
resExp1, resExp2 = f4(A, B)
resDC1, resDC2 = f4_sdfg(A, B)

assert np.all(np.abs(resDC1 - resExp1) <= 10**(-13))
assert np.all(np.abs(resDC2 - resExp2) <= 10**(-13))

#### `f5`

In [25]:
f5_jaxpr = jax.make_jaxpr(f5)(C)
print(f5_jaxpr)

{ lambda ; a:f32[20]. let
    b:f32[18] = slice[limit_indices=(20,) start_indices=(2,) strides=None] a
    c:f32[18] = slice[limit_indices=(18,) start_indices=(0,) strides=None] a
    d:f32[18] = sub b c
  in (d,) }


In [26]:
f5_sdfg = t(f5_jaxpr)

In [27]:
resExp = f5(C)
resDC  = f5_sdfg(C)

assert np.all(np.abs(resDC - resExp) <= 10**(-13))