In [1]:
import jax
import jax.numpy as jnp

In [None]:
def create_tuple(n, m):
    """
    Create a tuple with n zeros followed by m ones.

    Parameters:
    n (int): The number of zeros.
    m (int): The number of ones.

    Returns:
    tuple: A tuple containing n zeros followed by m ones.
    """
    # Create the tuple using tuple concatenation
    return (0,) * n + (1,) * m

# Example usage
create_tuple(3, 4)


def get_derivs(f,x0,dmax):
    """
    Compute the partial derivatives of the function, `f`, of a two-dimensional input evaluated at the point `x0` up to a maximum order, `dmax`.
    """
    derivs = {(0,0):f(x0)}
    Dn_f = jax.jacrev(f)
    Dn_f_val = Dn_f(x0)
    derivs[(1,0)] = Dn_f_val[create_tuple(1,0)]
    derivs[(0,1)] = Dn_f_val[create_tuple(0,1)]
    for d in range(2,dmax+1):
        Dn_f = jax.jacfwd(Dn_f)
        Dn_f_val = Dn_f(x0)
        for j in range(d+1):
            i = d-j
            derivs[(i,j)] = Dn_f_val[create_tuple(i,j)]
    return derivs

In [73]:
get_derivs(f,jnp.array([0,0.]),4)

{(0, 0): Array(1., dtype=float32),
 (1, 0): Array(0., dtype=float32),
 (0, 1): Array(0., dtype=float32),
 (2, 0): Array(-4., dtype=float32),
 (1, 1): Array(0., dtype=float32),
 (0, 2): Array(-9., dtype=float32),
 (3, 0): Array(0., dtype=float32),
 (2, 1): Array(0., dtype=float32),
 (1, 2): Array(0., dtype=float32),
 (0, 3): Array(0., dtype=float32),
 (4, 0): Array(16., dtype=float32),
 (3, 1): Array(0., dtype=float32),
 (2, 2): Array(36., dtype=float32),
 (1, 3): Array(0., dtype=float32),
 (0, 4): Array(81., dtype=float32)}

In [29]:
def f(x):
    x1,x2=x
    return jnp.cos(2*x1)*jnp.cos(3*x2)

In [34]:
D1f = jax.jacrev(f)
D2f = jax.jacfwd(D1f)
D3f = jax.jacfwd(D2f)
D4f = jax.jacfwd(D3f)

In [36]:
x0 = jnp.asarray([0,0.])
print(D1f(x0))
print(D2f(x0))
print(D4f(x0))

[0. 0.]
[[-4.  0.]
 [ 0. -9.]]
[[[[16.  0.]
   [ 0. 36.]]

  [[ 0. 36.]
   [36.  0.]]]


 [[[ 0. 36.]
   [36.  0.]]

  [[36.  0.]
   [ 0. 81.]]]]


In [46]:
d4[1,1,1,1]

Array(81., dtype=float32)

In [49]:
d4[0,0,0,0]

Array(16., dtype=float32)

In [12]:
jax.jacfwd(jax.jacfwd(jax.jacrev(f)))(jnp.asarray([1.,2.,3.]))

Array([[[ 0.     ,  0.     ,  0.     ],
        [ 0.     ,  0.     ,  0.     ],
        [ 0.     ,  0.     ,  0.     ]],

       [[ 0.     ,  0.     ,  0.     ],
        [ 0.     ,  0.     ,  0.     ],
        [ 0.     ,  0.     ,  0.     ]],

       [[ 0.     ,  0.     ,  0.     ],
        [ 0.     ,  0.     ,  0.     ],
        [ 0.     ,  0.     , -0.14112]]], dtype=float32)

In [115]:
jnp.asarray(1)

Array(1, dtype=int32, weak_type=True)

In [116]:
jax.jvp(f)(jnp.asarray([1,2,3]),jnp.asarray(1.))

TypeError: jvp() missing 2 required positional arguments: 'primals' and 'tangents'

In [109]:
def Gfn(x,e):
    return -1 * (x/e) * jnp.arctan(e * x / jnp.sqrt(1-e*e))
    
def Potential_uv(u,v,e):
    coshu = jnp.cosh(u)
    cosv = jnp.cos(v)
    sinhu = jnp.sinh(u)
    sinv = jnp.sin(v)
    U = -1 * coshu * jnp.arctan(e*coshu/jnp.sqrt(1-e*e))/e
    V = -1 * cosv * jnp.arctan(e*cosv/jnp.sqrt(1-e*e))/e
    Phi = (U-V)/(sinv*sinv + sinhu*sinhu)
    return Phi
    
def Potential_lm(l,m,e):
    bsq = 1 - e*e
    Fl = jnp.sqrt(l-bsq)  * jnp.arccos(jnp.sqrt(bsq/l))
    Fm = jnp.sqrt(m-bsq)  * jnp.arccos(jnp.sqrt(bsq/m))
    return -1*(Fl-Fm)/(l-m)
    
def Rz_to_lm(R,z,e):
    Rsq_plus_zsq =(R**2 + z**2)
    d = jnp.sqrt(e**4 + 2 * e**2 *(R**2-z**2) + Rsq_plus_zsq**2)
    lmbda = 1 - e*e/2 + Rsq_plus_zsq/2 + d/2
    mu = 1 - e*e/2 + Rsq_plus_zsq/2 - d/2
    return lmbda,mu
    
def Rz_to_uv(R,z,e):
    Rsq_plus_zsq =(R**2 + z**2)
    d = jnp.sqrt(e**4 + 2 * e**2 *(R**2-z**2) + Rsq_plus_zsq**2)
    lmbda = 1 - e*e/2 + Rsq_plus_zsq/2 + d/2
    mu = 1 - e*e/2 + Rsq_plus_zsq/2 - d/2
    u_of_Rz = jnp.arcsinh(jnp.sqrt((lmbda-1))/e)
    v_of_Rz = jnp.arcsin(jnp.sqrt((1-mu))/e)
    return u_of_Rz,v_of_Rz

def Potential_Rz(R,z,e):
    u,v = Rz_to_uv(R,z,e)
    return Potential_uv(u,v,e)
    
def Potential_Rz_alt(R,z,e,eps = 0.001):
    l,m = Rz_to_lm(R,z,e)
    return Potential_lm(l,m+eps,e)

In [110]:
RR,ZZ = jnp.meshgrid(jnp.linspace(0,3),jnp.linspace(-2,2))
NPhi=Potential_Rz(RR,ZZ,0.5)
NPhi2=Potential_Rz_alt(RR,ZZ,0.5)

In [111]:
dPhi_alt_dR = jax.grad(Potential_Rz_alt)

In [112]:
dPhi_alt_dR(1.,0,0.5)

Array(0.30578738, dtype=float32, weak_type=True)

In [99]:
dPhi_dR = jax.grad(Potential_Rz)
dPhi_dz = jax.grad(Potential_Rz,argnums=1)

In [88]:
dPhi_dR(1.,.5e-3,0.5)

Array(0.3059763, dtype=float32, weak_type=True)

In [89]:
dPhi_dz(1.,0.5e-3,0.5)

Array(0.00017666, dtype=float32, weak_type=True)

In [75]:
Rz_to_lm(1.,0,0.75)

(Array(2., dtype=float32, weak_type=True),
 Array(0.4375, dtype=float32, weak_type=True))