In [1]:
# Source code for ICML submission #640 "Efficient Continuous Pareto Exploration in Multi-Task Learning"
# This script generates Figure 4 in the paper.

import codecs
import gzip
import os
import urllib
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import torchvision.transforms as transforms
from PIL import Image
import scipy.optimize

from common import *

# Fix the random seed.
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f0e7804e2d0>

In [2]:
# Download MNIST training images and labels into root_folder.
# Returns (images, labels).
# images is of dimension 60k x h x w, labels is of dimension 60k.
def download_mnist_training(root_folder):
    # Helper function.
    def get_int(b):
        return int(codecs.encode(b, 'hex'), 16)

    # Download data.
    image_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
    label_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    for url in (image_url, label_url):
        data = urllib.request.urlopen(url)
        name = url.rpartition('/')[2]
        file_path = os.path.join(root_folder, name)
        with open(file_path, 'wb') as f:
            f.write(data.read())
        with open(file_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(file_path) as zip_f:
            out_f.write(zip_f.read())
        os.remove(file_path)

    # Extract images.
    with open(os.path.join(root_folder, 'train-images-idx3-ubyte'), 'rb') as f:
        data = f.read()
        # Check the magic number and metadata.
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        assert length == 60000
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        assert num_rows == num_cols == 28
        # Read images.
        image_data = np.frombuffer(data, dtype=np.uint8, offset=16)
        images = image_data.reshape(length, num_rows, num_cols)

    # Extract labels.
    with open(os.path.join(root_folder, 'train-labels-idx1-ubyte'), 'rb') as f:
        data = f.read()
        # Check the magic number and metadata.
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        assert length == 60000
        label_data = np.frombuffer(data, dtype=np.uint8, offset=8)
        labels = label_data.ravel()
    return images, labels

root_folder = 'MultiMNISTSubset'
mnist_images, mnist_labels = download_mnist_training(root_folder)

In [3]:
# Visualize MNIST.
import matplotlib.pyplot as plt
%matplotlib tk

fig = plt.figure(figsize=(8, 8))
choice = np.random.choice(mnist_labels.size, 16, replace=False)
for i in range(4):
    for j in range(4):
        ax = fig.add_subplot(4, 4, i * 4 + j + 1)
        ax.matshow(mnist_images[choice[i * 4 + j]])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('label: {}'.format(mnist_labels[choice[i * 4 + j]]))
plt.show()

In [4]:
# Generate MultiMNIST.
# Returns (images, left_label, right_label).
# images is of size (number, h, w), left_label and right_label are of size (number,)
def generate_multi_mnist(number):
    multi_images = []
    left_labels = []
    right_labels = []
    mnist_size = mnist_labels.size
    left = np.random.permutation(mnist_size)[:number]
    right = np.random.permutation(mnist_size)[:number]
    for l, r in zip(left, right):
        left_labels.append(mnist_labels[l])
        right_labels.append(mnist_labels[r])

        left_image = mnist_images[l]
        right_image = mnist_images[r]
        # Randomly shift left and right images.
        if np.random.rand() < 0.5:
            shift_left = np.random.randint(3)
            shift_right = np.random.randint(4)
        else:
            shift_left = np.random.randint(4)
            shift_right = np.random.randint(3)
        
        new_image = np.zeros((36, 36))
        new_image[shift_left:shift_left + 28, shift_left:shift_left + 28] += left_image
        new_image[8 - shift_right:36 - shift_right, 8 - shift_right:36 - shift_right] += right_image

        # Post-processing.
        new_image = np.clip(new_image, 0, 255).astype(mnist_images[0].dtype)
        # Downsample the image to 14 x 14.
        new_image = np.array(Image.fromarray(new_image).resize((14, 14), resample=Image.NEAREST))
        multi_images.append(new_image)
    return multi_images, left_labels, right_labels

subset_size = 2048
multi_images, left_labels, right_labels = generate_multi_mnist(subset_size)

In [5]:
# Visualize MultiMNIST.
fig = plt.figure(figsize=(8, 8))
choice = np.random.choice(subset_size, 16, replace=False)
for i in range(4):
    for j in range(4):
        ax = fig.add_subplot(4, 4, i * 4 + j + 1)
        idx = choice[i * 4 + j]
        ax.matshow(multi_images[idx])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('label: {}, {}'.format(left_labels[idx], right_labels[idx]))
plt.show()

In [6]:
# Define the neural network.
class MiniLeNet(nn.Module):
    def __init__(self):
        super(MiniLeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, (5, 5), stride=2)
        self.fc1 = nn.Linear(40, 20)
        self.fc3_1 = nn.Linear(20, 10)
        self.fc3_2 = nn.Linear(20, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = [self.fc3_1(x), self.fc3_2(x)]
        return x

network = MiniLeNet()
x0 = ndarray(parameters_to_vector(network.parameters()).clone().detach().numpy())

In [7]:
# Define the loss.
# Transform images to torch tensors and normalize them.
# For MNIST, mean = 0.1307, std = 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# Dim of multi_images_torch: #images x #channels x height x width.
multi_images_torch = torch.stack([transform(Image.fromarray(img)) for img in multi_images], dim=0)

# Label shape: (#images,)
left_labels_torch = torch.from_numpy(ndarray(left_labels)).long()
right_labels_torch = torch.from_numpy(ndarray(right_labels)).long()

loss_function = nn.CrossEntropyLoss()

In [8]:
# Generate empirical Pareto front.
# Depending on your choice of num_trials, this process can take a while.
# For 101 trials, this cell took 45 minutes to finish (PyTorch CPU). We actually used GPU in our experiments, which is
# way faster than CPUs. However, we intend to keep this example simple to set up so we didn't include our GPU version
# in this script.
# We have also attached the training results in this code repository for num_trials = 101. If you don't plan to use
# a new num_trials, executing this cell will be super fast.
num_trials = 101
def get_loss(x):
    x_torch = torch.as_tensor(x, dtype=torch.float)
    vector_to_parameters(x_torch, network.parameters())
    logits = network(multi_images_torch)
    loss_left, loss_right = loss_function(logits[0], left_labels_torch), loss_function(logits[1], right_labels_torch)
    loss_left = loss_left.double().clone().detach().cpu().numpy()
    loss_right = loss_right.double().clone().detach().cpu().numpy()
    return loss_left, loss_right

for w1 in np.linspace(0, 1, num_trials):
    w2 = 1 - w1

    def loss_and_grad(x):
        # Convert x to tensor.
        x_torch = torch.as_tensor(x, dtype=torch.float)

        # Compute loss.
        vector_to_parameters(x_torch, network.parameters())
        logits = network(multi_images_torch)
        loss_left, loss_right = loss_function(logits[0], left_labels_torch), loss_function(logits[1], right_labels_torch)
        loss_torch = w1 * loss_left + w2 * loss_right
        loss = loss_torch.double().clone().detach().cpu().numpy()

        # Compute gradients.
        grad = 0
        for loss_node, w in [(loss_left, w1), (loss_right, w2)]:
            grad_node = list(torch.autograd.grad(loss_node, network.parameters(), retain_graph=True, allow_unused=True))
            for i, (grad_module, param) in enumerate(zip(grad_node, network.parameters())):
                if grad_module is None:
                    grad_node[i] = torch.zeros_like(param)
            grad_vec = parameters_to_vector(grad_node)
            grad += grad_vec.double().clone().detach().numpy() * w
        return loss, grad

    data_file = os.path.join(root_folder, '{:.2f}_{:.2f}.bin'.format(w1, w2))
    if not os.path.exists(data_file):
        result = scipy.optimize.minimize(loss_and_grad, x0, method='L-BFGS-B', jac=True, bounds=None)
        x = result.x
        print('w1: {}, w2: {}'.format(w1, w2))
        print('x0:', x0)
        print('results:', result)
        print('loss and grad:', get_loss(x))
        pickle.dump(x, open(data_file, 'wb'))

w1: 0.0, w2: 1.0
x0: [ 0.15290771  0.16600157 -0.0468545  ...  0.05836105  0.1363675
  0.20993263]
results:       fun: array(0.63615155)
 hess_inv: <1500x1500 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 3.91322561e-03,  8.73468816e-05,  1.56715978e-04, ...,
       -1.37265006e-05, -1.94658598e-04, -5.56853483e-05])
  message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 5961
      nit: 5736
   status: 0
  success: True
        x: array([ 0.5433707 , -0.14053733,  0.45805089, ..., -1.20540721,
        1.15076842,  0.29510445])
loss and grad: (array(9.06728935), array(0.63615155))
w1: 0.01, w2: 0.99
x0: [ 0.15290771  0.16600157 -0.0468545  ...  0.05836105  0.1363675
  0.20993263]
results:       fun: array(0.80026311)
 hess_inv: <1500x1500 LbfgsInvHessProduct with dtype=float64>
      jac: array([-4.30060587e-03, -2.26301576e-04,  1.10684158e-03, ...,
       -9.31491135e-05, -7.19134676e-04, -4.61776742e-04])
  message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_

In [13]:
# Plot the empirical Pareto front.
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(1, 1, 1)

# Helper function.
# Input: numpy array of size 1500.
# Output: two scalars representing the left loss and right loss.
def get_loss(x):
    x_torch = torch.as_tensor(x, dtype=torch.float)
    vector_to_parameters(x_torch, network.parameters())
    logits = network(multi_images_torch)
    loss_left, loss_right = loss_function(logits[0], left_labels_torch), loss_function(logits[1], right_labels_torch)
    loss_left = loss_left.double().clone().detach().cpu().numpy()
    loss_right = loss_right.double().clone().detach().cpu().numpy()
    return loss_left, loss_right

pareto_front = []
for w1 in np.linspace(0, 1, num_trials):
    w2 = 1 - w1
    data_file = os.path.join(root_folder, '{:.2f}_{:.2f}.bin'.format(w1, w2))
    x = pickle.load(open(data_file, 'rb'))
    pareto_front.append(get_loss(x))

pareto_front = ndarray(pareto_front)
ax.scatter(pareto_front[:, 0], pareto_front[:, 1], c='k')
ax.set_aspect('equal')
plt.show()

In [12]:
pareto_front

array([[69.43066406,  1.13781726],
       [ 2.6763103 ,  1.11721623],
       [ 1.95574057,  1.10417998],
       [ 1.82927787,  1.09711087],
       [ 1.82162964,  1.10605049],
       [ 1.73173606,  1.09060371],
       [ 1.70210242,  1.10607243],
       [ 1.71688175,  1.10694396],
       [ 1.68528295,  1.11643124],
       [ 1.66520798,  1.11060786],
       [ 1.64345551,  1.13728034],
       [ 1.60336745,  1.13769543],
       [ 1.57139778,  1.20075488],
       [ 1.45916653,  1.19604611],
       [ 1.43582714,  1.22470582],
       [ 1.35479331,  1.16825151],
       [ 1.32347548,  1.24417531],
       [ 1.29005885,  1.2070626 ],
       [ 1.30735242,  1.21402895],
       [ 1.24484754,  1.22244608],
       [ 1.28228545,  1.2275275 ],
       [ 1.28953004,  1.26608634],
       [ 1.41393459,  1.24997652],
       [ 1.42970979,  1.34190285],
       [ 1.71911907,  1.439116  ],
       [ 1.42637861,  1.36297512],
       [ 1.36112976,  1.25696182],
       [ 1.43816376,  1.49806118],
       [ 1.91014469,