In [1]:
# Source code for ICML submission #640 "Efficient Continuous Pareto Exploration in Multi-Task Learning"
# This script generates Figure 2 in the paper.
import numpy as np
import cvxpy as cp

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.sparse.linalg import LinearOperator, lsmr

from common import *
from zdt2_variant import Zdt2Variant

In [2]:
# Hyperparameters and random seeds.
K = 5               # Number of tangent directions we want to generate.
s = 0.1             # Step size.
maxiter = 1         # Maximum allowable iterations in MINRES.
np.random.seed(42)

In [3]:
# Setting up the problem.
problem = Zdt2Variant()
n, m = problem.n, problem.m
# Generate the initial Pareto optimal point.
x0 = problem.sample_pareto_set()
f0 = problem.f(x0)
g0 = problem.grad(x0)

In [4]:
# Solve alpha at x0.
alpha = cp.Variable(m)
objective = cp.Minimize(cp.sum_squares(alpha.T @ g0))
constraints = [alpha >= 0, cp.sum(alpha) == 1]
alpha_prob = cp.Problem(objective, constraints)
optimal_loss = alpha_prob.solve()
alpha = ndarray(alpha.value).ravel()
# In this synthetic example, the gradients at a Pareto optimal solution must be perfectly parallel.
assert np.allclose(g0[0] * alpha[0] + g0[1] * alpha[1], 0)
assert np.isclose(np.sum(alpha), 1)
assert np.min(alpha) >= 0

In [5]:
# Defining a matrix-free solver.
def At_op(y):
    y = ndarray(y).ravel()
    assert y.size == n
    return np.hstack([problem.hvp(x0, alpha, y), g0 @ y])
def A_op(y):
    y = ndarray(y).ravel()
    assert y.size == n + m
    return problem.hvp(x0, alpha, y[:n]) + y[n:].T @ g0
# The way we used MINRES is different from what we described in the paper due to some specialities in this ZDT2-variant
# example. Instead of applying MINRES to Hv = g0.T @ beta where beta is randomly generated, we consider applying MINRES
# to the following problem: Let A = [H, g0.T] so Hv = g0.T @ beta equals finding a null vector of A. We then solve the
# linear least square problem below:
# min_x \|A.T @ x - b\|
# where b is randomly generated. This way, the solution x must satisfy the normal equation A @ A.T x = A @ b, and we
# apply MINRES to this equation. Once x is found, we can recover v by (A.T @ x - b)[:n].
# This mathematical trick for finding a null vector by solving a symmetric positive-semidefinite linear system was
# introduced in the following papers:
#
# "MINRES-QLP: A Krylov subspace method for indefinite or singular symmetric systems"
# "LSMR: An iterative algorithm for sparse least-squares problems"
#
# We didn't mention this method in our main paper because we did not use this trick in any other example. We found that
# directly applying MINRES to Hv = g0.T @ beta seems to work better on neual networks while LSMR is more suitable for
# tiny problems like ZDT2-variant. Note that LSMR is analytically equivalent to applying MINRES to solve the normal
# equation above, so we still call our method MINRES in the main paper.
def lsmr_solver(b):
    x = lsmr(LinearOperator((n + m, n), matvec=At_op, rmatvec=A_op), b, maxiter=maxiter)[0]
    return ndarray((At_op(x) - b)[:n])

In [6]:
# Visualization.
%matplotlib tk
# Collect all the directions we would like to plot: (direction, color, label, s_min, s_max)
# We will plot f(x0 + s * normalize(direction)) with s \in [s_min, s_max]
vi = []
for i in range(K):
    rhs = np.random.normal(size=n+m)
    vi.append((lsmr_solver(rhs), 'tab:orange', 'MINRES' if i == 0 else None, -s, s))
vi.append((g0[0], 'tab:blue', r'$\nabla f_1(x^*)$', 0, s))
vi.append((g0[1], 'tab:green', r'$\nabla f_2(x^*)$', 0, s))

# Plot the Pareto set.
fig_ps = plt.figure(figsize=(15, 15))
ax_ps = fig_ps.add_subplot(111, projection='3d', proj_type='ortho')
problem.plot_pareto_set(ax_ps)
ax_ps.scatter(x0[1], x0[2], x0[0], c='tab:red', s=50)
for idx, (v, color, label, _, _) in enumerate(vi):
    draw_arrow_3d(ax_ps, x0 + v / np.linalg.norm(v), x0, color)
    ax_ps.set_xlim([-2, 2])
    ax_ps.set_ylim([-2, 2])
plt.show()

# Plot the Pareto front.
fig_pf = plt.figure(figsize=(5, 5))
ax_pf = fig_pf.add_subplot(1, 1, 1)
ax_pf.scatter(f0[0], f0[1], c='tab:red', label='Pareto optimal $x^*$', s=100)
problem.plot_pareto_front(ax_pf, label='Pareto front')
for idx, (v, color, label, s0, s1) in enumerate(vi):
    fi = [problem.f(x0 + si * v / np.linalg.norm(v)) for si in np.linspace(s0, s1, 21)]
    fi = ndarray(fi)
    ax_pf.plot(fi[:, 0], fi[:, 1], c=color, label=label, lineWidth=3)
ax_pf.legend()
plt.show()