Def: --[ref](https://docs.sympy.org/latest/tutorial/intro.html)
- "mathematical objects are represented exactly, not approximately"
-  "mathematical expressions with unevaluated variables are left in symbolic form."

In [39]:
import jax
import jax.numpy as jnp 
import sympy
from sympy import symbols, expand, factor, diff, sin, exp
import sympy2jax

#### Manipulate Expressions

In [40]:
x, y = symbols('x y')
expr = x + 2*y 
print(f"We can print the expression:\t\t {expr}")
print(f"We can add to the expression:\t\t {expr + 1}")
print(f"We can subtract from the expression:\t {expr-x}")
print("")
print(f"We can take expressions:\t\t {x*expr}")
print(f"and expand them:\t\t\t {expand(x*expr)}")

We can print the expression:		 x + 2*y
We can add to the expression:		 x + 2*y + 1
We can subtract from the expression:	 2*y

We can take expressions:		 x*(x + 2*y)
and expand them:			 x**2 + 2*x*y


#### Differentiation

In [41]:
print(diff(expr, x))
print(diff(sin(x)))

1
cos(x)


#### Integrate

#### Solve

#### Limits

In [42]:
diff(sin(x))

cos(x)

In [43]:
x_sym = symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])
print(type(mod))


<class 'sympy2jax.sympy_module.SymbolicModule'>


In [45]:
x = jnp.zeros(3)
out = mod(x_sym=x)
params = jax.tree_leaves(mod)
params

[<function sympy2jax.sympy_module._reduce.<locals>.fn_(*args)>,
 DeviceArray(1., dtype=float32, weak_type=True),
 <CompiledFunction of <function _one_to_one_unop.<locals>.<lambda> at 0x122c75120>>,
 'x_sym',
 <function sympy2jax.sympy_module._reduce.<locals>.fn_(*args)>,
 DeviceArray(2., dtype=float32, weak_type=True),
 <CompiledFunction of <function _one_to_one_unop.<locals>.<lambda> at 0x122c74d30>>,
 'x_sym']