In [None]:
import os
import sys
if ".." not in sys.path:
    sys.path.append("..")
    
IN_PYTEST = "PYTEST_CURRENT_TEST" in os.environ

In [None]:
import torch

from vqr.api import VectorQuantileEstimator, VectorQuantileRegressor
from experiments.datasets.cond_banana import ConditionalBananaDataProvider
from vqr.solvers.regularized_lse import (
    RegularizedDualVQRSolver,
    MLPRegularizedDualVQRSolver,
)
from experiments.utils.metrics import kde, kde_l1, w2_pot, w2_keops
from numpy import array
from torch import tensor
from matplotlib import pyplot as plt
from matplotlib import cm

# Dataset

We generate a conditional banana dataset.

In [None]:
n = 10000
d = 2
k = 1
T = 25
num_epochs = 5000
linear = False
sigma = 0.1
GPU_DEVICE_NUM = 0
device = f"cuda:{GPU_DEVICE_NUM}" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
epsilon = 5e-3

if IN_PYTEST:
    n = 2500
    T = 25
    num_epochs = 500
    
data_provider = ConditionalBananaDataProvider(k=k, d=d, nonlinear=True)
X, Y = data_provider.sample(n=n)

# Solver

We create VQR and NL-VQR solvers.

In [None]:
# RVQR that solves Equation 7 in the paper
linear_solver = RegularizedDualVQRSolver(
    verbose=True,
    T=T,
    num_epochs=num_epochs,
    epsilon=epsilon,
    lr=2.9,
    gpu=True,
    full_precision=False,
    device_num=GPU_DEVICE_NUM,
    batchsize_y=None,
    batchsize_u=None,
    inference_batch_size=100,
    lr_factor=0.9,
    lr_patience=500,
    lr_threshold=0.5 * 0.01,
)


# NL-VQR solver that solves Equation 9 in the paper.
# g_\theta is chosen to be a small MLP with three layers
# Can also use a custom neural net, by using RegularizedDualVQRSolver with nn_init argument.
nonlinear_solver = MLPRegularizedDualVQRSolver(
    verbose=True,
    T=T,
    num_epochs=num_epochs,
    epsilon=epsilon,
    lr=0.4,
    gpu=True,
    skip=False,
    batchnorm=False,
    hidden_layers=(2, 10, 20),
    device_num=GPU_DEVICE_NUM,
    batchsize_y=None,
    batchsize_u=None,
    inference_batch_size=100,
    lr_factor=0.9,
    lr_patience=300,
    lr_threshold=0.5 * 0.01,
)

In [None]:
linear_vqr_est = VectorQuantileRegressor(solver=linear_solver)

linear_vqr_est.fit(X, Y)

In [None]:
nonlinear_vqr_est = VectorQuantileRegressor(solver=nonlinear_solver)

nonlinear_vqr_est.fit(X, Y)

# Sampling

We can now, for e.g., sample from the fitted VQR and NL-VQR models for a given X.

In [None]:
n_test = 4000

# Generate conditional distributions for the below X's
X_test = array([[2.0]])


# Sample the ground-truth conditional distribution for X_test
_, cond_Y_gt = data_provider.sample(n=n_test, x=X_test)
cond_Y_gt = tensor(cond_Y_gt, dtype=dtype)
kde_gt = kde(
    cond_Y_gt,
    grid_resolution=100,
    device=device,
    sigma=sigma,
)

# Sample from the estimated conditional distribution from VQR
vqr_cond_Y_est = tensor(linear_vqr_est.sample(n=n_test, x=X_test), dtype=dtype)
kde_est_vqr = kde(
    vqr_cond_Y_est,
    grid_resolution=100,
    device=device,
    sigma=sigma,
)

# Sample from the estimated conditional distribution from VQR
nlvqr_cond_Y_est = tensor(nonlinear_vqr_est.sample(n=n_test, x=X_test), dtype=dtype)
kde_est_nlvqr = kde(
    nlvqr_cond_Y_est,
    grid_resolution=100,
    device=device,
    sigma=sigma,
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for j in range(axes.shape[0]):
    axes[j].set_xticklabels([])
    axes[j].set_yticklabels([])
axes[0].imshow(
        kde_gt.T,
        interpolation="bilinear",
        origin="lower",
        cmap=cm.RdPu,
        extent=(0, 1, 0, 1),
    )
axes[0].set_title("Groundtruth")
axes[1].imshow(
        kde_est_vqr.T,
        interpolation="bilinear",
        origin="lower",
        cmap=cm.RdPu,
        extent=(0, 1, 0, 1),
    )
axes[1].set_title("VQR")
axes[2].imshow(
        kde_est_nlvqr.T,
        interpolation="bilinear",
        origin="lower",
        cmap=cm.RdPu,
        extent=(0, 1, 0, 1),
    )
axes[2].set_title("NL-VQR");