In [0]:
!apt install imagemagick &>/dev/null

In [0]:
import torch
import numpy as np

from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.layouts import gridplot
from bokeh.resources import CDN
from bokeh.embed import file_html

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import seaborn as sns

In [0]:
c = torch.tensor([2., 0.])

primary_function = lambda x: (torch.sqrt((x**2).sum())-1)**2
auxiliary_function = lambda x: ((x-c)**2).sum()
stopping_criterion = lambda x: np.sum((x.detach().numpy()-np.array([1, 0]))**2) < 1e-4

modes = ['Single task',
         'Multi-task',
         'Projection',
         'Unweighted cosine',
         'Weighted cosine',
         'Orthogonal']

In [0]:
def censored_vector(u, v, mode='Projection'):
  """Adjusts the auxiliary loss gradient
  
  Adjusts the auxiliary loss gradient before adding it to the primary loss
  gradient and using a gradient descent-based method
  
  Args:
    u: A PyTorch variable representing the auxiliary loss gradient
    v: A PyTorch variable representing the primary loss gradient
    mode: The method used for the adjustment:
      - Single task: the auxiliary loss gradient is ignored
      - Multi-task: the auxiliary loss gradient is kept as it is
      - Unweighted cosine: cf. https://arxiv.org/abs/1812.02224
      - Weighted cosine: cf. https://arxiv.org/abs/1812.02224
      - Orthogonal: https://arxiv.org/abs/1801.07593
      - Projection: cf. ICML submission
    
  Returns:
    A PyTorch variable representing the adjusted auxiliary loss gradient
  """
  if mode == 'Single task':
    return 0  
  if mode == 'Multi-task':
    return u
  l_u, l_v = torch.norm(u), torch.norm(v)
  if l_u.numpy() == 0 or l_v.numpy() == 0:
    return u
  u_dot_v = (u*v).sum()
  if mode == 'Unweighted cosine':
    return u if u_dot_v > 0 else torch.zeros_like(u)
  if mode == 'Weighted cosine':
    return torch.max(u_dot_v, torch.tensor(0.))*u/l_u/l_v
  if mode == 'Projection':
    return u - torch.min(u_dot_v, torch.tensor(0.))*v/l_v/l_v
  if mode == 'Orthogonal':
    return u - u_dot_v*v/l_v/l_v
  if mode == 'Parameter-wise':
    return u*((torch.sign(u*v)+1)/2)

In [0]:
def train(primary_function,
          auxiliary_function,
          optimizer,
          mode='Projection',
          start=[0, 0],
          lam=0.1,
          iterations=5000,
          stopping_criterion=None,
          patience=10):
  x = torch.tensor(start, requires_grad=True, dtype=torch.float)
  xs = np.expand_dims(x.data.numpy(), 0)
  optimizer = optimizer([x], lr=0.01)

  criterion_met = 0
  for i in range(iterations):

    optimizer.zero_grad()

    primary_loss = primary_function(x)
    auxiliary_loss = auxiliary_function(x)

    auxiliary_loss.backward()
    x.auxiliary_grad = x.grad.detach().clone()
    x.grad = None

    primary_loss.backward()
    x.grad.add_(lam*censored_vector(x.auxiliary_grad, x.grad, mode))

    optimizer.step()

    xs = np.concatenate([xs, np.expand_dims(x.data.numpy(), 0)])

    if stopping_criterion is not None:
      if stopping_criterion(x):
        criterion_met += 1
        if stop_iteration is None:
          stop_iteration = i
        if criterion_met == patience:
          return xs, i
      else:
        criterion_met, stop_iteration = 0, None
  
  return xs

## Matplotlib figures

In [0]:
def sign_dot_product(x):
  primary_grad_direction = np.array(x) * (1 if (x[0]**2+x[1]**2 < 1) else -1)
  auxiliary_grad_direction = np.array([2, 0]) - np.array(x)
  return 1 if primary_grad_direction.dot(auxiliary_grad_direction) > 0 else -1

def animation(trajectories, steps, filename, edge=False):
  sns.set_style("white")
  sns.set_style("ticks")
  fig = plt.figure()
  ax = fig.gca()
  ax.set_adjustable("box")
  sns.despine()
  plt.gca().set_aspect('equal')
  plt.ylim((0, 2))
  plt.xlim(-1, 2)
  x_ticks, y_ticks = np.arange(-1, 2.5, .5), np.arange(0, 2.5, .5)
  x_labels = ['-1', '-0.5', '0', '0.5', '1', '1.5', '2']
  y_labels = ['0', '0.5', '1', '1.5', '2']
  ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(x_ticks))
  ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(y_ticks))
  ax.xaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(x_labels))
  ax.yaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(y_labels))

  circle1 = plt.Circle((0, 0), 1, color='grey', alpha=0.2, zorder=0)
  circle2 = plt.Circle((1, 0), 1, color='grey', alpha=0.2, zorder=0)
  ax.add_artist(circle1)
  ax.add_artist(circle2)

  palette = sns.color_palette("tab10", n_colors=10)
  palette = [palette[1], palette[0], palette[2], palette[3], palette[4], palette[5]]

  lines, points = [], []
  for mode in range(len(trajectories)):
    lines.append(ax.plot(trajectories[mode][:1, 0],
                         trajectories[mode][:1, 1],
                         color=palette[mode],
                         label=modes[mode])[0])
    points.append(ax.plot(trajectories[mode][:1, 0][-1],
                          trajectories[mode][:1, 1][-1],
                          color=palette[mode],
                          marker='o',
                          markeredgewidth=1.0)[0])

  plt.legend(frameon=False)
  plt.tight_layout()

  def update(i):
    for mode in range(len(trajectories)):
      lines[mode].set_xdata(trajectories[mode][:i, 0])
      lines[mode].set_ydata(trajectories[mode][:i, 1])
      
      x = trajectories[mode][:i, :][-1, :]
      points[mode].set_xdata(x[0])
      points[mode].set_ydata(x[1])
      if edge:
        if sign_dot_product(x) == -1:
          points[mode].set_markeredgecolor('black')
        else:
          points[mode].set_markeredgecolor(palette[mode])
    return lines+points

  anim = FuncAnimation(fig,
                       update,
                       frames=np.arange(1, steps, 100),
                       interval=100)
  anim.save(filename, dpi=80, writer='imagemagick')

def static_chart(ax, trajectories, starting_point, big):
  plt.sca(ax)
  sns.set_style("white")
  sns.set_style("ticks")
  ax.set_adjustable("box")
  sns.despine()
  ax.set_aspect('equal')
  if big:
    plt.ylim((0, 2))
  else:
    plt.ylim((0, 1.1))
  plt.xlim(-1, 2)
  x_ticks, y_ticks = np.arange(-1, 2.5, 1), np.arange(0, 2.5, 1)
  x_labels = ['-1', '0', '1', '2']
  y_labels = ['0', '1', '2']
  if big:
    ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(x_ticks))
    ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(y_ticks))
    ax.xaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(x_labels))
    ax.yaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(y_labels))
    
    ax.text(0.5, 0.4, '+',
            verticalalignment='center',
            horizontalalignment='center',
            fontsize=24)

    ax.text(-0.5, 0.4, '-',
            verticalalignment='center',
            horizontalalignment='center',
            fontsize=24)

    ax.text(1.5, 0.4, '-',
            verticalalignment='center',
            horizontalalignment='center',
            fontsize=24)

    ax.text(0.5, 1.8, '+',
            verticalalignment='center',
            horizontalalignment='center',
            fontsize=24)
  else:
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')

  circle1 = plt.Circle((0, 0), 1, color='grey', alpha=0.15, zorder=0)
  circle2 = plt.Circle((1, 0), 1, color='grey', alpha=0.15, zorder=0)
  ax.add_artist(circle1)
  ax.add_artist(circle2)
  ax.text(starting_point[0], starting_point[1], '■',
          verticalalignment='center',
          horizontalalignment='center',
          fontsize=15 if big else 10)
  palette = sns.color_palette("tab10", n_colors=10)
  palette = {'Single task': palette[1],
             'Multi-task': palette[0],
             'Projection': palette[2],
             'Unweighted cosine': palette[3],
             'Weighted cosine': palette[4]}
  linestyles = {'Single task': '-',
                'Multi-task': '--',
                'Projection': ':'}
  for mode in range(len(modes)):
    ax.plot(trajectories[mode][:, 0],
            trajectories[mode][:, 1],
            color=palette[modes[mode]],
            linestyle=linestyles[modes[mode]],
            label=modes[mode])
  if big:
    plt.legend(frameon=False, prop={'family':'Liberation serif', 'size': 9})
  plt.tight_layout()

In [0]:
modes = ['Single task', 'Multi-task', 'Projection']

matplotlib.rcParams['figure.figsize'] = 6.75, 3
matplotlib.rcParams['font.family'] = 'Liberation serif'
gs = matplotlib.gridspec.GridSpec(3, 2,
                       width_ratios=[3, 2],
                       height_ratios=[1, 1, 1])

axes = [plt.subplot(gs[:, 0]),
        plt.subplot(gs[0, 1]),
        plt.subplot(gs[1, 1]),
        plt.subplot(gs[2, 1])]

starting_points = [[0, 2], [1.9, 0.9], [-0.3, 0.3], [0.2, 0.2]]

for i in range(4):
  trajectories_adam = [train(primary_function,
                             auxiliary_function,
                             torch.optim.Adam,
                             mode=mode,
                             start=starting_points[i])
                       for mode in modes]
  ax = axes[i]
  static_chart(ax, trajectories_adam, starting_points[i], i == 0)
gs.tight_layout(plt.gcf(), h_pad=0.5, w_pad=0.3)
plt.savefig('toy-example.pdf')

In [0]:
modes = ['Single task', 'Multi-task', 'Projection', 'Unweighted cosine', 'Weighted cosine', 'Orthogonal']

In [0]:
i = 0
trajectories_adam = [train(primary_function,
                             auxiliary_function,
                             torch.optim.Adam,
                             mode=mode,
                             start=starting_points[i])
                       for mode in modes]

In [0]:
matplotlib.rcParams['figure.figsize'] = 6, 4.5
animation(trajectories_adam, 5000, 'trajectories_adam.gif')

In [0]:
i = 0
trajectories_vanilla_gd = [train(primary_function,
                             auxiliary_function,
                             torch.optim.SGD,
                             mode=mode,
                             start=starting_points[i])
                       for mode in modes]
animation(trajectories_vanilla_gd, 2000, 'trajectories_vanilla_gd.gif')

In [0]:
starting_points = [[-0.5, 1.5], [2, 1], [0.5, 0.5]]

for i in range(3):
  trajectories_adam = [train(primary_function,
                             auxiliary_function,
                             torch.optim.Adam,
                             mode=mode,
                             start=starting_points[i])
                       for mode in modes]
  animation(trajectories_adam, 5000, 'trajectories_adam%d.gif' %(i+2))
  trajectories_vanilla = [train(primary_function,
                                auxiliary_function,
                                torch.optim.SGD,
                                mode=mode,
                                start=starting_points[i])
                          for mode in modes]
  animation(trajectories_vanilla, 2000, 'trajectories_vanilla_gd%d.gif' %(i+2))


## Bokeh figures

In [0]:
def draw_charts(trajectories, modes, colors, c):
  
  output_notebook()

  s1 = figure(match_aspect=True, plot_width=400, plot_height=400)
  s1.circle(0, 0, radius=1, fill_color='grey', fill_alpha=.2, line_color='grey')
  s1.circle(1, 0, radius=1, fill_color='grey', fill_alpha=.2, line_color='grey')
  for i in range(len(trajectories)):
    xs = trajectories[i]
    s1.line(xs[:, 0],
            xs[:, 1],
            line_width=2,
            color=colors[i],
            legend_label=modes[i])
  s1.title.text = 'Trajectories'
  s1.yaxis.axis_label = 'y'
  s1.xaxis.axis_label = 'x'
  s1.legend.location = 'top_right'
  s1.legend.click_policy='hide'

  s2 = figure(plot_width=400, plot_height=400)
  for i in range(len(trajectories)):
    xs = trajectories[i]
    s2.line(range(xs.shape[0]),
            np.sqrt(xs[:, 0]**2+xs[:, 1]**2),
            line_width=2,
            color=colors[i],
            legend_label=modes[i])
  s2.legend.location = 'top_right'
  s2.legend.click_policy = 'hide'
  s2.title.text = 'Distance to the origin by optimization step'
  s2.yaxis.axis_label = 'Distance to the origin'
  s2.xaxis.axis_label = 'Steps'

  s3 = figure(plot_width=400, plot_height=400)
  for i in range(len(trajectories)):
    xs = trajectories[i]
    s3.line(range(xs.shape[0]),
            np.sqrt(np.sum((xs-c)**2, axis=1)),
            line_width=2,
            color=colors[i],
            legend_label=modes[i])
  s3.legend.location = 'top_right'
  s3.legend.click_policy = 'hide'
  s3.title.text = 'Distance to (%d, %d) by optimization step' % (c[0], c[1])
  s3.yaxis.axis_label = 'Distance to (%d, %d)' % (c[0], c[1])
  s3.xaxis.axis_label = 'Steps'

  p = gridplot([[s1, s2, s3]])
  show(p)
  return file_html(p, CDN, "my plot")

In [0]:
%%time
trajectories_adam = [train(primary_function,
                           auxiliary_function,
                           torch.optim.Adam,
                           mode=mode,
                           start=[0, 2])
                     for mode in modes]

colors = ['orange', 'blue', 'green', 'red', 'purple', 'brown']
html = draw_charts(trajectories_adam, modes, colors, c.numpy())

In [0]:
%%time
trajectories_vanilla_gd = [train(primary_function,
                                 auxiliary_function,
                                 torch.optim.SGD,
                                 iterations=1500,
                                 mode=mode,
                                 start=[0, 2])
                           for mode in modes]

draw_charts(trajectories_vanilla_gd, modes, colors, c.numpy())