# Multitask GP Regression

## Introduction

Multitask regression, introduced in [this paper](https://papers.nips.cc/paper/3189-multi-task-gaussian-process-prediction.pdf) learns similarities in the outputs simultaneously. It's useful when you are performing regression on multiple functions that share the same inputs, especially if they have similarities (such as being sinusodial). 

Given inputs $x$ and $x'$, and tasks $i$ and $j$, the covariance between two datapoints and two tasks is given by

$$  k([x, i], [x', j]) = k_\text{inputs}(x, x') * k_\text{tasks}(i, j)
$$

where $k_\text{inputs}$ is a standard kernel (e.g. RBF) that operates on the inputs.
$k_\text{task}$ is a lookup table containing inter-task covariance.

In [None]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette("bright")
torch.set_default_dtype(torch.double)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
palette = sns.color_palette("bright")

In [None]:
sns.palplot(palette)

### Set up training data

In the next cell, we set up the training data for this example. We'll be using 100 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels.

We'll have two functions - a sine function (y1) and a cosine function (y2).

For MTGPs, our `train_targets` will actually have two dimensions: with the second dimension corresponding to the different tasks.

In [None]:
train_x = torch.linspace(0, 1, 100).view(-1,1)

train_y = torch.stack([
    torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.cos(train_x * (5 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.sin(train_x * (3 * math.pi)) + torch.randn(train_x.size()) * 0.2,
], -1).squeeze(-2)

In [None]:
num_tasks = train_y.shape[-1]

In [None]:
train_y.shape

## Define a multitask model

The model should be somewhat similar to the `ExactGP` model in the [simple regression example](../01_Exact_GPs/Simple_GP_Regression.ipynb).
The differences:

1. We're going to wrap ConstantMean with a `MultitaskMean`. This makes sure we have a mean function for each task.
2. Rather than just using a RBFKernel, we're using that in conjunction with a `MultitaskKernel`. This gives us the covariance function described in the introduction.
3. We're using a `MultitaskMultivariateNormal` and `MultitaskGaussianLikelihood`. This allows us to deal with the predictions/outputs in a nice way. For example, when we call MultitaskMultivariateNormal.mean, we get a `n x num_tasks` matrix back.

You may also notice that we don't use a ScaleKernel, since the IndexKernel will do some scaling for us. (This way we're not overparameterizing the kernel.)

In [None]:
from sampling_mtgps import MatheronMultiTaskGP
model = MatheronMultiTaskGP(train_x, train_y)
likelihood = model.likelihood

### Train the model hyperparameters

In [None]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iterations = 2 if smoke_test else 50


# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iterations):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
    optimizer.step()

### Make predictions with the model

In [None]:
# Set into eval mode
model.eval()
likelihood.eval()

# Initialize plots
f, ax = plt.subplots(2, 2, figsize=(16, 6))
ax = ax.reshape(-1)
# Make predictions
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    test_x = torch.linspace(0, 1, 151)
    predictions = model(test_x)
    dist_samples = predictions.rsample(torch.Size((1024,)))
    mean = predictions.mean
    lower, upper = predictions.confidence_region()
    
# This contains predictions for both tasks, flattened out
# The first half of the predictions is for the first task
# The second half is for the second task


for i in range(4):
    y1_ax = ax[i]
    # Plot training data as black stars
    y1_ax.plot(train_x.detach().numpy(), train_y[:, i].detach().numpy(), 'k*')
    # Predictive mean as blue line
    y1_ax.plot(test_x.numpy(), mean[:, i].numpy(), 'b')
    # Shade in confidence 
    y1_ax.fill_between(test_x.numpy(), lower[:, i].numpy(), upper[:, i].numpy(), alpha=0.5)
    y1_ax.set_ylim([-3, 3])
    y1_ax.legend(['Observed Data', 'Mean', 'Confidence'])
    y1_ax.set_title('Observed Values (Likelihood)')

## Decoupled Sampling

In [None]:
final_samples = model.posterior(test_x).rsample(torch.Size((1024,))).detach()

In [None]:
sampled_mean = final_samples.mean(dim=0)
sampled_std = final_samples.std(dim=0)
sampled_lower = sampled_mean - 2 * sampled_std
sampled_upper = sampled_mean + 2 * sampled_std

In [None]:
sampled_lower.shape

In [None]:
# Initialize plots
f, ax = plt.subplots(2, 2, figsize=(16, 6))
ax = ax.reshape(-1)

for i in range(4):
    y1_ax = ax[i]
    # Plot training data as black stars
    y1_ax.plot(train_x.detach().numpy(), train_y[:, i].detach().numpy(), 'k*')
    # Predictive mean as blue line
    y1_ax.plot(test_x.numpy(), mean[:, i].numpy(), 'b')
    # Shade in confidence 
    y1_ax.fill_between(test_x.numpy(), lower[:, i].numpy(), upper[:, i].numpy(), alpha=0.5)
    
    y1_ax.fill_between(test_x, sampled_lower[:,i], sampled_upper[:,i], color = "red", alpha = 0.4)
    y1_ax.plot(test_x, sampled_mean[:,i], color = "red")
    #for j in range(final_samples.shape[0]):
    #    y1_ax.plot(test_x.numpy(), final_samples[j, :, i], color = "red", alpha = 0.1)
    
    #y1_ax.set_ylim([-3, 3])
    y1_ax.legend(['Observed Data', 'Mean', 'Samples'])
    y1_ax.set_title('Observed Values (Likelihood)')  

In [None]:
fig, ax = plt.subplots(1, 2, figsize = (12, 5))

ax[0].scatter(sampled_std.reshape(-1), 
              predictions.variance.reshape(-1).detach()**0.5, label = "Matheron's Rule",
             alpha = 0.8)
ax[0].scatter(dist_samples.std(0).reshape(-1), predictions.variance.reshape(-1).detach()**0.5,
             label = "Distributional", alpha = 0.8)
ax[0].set_xlabel("Sampled Stddev")
ax[0].set_ylabel("True Stddev")
ax[0].legend(loc="upper left")

ax[1].scatter(sampled_mean.reshape(-1), predictions.mean.reshape(-1))
ax[1].scatter(dist_samples.mean(0).reshape(-1), predictions.mean.reshape(-1))

ax[1].set_xlabel("Sampled Mean")
ax[1].set_ylabel("True Mean")

ax[0].grid()
ax[1].grid()

In [None]:
plt.scatter(sampled_mean.reshape(-1), predictions.mean.reshape(-1))
plt.scatter(dist_samples.mean(0).reshape(-1), predictions.mean.reshape(-1))

plt.xlabel("Sampled Stddev")
plt.ylabel("True Stddev")

In [None]:
pred_std = predictions.variance.detach()**0.5
fig, ax = plt.subplots(2, 2, figsize = (16, 6))

ax = ax.reshape(-1)
for i in range(4):
    ax[i].plot(test_x, (sampled_std[:,i] - pred_std[:,i]))
    ax[i].plot(test_x, (dist_samples.std(0)[:,i] - pred_std[:,i]))

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
fig, ax = plt.subplots(1, 4, figsize = (24, 5))

ax = ax.reshape(-1)
for i in range(4):
    y1_ax = ax[i]
    # Plot training data as black stars
    y1_ax.plot(train_x.detach().numpy(), train_y[:, i].detach().numpy(), 'k*')
    # Predictive mean as blue line
    # Shade in confidence 
    
    y1_ax.fill_between(test_x, sampled_lower[:,i], sampled_upper[:,i], alpha = 0.4,
                      color = palette[2])
    y1_ax.plot(test_x, sampled_mean[:,i], color = palette[2])
    y1_ax.fill_between(test_x.numpy(), lower[:, i].numpy(), upper[:, i].numpy(), alpha=0.5, color = palette[4])

    y1_ax.plot(test_x.numpy(), mean[:, i].numpy(), color = palette[4], linewidth=3)

    #for j in range(final_samples.shape[0]):
    #    y1_ax.plot(test_x.numpy(), final_samples[j, :, i], color = "red", alpha = 0.1)
    
    #y1_ax.set_ylim([-3, 3])

        
    divider = make_axes_locatable(y1_ax)
    axHistx = divider.append_axes("bottom", size=1.2, pad=0.1, sharex=y1_ax)
    axHistx.plot(test_x, (sampled_std[:,i] - pred_std[:,i])/ pred_std[:,i], label = "Matheron")
    axHistx.plot(test_x, (dist_samples.std(0)[:,i] - pred_std[:,i]) / pred_std[:,i], label = "Distributional")
    axHistx.plot(test_x, torch.zeros_like(test_x), color = "black")
    if i == 0:
        y1_ax.legend(['Observed Data', 'True Conf. Region', 'Sampled Conf. Region'], 
                     fontsize = 18, loc = "upper center", 
                    bbox_to_anchor=(2.25, -0.72), ncol=3)
        axHistx.legend(fontsize=18, loc="upper center", ncol=2, bbox_to_anchor=(2.25, -0.8))
        axHistx.set_ylabel(r"$\frac{\hat \sigma - \sigma}{\sigma}$", fontsize = 18)
    
        y1_ax.set_ylabel("f(x)", fontsize = 24)
    axHistx.set_xlabel("x", fontsize = 24)
# plt.tight_layout()
plt.savefig("./gp_sampling_accuracy.pdf", bbox_inches="tight")