In [83]:
# Source code for ICML submission #640 "Efficient Continuous Pareto Exploration in Multi-Task Learning"
# This script generates Figure 4 in the paper.

import codecs
import gzip
import os
import urllib
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import torchvision.transforms as transforms
from PIL import Image
import scipy.optimize

from common import *

# Fix the random seed.
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fb7f7fa42d0>

In [34]:
# Download MNIST training images and labels into root_folder.
# Returns (images, labels).
# images is of dimension 60k x h x w, labels is of dimension 60k.
def download_mnist_training(root_folder):
    # Helper function.
    def get_int(b):
        return int(codecs.encode(b, 'hex'), 16)

    # Download data.
    image_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
    label_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
    for url in (image_url, label_url):
        data = urllib.request.urlopen(url)
        name = url.rpartition('/')[2]
        file_path = os.path.join(root_folder, name)
        with open(file_path, 'wb') as f:
            f.write(data.read())
        with open(file_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(file_path) as zip_f:
            out_f.write(zip_f.read())
        os.remove(file_path)

    # Extract images.
    with open(os.path.join(root_folder, 'train-images-idx3-ubyte'), 'rb') as f:
        data = f.read()
        # Check the magic number and metadata.
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        assert length == 60000
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        assert num_rows == num_cols == 28
        # Read images.
        image_data = np.frombuffer(data, dtype=np.uint8, offset=16)
        images = image_data.reshape(length, num_rows, num_cols)

    # Extract labels.
    with open(os.path.join(root_folder, 'train-labels-idx1-ubyte'), 'rb') as f:
        data = f.read()
        # Check the magic number and metadata.
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        assert length == 60000
        label_data = np.frombuffer(data, dtype=np.uint8, offset=8)
        labels = label_data.ravel()
    return images, labels

root_folder = 'MultiMNISTSubset'
mnist_images, mnist_labels = download_mnist_training(root_folder)

In [35]:
# Visualize MNIST.
import matplotlib.pyplot as plt
%matplotlib tk

fig = plt.figure(figsize=(8, 8))
choice = np.random.choice(mnist_labels.size, 16, replace=False)
for i in range(4):
    for j in range(4):
        ax = fig.add_subplot(4, 4, i * 4 + j + 1)
        ax.matshow(mnist_images[choice[i * 4 + j]])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('label: {}'.format(mnist_labels[choice[i * 4 + j]]))
plt.show()

In [36]:
# Generate MultiMNIST.
# Returns (images, left_label, right_label).
# images is of size (number, h, w), left_label and right_label are of size (number,)
def generate_multi_mnist(number):
    multi_images = []
    left_labels = []
    right_labels = []
    mnist_size = mnist_labels.size
    left = np.random.permutation(mnist_size)[:number]
    right = np.random.permutation(mnist_size)[:number]
    for l, r in zip(left, right):
        left_labels.append(mnist_labels[l])
        right_labels.append(mnist_labels[r])

        left_image = mnist_images[l]
        right_image = mnist_images[r]
        # Randomly shift left and right images.
        if np.random.rand() < 0.5:
            shift_left = np.random.randint(3)
            shift_right = np.random.randint(4)
        else:
            shift_left = np.random.randint(4)
            shift_right = np.random.randint(3)
        
        new_image = np.zeros((36, 36))
        new_image[shift_left:shift_left + 28, shift_left:shift_left + 28] += left_image
        new_image[8 - shift_right:36 - shift_right, 8 - shift_right:36 - shift_right] += right_image

        # Post-processing.
        new_image = np.clip(new_image, 0, 255).astype(mnist_images[0].dtype)
        # Downsample the image to 14 x 14.
        new_image = np.array(Image.fromarray(new_image).resize((14, 14), resample=Image.NEAREST))
        multi_images.append(new_image)
    return multi_images, left_labels, right_labels

subset_size = 2048
multi_images, left_labels, right_labels = generate_multi_mnist(subset_size)

In [37]:
# Visualize MultiMNIST.
fig = plt.figure(figsize=(8, 8))
choice = np.random.choice(subset_size, 16, replace=False)
for i in range(4):
    for j in range(4):
        ax = fig.add_subplot(4, 4, i * 4 + j + 1)
        idx = choice[i * 4 + j]
        ax.matshow(multi_images[idx])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title('label: {}, {}'.format(left_labels[idx], right_labels[idx]))
plt.show()

In [39]:
# Define the neural network.
class MiniLeNet(nn.Module):
    def __init__(self):
        super(MiniLeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, (5, 5), stride=2)
        self.fc1 = nn.Linear(40, 20)
        self.fc3_1 = nn.Linear(20, 10)
        self.fc3_2 = nn.Linear(20, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = [self.fc3_1(x), self.fc3_2(x)]
        return x

network = MiniLeNet()

In [66]:
# Define the loss.
# Transform images to torch tensors and normalize them.
# For MNIST, mean = 0.1307, std = 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# Dim of multi_images_torch: #images x #channels x height x width.
multi_images_torch = torch.stack([transform(Image.fromarray(img)) for img in multi_images], dim=0)

# Label shape: (#images,)
left_labels_torch = torch.from_numpy(ndarray(left_labels)).long()
right_labels_torch = torch.from_numpy(ndarray(right_labels)).long()

loss_function = nn.CrossEntropyLoss()

In [90]:
# Generate empirical Pareto front.
# Depending on your choice of num_trials, this process can take a while.
# For 101 trials, this cell took 45 minutes to finish (PyTorch CPU). We actually used GPU in our experiments, which is
# way faster than CPUs. However, we intend to keep this example simple to set up so we didn't include our GPU version
# in this script.
x0 = ndarray(parameters_to_vector(network.parameters()).clone().detach().numpy())
num_trials = 101
for w1 in np.linspace(0, 1, num_trials):
    w2 = 1 - w1

    def loss_and_grad(x):
        # Convert x to tensor.
        x_torch = torch.as_tensor(x, dtype=torch.float)

        # Compute loss.
        vector_to_parameters(x_torch, network.parameters())
        logits = network(multi_images_torch)
        loss_left, loss_right = loss_function(logits[0], left_labels_torch), loss_function(logits[1], right_labels_torch)
        loss_torch = w1 * loss_left + w2 * loss_right
        loss = loss_torch.double().clone().detach().cpu().numpy()

        # Compute gradients.
        grad = 0
        for loss_node, w in [(loss_left, w1), (loss_right, w2)]:
            grad_node = list(torch.autograd.grad(loss_node, network.parameters(), retain_graph=True, allow_unused=True))
            for i, (grad_module, param) in enumerate(zip(grad_node, network.parameters())):
                if grad_module is None:
                    grad_node[i] = torch.zeros_like(param)
            grad_vec = parameters_to_vector(grad_node)
            grad += grad_vec.double().clone().detach().numpy() * w
        return loss, grad

    data_file = os.path.join(root_folder, '{:.2f}_{:.2f}.bin'.format(w1, w2))
    if not os.path.exists(data_file):
        result = scipy.optimize.minimize(loss_and_grad, x0, method='L-BFGS-B', jac=True, bounds=None)
        x = result.x
        pickle.dump(x, open(data_file, 'wb'))