In [1]:
# Source code for ICML submission #640 "Efficient Continuous Pareto Exploration in Multi-Task Learning"
# This script generates Figure 2 in the paper.
import numpy as np
import cvxpy as cp

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.sparse.linalg import LinearOperator, minres

from common import *
from zdt2_variant import Zdt2Variant

In [2]:
# Hyperparameters and random seeds.
K = 5               # Number of tangent directions we want to generate.
s = 0.1             # Step size.
maxiter = 2         # Maximum allowable number of iterations in MINRES.
np.random.seed(42)

In [3]:
# Setting up the problem.
problem = Zdt2Variant()
n, m = problem.n, problem.m
# Generate the initial Pareto optimal point.
x0 = problem.sample_pareto_set()
f0 = problem.f(x0)
g0 = problem.grad(x0)

In [4]:
# Solve alpha from gradients.
def compute_alpha(g):
    alpha = cp.Variable(m)
    objective = cp.Minimize(cp.sum_squares(alpha.T @ g))
    constraints = [alpha >= 0, cp.sum(alpha) == 1]
    alpha_prob = cp.Problem(objective, constraints)
    optimal_loss = alpha_prob.solve()
    alpha = ndarray(alpha.value).ravel()
    # In this synthetic example, the gradients at a Pareto optimal solution must be perfectly parallel.
    assert np.allclose(g[0] * alpha[0] + g[1] * alpha[1], 0)
    assert np.isclose(np.sum(alpha), 1)
    assert np.min(alpha) >= 0
    return alpha
alpha = compute_alpha(g0)

In [5]:
# Defining a matrix-free solver.
def H_op(y):
    y = ndarray(y).ravel()
    assert y.size == n
    return problem.hvp(x0, alpha, y)
def minres_solver(b):
    assert maxiter >= 2, print('Directly applying MINRES would require >= 2 iterations to gain better results.')
    x, _ = minres(LinearOperator((n, n), matvec=H_op, rmatvec=H_op), b, maxiter=maxiter)
    return ndarray(x)

In [6]:
# Visualization.
%matplotlib tk
# Collect all the directions we would like to plot: (direction, color, label, s_min, s_max)
# We will plot f(x0 + s * normalize(direction)) with s \in [s_min, s_max]
vi = []
for i in range(K):
    rhs = np.random.normal(size=n)
    vi.append((minres_solver(rhs), 'tab:orange', 'MINRES' if i == 0 else None, -s, s))
vi.append((g0[0], 'tab:blue', r'$\nabla f_1(x^*)$', 0, s))
vi.append((g0[1], 'tab:green', r'$\nabla f_2(x^*)$', 0, s))

# Plot the Pareto set.
fig_ps = plt.figure(figsize=(15, 15))
ax_ps = fig_ps.add_subplot(111, projection='3d', proj_type='ortho')
problem.plot_pareto_set(ax_ps)
ax_ps.scatter(x0[1], x0[2], x0[0], c='tab:red', s=50)
for idx, (v, color, _, _, _) in enumerate(vi):
    draw_arrow_3d(ax_ps, x0 + v / np.linalg.norm(v), x0, color)
    ax_ps.set_xlim([-2, 2])
    ax_ps.set_ylim([-2, 2])
plt.show()

# Plot the Pareto front.
fig_pf = plt.figure(figsize=(5, 5))
ax_pf = fig_pf.add_subplot(1, 1, 1)
ax_pf.scatter(f0[0], f0[1], c='tab:red', label='Pareto optimal $f(x^*)$', s=100)
problem.plot_pareto_front(ax_pf, label='Pareto front')
for idx, (v, color, label, s0, s1) in enumerate(vi):
    fi = [problem.f(x0 + si * v / np.linalg.norm(v)) for si in np.linspace(s0, s1, 21)]
    fi = ndarray(fi)
    ax_pf.plot(fi[:, 0], fi[:, 1], c=color, label=label, linewidth=3)
ax_pf.legend()
plt.show()

In [7]:
# Now generate more results with different random seeds.
vi = []
xi = []
fi = []
# (direction, color, label, s_min, s_max, transparency)
for seed in range(40):
    np.random.seed(seed)
    # Generate x0.
    x0 = problem.sample_pareto_set()
    f0 = problem.f(x0)
    g0 = problem.grad(x0)
    alpha0 = compute_alpha(g0)
    fi.append(f0)
    # Generate 1 direction for each x0.
    rhs = np.random.normal(size=n)
    v = minres_solver(rhs)
    vi.append((v, 'tab:orange', 'MINRES' if seed == 0 else None, -s, s, 1.0))
    xi.append(x0)
    vi.append((g0[0], 'tab:blue', r'$\nabla f_1(x^*)$' if seed == 0 else None, 0, s, 0.1))
    xi.append(x0)
    vi.append((g0[1], 'tab:green', r'$\nabla f_2(x^*)$' if seed == 0 else None, 0, s, 0.1))
    xi.append(x0)
fi = ndarray(fi)

In [8]:
# Plot more results. This is the figure in the supplemental material.
fig_pf = plt.figure(figsize=(5, 5))
ax_pf = fig_pf.add_subplot(1, 1, 1)
ax_pf.scatter(fi[:, 0], fi[:, 1], c='tab:red', label='Pareto optimal $f(x^*)$', s=20)
problem.plot_pareto_front(ax_pf, label='Pareto front')
for x0, (v, color, label, s0, s1, a) in zip(xi, vi):
    fi = [problem.f(x0 + si * v / np.linalg.norm(v)) for si in np.linspace(s0, s1, 21)]
    fi = ndarray(fi)
    ax_pf.plot(fi[:, 0], fi[:, 1], c=color, label=label, linewidth=3, alpha=a)
ax_pf.legend()
plt.show()