In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jaxdf.discretization import FourierSeries
from jax import random
from jaxdf.geometry import Domain
import jax

seed = random.PRNGKey(42)

domain = Domain((16,16), (.5,.5))

fourier_discretization = FourierSeries(domain)

In [3]:
from jaxdf.core import Field
seeds = random.split(seed, 2)

u_params, u = fourier_discretization.empty_field(name="u")


$F(a,x) \to F(a,x) + 2$

$F(a,x) \to F(a + 2,x)$

In [11]:
from jaxdf import operators as jops
from jaxdf.core import operator

Sin = jops.elementwise(jax.numpy.sin)

@operator()
def custom_op(u):
    grad_u = jops.gradient(u)
    nabla_sq = jops.sum_over_dims(jops.diag_jacobian(grad_u))
    return nabla_sq + 2*Sin(u)

In [12]:
op = custom_op(u=u)
print(op)

DiscretizedOperator :: [FourierSeries], ['_x9'] 

 Input fields: ('u',)

Globals: Shared: {'k_vec': [DeviceArray([ 0.       ,  0.7853982,  1.5707964,  2.3561945,  3.1415927,
              3.926991 ,  4.712389 ,  5.4977875, -6.2831855, -5.4977875,
             -4.712389 , -3.926991 , -3.1415927, -2.3561945, -1.5707964,
             -0.7853982], dtype=float32), DeviceArray([ 0.       ,  0.7853982,  1.5707964,  2.3561945,  3.1415927,
              3.926991 ,  4.712389 ,  5.4977875, -6.2831855, -5.4977875,
             -4.712389 , -3.926991 , -3.1415927, -2.3561945, -1.5707964,
             -0.7853982], dtype=float32)]}
Independent: {'MultiplyScalarLinear_rE': {'scalar': 2}}

Operations:
- _k5: FourierSeries <-- FFTGradient ('u',) | (shared) FFTGradient
- _mO: FourierSeries <-- FFTDiagJacobian ('_k5',) | (shared) FFTDiagJacobian
- _oj: FourierSeries <-- SumOverDimsOnGrid ('_mO',) | (none) SumOverDimsOnGrid
- _q2: FourierSeries <-- ElementwiseOnGrid ('u',) | (none) ElementwiseOnGrid
- _vm: 

In [9]:
from jax import make_jaxpr, jit
from jax import numpy as jnp

f = op.get_field_on_grid(0)
global_params = op.get_global_params()

make_jaxpr(f)(global_params, {"u": u_params})

{ lambda  ; a b c d.
  let e = lt 0 0
      f = add 0 1
      g = select e f 0
      h = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] g
      i = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2,), start_index_map=(2,))
                  indices_are_sorted=True
                  slice_sizes=(16, 16, 1)
                  unique_indices=True ] d h
      j = broadcast_in_dim[ broadcast_dimensions=(0, 1)
                            shape=(16, 16) ] i
      k = transpose[ permutation=(1, 0) ] j
      l = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = fft[ fft_lengths=(16,)
                                              fft_type=FftType.FFT ] a
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=fft ] k
      m = mul l 1j
   

In [5]:
class AddScalar(Primitive):
    def __init__(self, scalar, name="AddScalar", independent_params=True):
        super().__init__(name, independent_params)
        self.scalar = scalar
    
    def discrete_transform(self):
        def f(op_params, field_params):
            return [field_params, op_params["scalar"]]
        f.__name__ = self.name
        return f

    def setup(self, field):
        '''New arbitrary discretization'''
        parameters = {"scalar": self.scalar}

        def get_field(p_joined, x):
            p, scalar = p_joined
            return field.discretization.get_field()(p,x) + scalar

        new_discretization = discretization.Arbitrary(
            field.discretization.domain,
            get_field,
            no_init
        )

        return parameters, new_discretization


NameError: name 'Primitive' is not defined

In [None]:
class AddScalarLinear(Primitive):
    def __init__(self, scalar, name="AddScalarLinear", independent_params=True):
        super().__init__(name, independent_params)
        self.scalar = scalar

    def discrete_transform(self):
        def f(op_params, field_params):
            return field_params + op_params["scalar"]
        f.__name__ = self.name
        return f
    
    def setup(self, field):
        '''Same discretization family as the input'''
        new_discretization = field.discretization
        parameters = {"scalar": self.scalar}
        return parameters, new_discretization

In [6]:
from jaxdf.core import operator

@operator(u)
def custom_op(x):
    return u(x) + 2

In [7]:
from jaxdf.core import operator

@operator()
def custom_op(u):
    a = u + 2
    return a

In [8]:
custom_op

<function jaxdf.core.operator.<locals>.decorator.<locals>.wrapper(*args, **kwargs)>

In [9]:
op = custom_op(u=u)
print(op)

DiscretizedOperator :: [RealFourierSeries], ['_k5'] 

 Input fields: ('u',)

Globals: Shared: {}
Independent: {'AddScalarLinear_jR': {'scalar': 2}}

Operations:
- _k5: RealFourierSeries <-- AddScalarLinear ('u',) | (independent) AddScalarLinear_jR



In [10]:
global_params = op.get_global_params()

f = op.get_field(0)
f(global_params, {'u': u_params}, x=1.)

DeviceArray([2240.], dtype=float32)

In [7]:
from jaxdf.core import operator
from jaxdf.primitives import AddScalarLinear

add_three = AddScalarLinear(scalar=3.)
add_five = AddScalarLinear(scalar=5., independent_params = False)

@operator(debug=False)
def custom_op(u):
    a = u + 2
    b = add_three(a)
    c = add_five(b)
    return c

In [8]:
from jaxdf.core import Field

# Fourier discretization
seeds = random.split(seed, 2)
u_params = fourier_discretization.random_field(seeds[0])
u = Field(fourier_discretization, params=u_params, name='u')

# Compiling operator on the given discretization
op = custom_op(u=u)
print(op)

DiscretizedOperator :: [RealFourierSeries], ['_oj'] 

 Input fields: ('u',)

Globals: Globals: 
{'shared': {'AddScalarLinear': {'scalar': 5.0}}, 'independent': {'AddScalarLinear_jR': {'scalar': 2}, 'AddScalarLinear_l5': {'scalar': 3.0}}}

Operations:
- _k5: RealFourierSeries <-- AddScalarLinear ('u',) | (independent) AddScalarLinear_jR
- _mO: RealFourierSeries <-- AddScalarLinear ('_k5',) | (independent) AddScalarLinear_l5
- _oj: RealFourierSeries <-- AddScalarLinear ('_mO',) | (shared) AddScalarLinear



In [9]:
global_params = op.get_global_params()

f = op.get_field(0)
f(global_params, {'u': u_params}, 1.)

DeviceArray([8.997641], dtype=float32)

In [13]:
import pprint

pp = pprint.PrettyPrinter(indent=4)

In [14]:
pp.pprint(global_params)

{   'independent': {   'AddScalarLinear_jR': {'scalar': 2},
                       'AddScalarLinear_l5': {'scalar': 3.0}},
    'shared': {'AddScalarLinear': {'scalar': 5.0}}}


In [15]:
from jaxdf.primitives import Primitive

In [9]:
class MultiplyScalarLinear(Primitive):
    def __init__(self, scalar, name="MultiplyScalarLinear", independent_params=True):
        super().__init__(name, independent_params)
        self.scalar = scalar

    def discrete_transform(self):
        def f(op_params, field_params):
            return field_params  op_params["scalar"]
        f.__name__ = self.name
        return f
    
    def setup(self, field):
        '''Same discretization family as the input'''
        new_discretization = field.discretization
        parameters = {"scalar": self.scalar}
        return parameters, new_discretization

In [33]:
class PowerScalarLinear(Primitive):
    def __init__(self, exponent, name="abc", independent_params=True):
        super().__init__(name, independent_params)
        self.exponent = exponent
        
    def setup(self, field):
        '''Same discretization family as the input'''
        new_discretization = field.discretization
        parameters = {"exponent": self.exponent}
        return parameters, new_discretization
    
    def discrete_transform(self):
        def f(op_params, field_params):
            return field_params*op_params["exponent"]
        f.__name__ = self.name
        return f

In [34]:
power_4 = PowerScalarLinear(4)
    
@operator(debug=False)
def custom_op(u):
    b = power_4(u)
    c = b + 10
    return c

In [35]:
op = custom_op(u=u)
print(op)

DiscretizedOperator :: [RealFourierSeries], ['_mO'] 

 Input fields: ('u',)

Globals: {'shared': {}, 'independent': {'abc_jR': {'exponent': 4}, 'AddScalarLinear_l5': {'scalar': 10}}}

Operations:
- _k5: RealFourierSeries <-- abc ('u',) | (independent) abc_jR
- _mO: RealFourierSeries <-- AddScalarLinear ('_k5',) | (independent) AddScalarLinear_l5



In [36]:
global_params = op.tracer.globals.dict
f = op.get_field(0)
f(global_params, {'u': u_params}, 1.)g

DeviceArray(5.990572, dtype=float32)

In [37]:
def g(exponent):
    new_global = global_params.copy()
    new_global['independent']['abc_jR']['exponent'] = exponent
    return f(new_global, {'u': u_params}, 1.)

In [38]:
g(3.)

DeviceArray(6.992928, dtype=float32)

In [39]:
from jax import grad

dg = grad(g)

In [40]:
dg(3.)

DeviceArray(-1.0023577, dtype=float32)

In [30]:
jax.make_jaxpr(f)(global_params, {'u': u_params}, 1.)

{ lambda a b c d e ; f g h i.
  let j = convert_element_type[ new_dtype=float32
                                weak_type=False ] g
      k = pow h j
      l = convert_element_type[ new_dtype=float32
                                weak_type=False ] f
      m = add k l
      n = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      o = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      p = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      q = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=False
                   unique_indices=True
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] n p o
      r = convert_elemen

In [1]:
jax.make_jaxpr(f)(global_params, {'u': u_params}, 1.)

NameError: name 'jax' is not defined

# jax.make_jaxpr(f)(global_params, {'u': u_params}, 1.)

In [23]:
pp.pprint(global_params)

{   'independent': {   'AddScalarLinear_l5': {'scalar': 10},
                       'abc_jR': {'exponent': 4}},
    'shared': {}}
