In [321]:
from jax import grad, jacfwd, jacrev, jit
import jax.numpy as jnp
from jax.scipy.sparse.linalg import cg, bicgstab
import plotly.express as px


Computes the Hessian-vector product of a function $f$ on two variables with a vector $v$.


In [322]:
def hvp(f, x, y, v, a1, a2):
    return grad(lambda x, y: jnp.vdot(grad(f, argnums=a1)(x, y), v), argnums=a2)(x, y)


Competitive gradient descent (CGD) is a new algorithm for competitive optimization which offers a generalization of gradient descent to the competitive setting. The algorithm pursues a game-theoretic approach where the notion of a solution is given by the Nash equilibrium where neither player can unilaterally improve their strategy.


In [323]:
def CGD(f, g, x0, y0, n=0.2, iterations=50):
    x = [x0]
    y = [y0]
    for i in range(iterations):
        dfdx = grad(f, argnums=0)
        dgdy = grad(g, argnums=1)
        lx = x[-1]
        ly = y[-1]

        rhs_x = dfdx(lx, ly) - n * hvp(f, lx, ly, dgdy(lx, ly), 1, 0)

        def lhs_x(v):
            return v - n ** 2 * hvp(f, lx, ly, hvp(g, lx, ly, v, 0, 1), 1, 0)
        sol_x = cg(lhs_x, rhs_x)[0]

        rhs_y = dgdy(lx, ly) - n * hvp(g, lx, ly, dfdx(lx, ly), 0, 1)

        def lhs_y(v):
            return v - n ** 2 * hvp(g, lx, ly, hvp(f, lx, ly, v, 1, 0), 0, 1)
        sol_y = cg(lhs_y, rhs_y)[0]

        nx = lx - n * sol_x
        ny = ly - n * sol_y
        x.append(nx)
        y.append(ny)
    return x, y


Define a simple bilinear game with functions $f(x,y)=x^Ty$ and $g(x,y)=-f(x,y)$. Run GDA for 50 iterations using $\eta=0.2$ with $x$ and $y$ starting as column vectors consisting of a single 1.


In [324]:
def f(x, y): return (x.T @ y)[0, 0]
def g(x, y): return -f(x, y)


x, y = CGD(f, g, jnp.ones((1, 1)), jnp.ones((1, 1)), iterations=50)


We can visualize this game by plotting $x$ and $y$ for each iteration of CGD. This demonstrates the convergent behavior of CGD when compared against SimGD, which spirals away from the Nash equilibrium instead of toward it.


In [325]:
fig = px.scatter(x=[i[0, 0] for i in x], y=[i[0, 0] for i in y])
fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)
fig.show()
