In [1]:
import jax
jax.config.update("jax_enable_x64", True)
from bulirsch import * 
from scipy.special import ellipe, ellipk
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import quad

In [None]:
def el1_numerical(x, kc):

    func = lambda y: np.sqrt(1 / (np.cos(y)**2 + kc**2 * np.sin(y)**2))
    return quad(func, 0, np.arctan(x))[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(el1_numerical)(x, 0.2), '-', alpha=0.5, linewidth=5)
plt.plot(x, el1(x, 0.2), 'k--')

plt.figure()
kc = np.linspace(0.01, 2, 100)
plt.plot(x, np.vectorize(el1_numerical)(0.1, kc), '-', alpha=0.5, linewidth=5)
plt.plot(x, el1(0.1, kc), 'k--')

In [None]:
def el2_numerical(x, kc, a, b):

    func = lambda y: ((a + b * np.tan(y)**2) 
                      / jnp.sqrt(
                          (1 + np.tan(y)**2) 
                          * (1 + kc**2 * np.tan(y)**2)
                      )
                     )
    
    return quad(func, 0, np.arctan(x))[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(el2_numerical)(x, 0.2, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, el2(x, 0.2, 0.1, 0.5), 'k--')

plt.figure()
kc = np.linspace(0.01, 2, 100)
plt.plot(x, np.vectorize(el2_numerical)(0.1, kc, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, el2(0.1, kc, 0.1, 0.5), 'k--')

In [None]:
def el3_numerical(x, kc, p):

    func = lambda y: (
        1 / (np.cos(y)**2 + p * np.sin(y)**2) 
        / np.sqrt(np.cos(y)**2 + kc * kc * np.sin(y)**2)
    )
    
    return quad(func, 0, np.arctan(x))[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(el3_numerical)(x, 0.2, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, el3(x, 0.2, 0.1), 'k--')

plt.figure()
kc = np.linspace(0.01, 2, 100)
plt.plot(x, np.vectorize(el3_numerical)(0.1, kc, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, el3(0.1, kc, 0.1), 'k--')

In [None]:
def cel_numerical(kc, p, a, b):

    func = lambda y: (
        (a * np.cos(y)**2 + b * np.sin(y)**2) / (np.cos(y)**2 + p * np.sin(y)**2)
        / np.sqrt(np.cos(y)**2 + kc**2 * np.sin(y)**2)
    )

    return quad(func, 0, np.pi / 2)[0]

x = np.linspace(-5, 5, 100)
plt.plot(x, np.vectorize(cel_numerical)(x, 0.2, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, cel(x, 0.2, 0.1, 0.5), 'k--')

plt.figure()
kc = np.linspace(0, 2, 100)
plt.plot(kc, np.vectorize(cel_numerical)(0.1, kc, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(kc, cel(0.1, kc, 0.1, 0.5), 'k--')

In [None]:
x = np.linspace(-5, 5, 5000)
%timeit elliprc(x, 0.2)
rc(x, 0.2)
%timeit rc(x, 0.2)
plt.plot(x, elliprc(x, 0.2), '-', alpha=0.5, linewidth=5)
plt.plot(x, rc(x, 0.2), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprc(0.1, x), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rc(0.1, x), 'k--')

In [None]:
x = np.linspace(-5, 5, 100)
plt.plot(x, elliprd(x, 0.2, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, rd(x, 0.2, 0.1), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprd(0.1, x, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rd(0.1, x, 0.1), 'k--')

In [None]:
x = np.linspace(-5, 5, 100)
plt.plot(x, elliprf(x, 0.2, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(x, rf(x, 0.2, 0.1), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprf(0.1, x, 0.1), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rf(0.1, x, 0.1), 'k--')

In [None]:
x = np.linspace(-5, 5, 100)
plt.plot(x, elliprj(x, 0.2, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(x, rj(x, 0.2, 0.1, 0.5), 'k--')

plt.figure()
x = np.linspace(0, 2, 100)
plt.plot(kc, elliprj(0.1, x, 0.1, 0.5), '-', alpha=0.5, linewidth=5)
plt.plot(kc, rj(0.1, x, 0.1, 0.5), 'k--')

In [None]:
x = 1.0
kc = 0.2
p = 0.1

print((el3(x + 1e-6, kc, p) - el3(x, kc, p)) / 1e-6)
print((el3(x, kc + 1e-6, p) - el3(x, kc, p)) / 1e-6)
print((el3(x, kc, p + 1e-6) - el3(x, kc, p)) / 1e-6)

jax.jacfwd(el3, (0, 1, 2))(x, kc, p)

In [5]:
kc = 1.0
p = 1.0
a = 1.3
b = 2.0

d = 1e-12

print((cel(kc + d, p, a, b) - cel(kc, p, a, b)) / d)
print((cel(kc, p + d, a, b) - cel(kc, p, a, b)) / d)
print((cel(kc, p, a + d, b) - cel(kc, p, a, b)) / d)
print((cel(kc, p, a, b + d) - cel(kc, p, a, b)) / d)

jax.jacfwd(cel, (0, 1, 2, 3))(kc, p, a, b)

-1.4335199693960021
-1.4335199693960021
0.7851497230149107
0.7851497230149107


(Array(-1.43335165, dtype=float64, weak_type=True),
 Array(-1.43335165, dtype=float64, weak_type=True),
 Array(0.78539816, dtype=float64, weak_type=True),
 Array(0.78539816, dtype=float64, weak_type=True))

In [None]:
ap = (a - b) * (kc**2 - 1)
bp = - ((kc**2 - 1) * (-3 * b + (a + 2 * b) * kc**2))
cel(kc, 1.0, ap, bp) / (3 * (kc**2 - 1)**2)

In [None]:
ap = 4 * (kc**2 - 1) * (2 * b + (a - 3 * b) * kc**2)
bp = 4 * kc**2 * (kc**2 - 1) * (b - a * kc**2)
fac = 1 / (12 * kc**4 * (1 - kc ** 2)* (-1 + kc**2))
fac * cel(kc, 1.0, ap, bp)

In [None]:
-kc / (1 - kc * kc) * (cel(kc, kc**2, b / kc**2, a) - cel(kc, 1.0, b / kc**2, a))

In [None]:
Eterm = ellipe(jnp.sqrt(1 - kc**2)) * (-b + a * kc * kc) / (kc * kc * (kc * kc - 1))
Fterm = (b - a) * ellipk(1 - kc**2) / (kc**2 - 1)
print(Eterm)

In [None]:
ellipe(jnp.sqrt(1 - kc**2))

In [None]:
from scipy.special import ellipe as ellipe_np
ellipe_np(1 - kc**2)

In [None]:
from legendre import ellippi, ellipk
kc, p, a, b = 1.0, 1.0, 1.0, 1.0

print((cel(kc, p, a, b + 1e-6) - cel(kc, p, a, b)) / 1e-6)

n = p - 1.0
k = jnp.sqrt(1 - kc * kc)
ellipk(k) / n - cel(kc, p, 1.0, 1.0) / n

In [None]:
from legendre import *
ellipe(1 - kc**2)

In [None]:
kc, p, a, b = 0.5, 1.0, 1.0, 1.0
n = p - 1.0
k = jnp.sqrt(1 - kc * kc)

print((cel(kc, p, a, b + 1e-6) - cel(kc, p, a, b)) / 1e-6)
cel(kc, p, 0.0, 1.0)

In [None]:
kc, p, a, b = 0.5, 1.0, 1.0, 1.0
n = p - 1.0
k = jnp.sqrt(1 - kc * kc)

print((cel(kc, p, a + 1e-6, b) - cel(kc, p, a, b)) / 1e-6)
cel(kc, p, 1.0, 0.0)

In [None]:
kc, p, a, b = 0.5, 1.01, 1.0, 1.0
n = p - 1.0
k = jnp.sqrt(1 - kc * kc)

print((cel(kc, p + 1e-6, a, b) - cel(kc, p, a, b)) / 1e-6)
lam = kc * kc * (b + a * p - 2 * b * p) + p * (3 * b * p - a * p**2 - 2*b)
dp = cel(kc, p, 0.0, lam) + (b - a * p) * cel(kc, 1.0, 1.0 - p, kc * kc - p)
dp /= 2 * p * (1 - p) * (p - kc * kc)
dp

In [None]:
parr = np.linspace(0.1, 2, 1000)
plt.plot(parr, cel(kc, parr, 1.0, 1.0))

In [None]:
cel(kc, p, 0.0, lam)

In [None]:
cel(kc, p, 0.0, 1.0) * lam

In [None]:
cel(kc, p, lam, 0.0)

In [None]:
cel(kc, p, 1.0, 0.0) * lam

In [None]:
cel(kc, p, 0.0, lam) + (b - a * p) * cel(kc, 1.0, 1.0 - p, kc * kc - p)

In [None]:
lam * cel(kc, p, 0.0, 1.0) + (b - a * p) * cel(kc, 1.0, 1.0 - p, kc * kc - p)