In [2]:
import numpy as np
import jax.numpy as jnp
import jax
import math

import matplotlib.pyplot as plt
from rikabplotlib.plot_utils import newplot


In [3]:

# Necessary for jax to work with 0^0
def build_powers(base, length):
   
    def body_fun(i, arr):
        # arr[i] = arr[i-1] * base
        return arr.at[i].set(arr[i-1] * base)
    
    # Initialize an array of zeros, then set arr[0] = 1
    arr = jnp.zeros((length,))
    arr = arr.at[0].set(1.0)
    
    # fori_loop will fill in arr[1], arr[2], ... arr[length-1]
    arr = jax.lax.fori_loop(1, length, body_fun, arr)
    return arr

@jax.jit
def polynomial_f(t, alpha, params):

    M, N = params.shape
    
    # Build powers of alpha: [1, alpha, alpha^2, ... alpha^(M-1)]
    alpha_powers = build_powers(alpha, M)  # shape (M,)

    # Build powers of t: [1, t, t^2, ... t^(N-1)]
    t_powers = build_powers(t, N)         # shape (N,)

    poly_val = alpha_powers @ params @ t_powers
    return poly_val
    


def t_from_x(x):
    return jnp.log(1/x)


def construct_cdf(function):

    def cdf(x, alpha, params):
        t = t_from_x(x)
        return jnp.nan_to_num(jnp.exp(-function(t, alpha, params)))
    return cdf



def construct_pdf(function):

    cdf = construct_cdf(function)
    derivative = jax.grad(cdf, argnums=0)

    def pdf(x, alpha, params):
        return jnp.nan_to_num(derivative(x, alpha, params) )

    return pdf



def taylor_expand_in_alpha(function, order):

    ps = [function,]
    if order > 0:
        for i in range(order):
            ps.append(jax.grad(ps[-1], argnums=1))

    def taylor_expansion(x, alpha, params):
        near_zero = 1e-16
        terms = jnp.array([p(x, near_zero, params) for p in ps])
        factorials = jax.scipy.special.gamma(jnp.arange(len(terms)) + 1)

        return jnp.sum(terms / factorials * jnp.power(alpha, jnp.arange(len(terms))))
    
    return taylor_expansion



def taylor_expand_in_t(function, order):

    ps = [function,]
    if order > 0:
        for i in range(order):
            ps.append(jax.grad(ps[-1], argnums=0))

    def taylor_expansion(x, alpha, params):
        near_zero = 1e-16
        terms = jnp.array([p(near_zero, alpha, params) for p in ps])
        factorials = jax.scipy.special.gamma(jnp.arange(len(terms)) + 1)

        return jnp.sum(terms / factorials * jnp.power(x, jnp.arange(len(terms))))
    
    return taylor_expansion


In [4]:
print(polynomial_f(5, 5, jnp.array([[1,1]])))

derivative = jax.grad(polynomial_f, argnums=0)
print(derivative(0.0, 0.0, jnp.array([[0,], [1,]])))

derivative2 = jax.grad(derivative, argnums=0)
print(derivative2(0.0, 0.0, jnp.array([[1,]])))

derivative3 = jax.grad(derivative2, argnums=0)
print(derivative3(0.0, 0.0, jnp.array([[1.0, 0.0] ])))

# derivative4 = jax.grad(derivative3, argnums=0)
# print(derivative4(0.0, 0.0, integral_coeffs[:2,:2]))




No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


6.0
0.0
0.0
0.0


In [11]:
def taylor_coefficients_2d(f, t0, alpha0, M, N, params=None):
    """
    Compute the Taylor coefficients c[m, n] for the bivariate expansion of
        f(t, alpha, params)
    around (t0, alpha0).  We assume the expansion is about (0,0) if
    t0=0, alpha0=0.

    f_taylor(t, alpha) = sum_{m=0..M} sum_{n=0..N} c[m,n] * alpha^m * t^n
    
    where
        c[m,n] = (1/(m! n!)) *
                 (d^(m+n) f / d alpha^m d t^n)(alpha0, t0).
    
    This version builds a 2D table of partial-derivative functions so it 
    doesn't recompute derivatives from scratch for every (m, n). 
    (Much more efficient than the naive approach, for moderate M, N.)

    Parameters
    ----------
    f        : callable
               A function f(t, alpha, params) -> scalar.
    t0       : float
               The point in t at which we expand.
    alpha0   : float
               The point in alpha at which we expand.
    M        : int
               Max power of alpha in the expansion.
    N        : int
               Max power of t in the expansion.
    params   : any
               Additional parameters that f may depend on.
    
    Returns
    -------
    c : jnp.ndarray of shape (M+1, N+1)
        The coefficients c[m, n].
    """
    
    # 0) Make a helper that has the right signature for JAX differentiation.
    #    We assume f(t, alpha, params).  We’ll keep that interface for clarity.
    def f_base(t, alpha):
        return f(t, alpha, params)

    # 1) Build partial derivatives wrt t in a list f_list[n] = (d^n/dt^n) f.
    #    We'll do this by repeated application of jax.grad with argnums=0.
    #
    #    f_list[0](t, alpha) = f_base(t, alpha)
    #    f_list[1](t, alpha) = d/dt of f_list[0], etc.
    #
    #    Each step is f_{(n+1)} = grad( f_{(n)}, argnums=0 ).
    
    f_list = [f_base]      # f_0
    for n_ in range(N):
        fn_plus_1 = jax.grad(f_list[-1], argnums=0)  # derivative wrt t
        f_list.append(fn_plus_1)

    # 2) For each f_list[n], we get all derivatives wrt alpha at alpha0,
    #    up to order M.  We'll store them in c[m,n].
    
    # Initialize a JAX array for the coefficients.
    c = jnp.zeros((M+1, N+1), dtype=jnp.float64)

    def derivatives_wrt_alpha_up_to_order_M(fn_of_alpha, alpha0, M):
        """
        Return [ fn_of_alpha^{(0)}(alpha0), 
                 fn_of_alpha^{(1)}(alpha0),
                 ...
                 fn_of_alpha^{(M)}(alpha0) ] 
        by repeated application of jax.grad wrt alpha (argnums=0).
        """
        out = []
        current_g = fn_of_alpha
        for m_ in range(M+1):
            if m_ == 0:
                # 0th derivative => the function value
                out.append(current_g(alpha0))
            else:
                # 1st..Mth => derivative wrt alpha
                current_g = jax.grad(current_g, argnums=0)
                out.append(current_g(alpha0))
        return jnp.array(out)

    # Loop over n=0..N
    for n_ in range(N+1):
        # The function f_n(t, alpha)
        fn = f_list[n_]

        # We'll evaluate partial derivatives wrt alpha of the function alpha -> fn(t0, alpha)
        def fn_of_alpha(a):
            return fn(t0, a)

        # partials_alpha[m] = (d^m / dalpha^m) fn(t0, alpha) at alpha=alpha0
        partials_alpha = derivatives_wrt_alpha_up_to_order_M(fn_of_alpha, alpha0, M)
        # partials_alpha[m] = (∂^(m+n_)/∂t^n_ ∂alpha^m) f  at (t0, alpha0)

        # 3) Fill c[m, n_] = partials_alpha[m] / (m! n_!)
        for m_ in range(M+1):
            val = partials_alpha[m_]
            denom = math.factorial(m_) * math.factorial(n_)
            c = c.at[m_, n_].set(val / denom)

    return c


def integrate_taylor_polynomial(c):
  

    M_plus_1, N_plus_1 = c.shape
    M = M_plus_1 - 1
    N = N_plus_1 - 1
    
    # New array will have shape (M+1, N+2)
    d = np.zeros((M_plus_1, N_plus_1 + 1), dtype=c.dtype)
    
    # d[m,n] = c[m,n-1]/n, except d[m,0] = 0.
    for m in range(M_plus_1):
        for n in range(1, N+2):
            d[m, n] = c[m, n-1] / n

    return d


def example_f(t, alpha, params):
    return (alpha * t )   



M = 5
N = 5
coeffs = taylor_coefficients_2d(example_f, 0.0, 0.0, M, N)
integral_coeffs = integrate_taylor_polynomial(coeffs)

def matching_coeffs(f, M, N):

    coeffs = taylor_coefficients_2d(f, 0.0, 0.0, M, N)
    integral_coeffs = integrate_taylor_polynomial(coeffs)

    K = M + N

    # @jax.jit
    def temp(t, alpha, params):

        x = polynomial_f(t, alpha, integral_coeffs)

        # Compute -x - x2/2 - x3/3 - ... - xK/K = log(1-x)
        return -1 * jnp.sum(jnp.array([x**k / k for k in range(1, K+1)]))

    
    return -1 * taylor_coefficients_2d(temp, 0.0, 0.0, M, N)

# matched_coeffs = matching_coeffs(example_f, 3, 3)

print(coeffs)
print(integral_coeffs)
print(matching_coeffs(example_f, 4, 4))



  c = jnp.zeros((M+1, N+1), dtype=jnp.float64)


[[0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]
[[0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.5 0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0.  0. ]]
[[ 0.    -0.    -0.    -0.    -0.   ]
 [-0.    -0.     0.5   -0.    -0.   ]
 [-0.    -0.    -0.    -0.     0.125]
 [-0.    -0.    -0.    -0.    -0.   ]
 [-0.    -0.    -0.    -0.    -0.   ]]


In [None]:
alpha = 0.118

x = jnp.linspace(0, 1, 10000)
t = t_from_x(x)
matched_coeffs = matching_coeffs(example_f, 4, 8)
print(matched_coeffs)

pdf = construct_pdf(polynomial_f)
pdf = jax.vmap(pdf, in_axes=(0, None, None))

fig, ax = newplot("full")


ax.plot(x, alpha * jnp.log(1/x) / x, color = 'black', ls = "-")
# ax.plot(x, alpha * jnp.log(1/x) / x * jnp.exp( -alpha / 2 * jnp.log(1/x)**2), color = 'grey', ls = "-")



# ax.plot(x, polynomial_f(t, alpha, coeffs) / x, label="Original", color = "black")
# ax.plot(x, pdf(x, alpha, matched_coeffs[:4 + 1]), label=r"Matched $\mathcal{O}(\alpha_s^4)$", color = "red")
ax.plot(x, pdf(x, alpha, matched_coeffs[:3 + 4]), label=r"Matched $\mathcal{O}(\alpha_s^3)$", color = "red", alpha = 0.75)
ax.plot(x, pdf(x, alpha, matched_coeffs[:2 + 1]), label=r"Matched $\mathcal{O}(\alpha_s^2)$", color = "red", alpha = 0.5)
ax.plot(x, pdf(x, alpha, matched_coeffs[:1 + 1]), label=r"Matched $\mathcal{O}(\alpha_s^1)$", color = "red", alpha = 0.25)




plt.yscale("log")
plt.ylim(1e-3, 1e3)

plt.legend(title=r"$p(x|\alpha) =\alpha\log(1/x)/x$")


# Plot the CDFs
cdf = construct_cdf(polynomial_f)
cdf = jax.vmap(cdf, in_axes=(0, None, None))

fig, ax = newplot("full")
# ax.plot(x, polynomial_f(t, alpha, integral_coeffs), label="Original", color = "black")
ax.plot(x, cdf(x, alpha, matched_coeffs), label=r"Matched $\mathcal{O}(\alpha_s^3)$", color = "red")
ax.plot(x, cdf(x, alpha, matched_coeffs[:-1]), label=r"Matched $\mathcal{O}(\alpha_s^2)$", color = "red", alpha = 0.5)
ax.plot(x, cdf(x, alpha, matched_coeffs[:-2]), label=r"Matched $\mathcal{O}(\alpha_s^1)$", color = "red", alpha = 0.25)

# ax.plot(x, 1- alpha * jnp.log(1/x) ** 2, color = 'black', ls = "-")
ax.plot(x, jnp.exp( -alpha / 2 * jnp.log(1/x)**2), color = 'pink', ls = "--")

  c = jnp.zeros((M+1, N+1), dtype=jnp.float64)
