In [229]:
from jax import grad, jacfwd, jacrev, jit
import jax.numpy as jnp
import plotly.express as px

In [230]:
def hvp(f, x, y, v, a1, a2):
    return grad(lambda x, y: jnp.vdot(grad(f, argnums=a1)(x, y), v), argnums=a2)(x, y)

In [231]:
def optimize(a, b):
    guess = jnp.ones_like(b)
    def error(t):
        return jnp.linalg.norm(a(t) - b)
    derrordx = grad(error)
    for i in range(200):
        guess -= 0.01 * derrordx(guess)
    return guess

In [232]:
def CGD(f, g, x0, y0, n=0.2, iterations=50):
    x = [x0]
    y = [y0]
    for i in range(iterations):
        dfdx = grad(f, argnums=0)
        dgdy = grad(g, argnums=1)
        lx = x[-1]
        ly = y[-1]

        rhs_x = dfdx(lx, ly) - n * hvp(f, lx, ly, dgdy(lx, ly), 1, 0)
        def lhs_x(v):
            n = len(v)
            return jnp.eye(n) @ v - n ** 2 * hvp(f, lx, ly, hvp(g, lx, ly, v, 0, 1), 1, 0)
        sol_x = optimize(lhs_x, rhs_x)

        rhs_y = dgdy(lx, ly) - n * hvp(g, lx, ly, dfdx(lx, ly), 0, 1)
        def lhs_y(v):
            n = len(v)
            return jnp.eye(n) @ v - n ** 2 * hvp(g, lx, ly, hvp(f, lx, ly, v, 1, 0), 0, 1)
        sol_y = optimize(lhs_y, rhs_y)

        nx = lx - n * sol_x
        ny = ly - n * sol_y
        x.append(nx)
        y.append(ny)
    return x, y

In [233]:
f = lambda x, y: (x.T @ y)[0, 0]
g = lambda x, y: -f(x, y)
x, y = CGD(f, g, jnp.ones((1, 1)), jnp.ones((1, 1)), iterations=20)

In [234]:
fig = px.scatter(x=[i[0, 0] for i in x], y=[i[0, 0] for i in y])
fig.update_yaxes(
    scaleanchor = "x",
    scaleratio = 1,
)
fig.show()