In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

import dace

In [2]:
from JaxprToSDFG import  JaxprToSDFG

## Demo Input
This is the input we are using.

In [3]:
N1, N2 = 20, 20
A = np.random.rand(N1, N2).astype(np.float64)
B = np.random.rand(N1, N2).astype(np.float64)
_OUT = np.ones((N1, N2))

two = np.full((N1, N2), 2.)


## Functions to Transforms
Here are various functions that we use to transform.
They are ordered in increasing complexities.

In [4]:
# A sinple function
def f1(A, B, two):
    return A + B * two
#

# This is the same as `f1` but this time `2.0` is a double literal and not an argument.
#  Casting does not work yet.
def f2(A, B):
    return A + B * 2.0
#



## Jaxpr
We now transform it into Jaxpr.

#### `f1`

In [5]:
f1_jaxpr = jax.make_jaxpr(f1)(A, B, two)

In [6]:
print(f1_jaxpr)

{ lambda ; a:f32[20,20] b:f32[20,20] c:f32[20,20]. let
    d:f32[20,20] = mul b c
    e:f32[20,20] = add a d
  in (e,) }


In [7]:
t = JaxprToSDFG()
f1_sdfg = t(f1_jaxpr)
#f1_sdfg

In [8]:
# Currently it is requiered to make the renaming manually.
f1_sdfg(a=A, b=B, c=two, _out=_OUT)

In [9]:
resExp = f1(A, B, two)
resDC  = _OUT

assert np.all(np.abs(resDC - resExp) <= 10**(-13))


#### `f2`

In [10]:
f2_jaxpr = jax.make_jaxpr(f2)(A, B)
print(f2_jaxpr)

{ lambda ; a:f32[20,20] b:f32[20,20]. let
    c:f32[20,20] = mul b 2.0
    d:f32[20,20] = add a c
  in (d,) }


In [12]:
f2_sdfg = t(f2_jaxpr)

AssertionError: Expected to find input '['2.0']'