# Data Aggregation experiment

- applies TreeDSBM on the bike dataset example from Korotin et al. 2021.

In [1]:
import os  # before importing anything jax

# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
os.environ['CUDA_VISIBLE_DEVICES']='4'

import sys
sys.path.append("..")

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from functools import partial

from tqdm import trange
from omegaconf import OmegaConf

from models import ScoreMLP, BasicModel

from run_BarycentreDSBM import BarycentreDSBM

## define distributions
- follows the preprocessing used in Korotin et al. 2021.

In [2]:
scale = jnp.sqrt(1e7)

dataset_1 = jnp.load('../data/bike_posterior/samples_0_0.npy')[:, 1:]
dataset_1 -= dataset_1.mean(axis=0)
dataset_1 *= scale

dataset_2 = jnp.load('../data/bike_posterior/samples_0_1.npy')[:, 1:]
dataset_2 -= dataset_2.mean(axis=0)
dataset_2 *= scale

dataset_3 = jnp.load('../data/bike_posterior/samples_0_2.npy')[:, 1:]
dataset_3 -= dataset_3.mean(axis=0)
dataset_3 *= scale

dataset_4 = jnp.load('../data/bike_posterior/samples_0_3.npy')[:, 1:]
dataset_4 -= dataset_4.mean(axis=0)
dataset_4 *= scale

dataset_5 = jnp.load('../data/bike_posterior/samples_0_4.npy')[:, 1:]
dataset_5 -= dataset_5.mean(axis=0)
dataset_5 *= scale

ground_truth = jnp.load('../data/bike_posterior/samples_0_all.npy')[:, 1:]
ground_truth -= ground_truth.mean(axis=0)
ground_truth *= scale

class DatasetDist:
    def __init__(self, dataset):
        self.dataset = dataset

    @partial(jax.jit, static_argnums=(0,2))
    def sample(self, key, num_samples):
        return jax.random.choice(key, self.dataset, shape=(num_samples,), replace=True)
    
class t_Dist:
    def sample(self, key, num_samples):
        raise NotImplementedError
    
class UniformDist(t_Dist):
    def sample(self, key, num_samples):
        return jax.random.uniform(key, (num_samples,), minval=0.001, maxval=1.0-0.001)

dataset_1.shape, dataset_2.shape, dataset_3.shape, dataset_4.shape, dataset_5.shape, ground_truth.shape

((100000, 8), (100000, 8), (100000, 8), (100000, 8), (100000, 8), (100000, 8))

## define the problem, model

In [3]:
d = 8
shape = (d,)
N = 5

epsilon = 0.001
sigma = jnp.sqrt(epsilon / 2)   # convert from epsilon to sigma

# define the fixed marginals
mu_0 = DatasetDist(dataset_1)
mu_1 = DatasetDist(dataset_2)
mu_2 = DatasetDist(dataset_3)
mu_3 = DatasetDist(dataset_4)
mu_4 = DatasetDist(dataset_5)

mu_lst = [mu_0, mu_1, mu_2, mu_3, mu_4]
weights = jnp.ones((N,)) / N
weights = weights / jnp.sum(weights)

model = BasicModel(out_dim=d, d=d)

In [4]:
baryDSBM = BarycentreDSBM(
    mu_lst=mu_lst,
    sigma=sigma,
    shape=shape,
    model=model,
)

train_config = OmegaConf.create({
    'num_IMF_steps': 4,
    'num_sampling_steps': 100,
    'num_training_steps': 10_000,
    'reflow_num_training_steps': None, # number of training steps for reflow (could be lower if desired)
    'num_training_samples': 8192,  # number of samples to simulate for subsequent IMF iterations
    'lr': 1e-3,
    'batch_size': 4096,
    'simulation_batch_size': None, # if num_training_samples is too large, set this to be smaller to simulate in batches
    'ema_rate': 0.01,
    'simultaneous_training': True, # True, False,
    'warmstart': False, # True, False, # whether to warmstart the model with the params from the first iteration
})

key = jax.random.PRNGKey(0)
all_states_lst, all_bms_lst = baryDSBM.train(key, train_config=train_config, model=model)

Running IMF step 1


Training: 100%|██████████| 10000/10000 [01:38<00:00, 101.22step/s, loss=13.5]


Running IMF step 2


Training: 100%|██████████| 10000/10000 [01:35<00:00, 104.88step/s, loss=0.698]


Running IMF step 3


Training: 100%|██████████| 10000/10000 [01:33<00:00, 106.62step/s, loss=0.638]


Running IMF step 4


Training: 100%|██████████| 10000/10000 [01:32<00:00, 107.80step/s, loss=0.643]


## Evaluation
- using the BW2-UVP metric

In [5]:
# from Korotin et al. (2023), https://github.com/iamalexkorotin/Wasserstein2Barycenters

import numpy as np
import scipy.linalg as ln

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = ln.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

In [6]:
num_steps = 50
num_samples = 100_000

def get_UVP_along_edge(key, state, bm):

    drift_fn = bm.get_drift_fn(state, use_ema_params=True, fwd=True)
    _, nu_samples = bm.sample(key, drift_fn, num_samples, num_steps, fwd=True)

    bary_samples_cov = np.cov(nu_samples.T)
    bary_samples_mean = np.mean(nu_samples, axis=0)

    ground_truth_mean = np.mean(ground_truth, axis=0)
    ground_truth_cov = np.cov(ground_truth.T)
    ground_truth_var = np.trace(ground_truth_cov)

    UVP = 100 * calculate_frechet_distance(
                bary_samples_mean, bary_samples_cov,
                ground_truth_mean, ground_truth_cov,
            ) / ground_truth_var
    
    return UVP

def get_UVPs_from_run(key, states_lst, bm_lst):

    run_UVP = []
    for i in range(len(states_lst)):
        state = states_lst[i]
        bm = bm_lst[i]
        UVP = get_UVP_along_edge(key, state, bm)
        run_UVP.append(UVP)
    return jnp.array(run_UVP)

In [7]:
key = jax.random.PRNGKey(0)

IMF_idx = -1  # use the last IMF
states_lst = all_states_lst[IMF_idx]
bm_lst = all_bms_lst[IMF_idx]

UVPs = get_UVPs_from_run(key, states_lst, bm_lst)

print("UVPs for each edge:", UVPs)
print("Average UVP:", jnp.mean(UVPs))

  covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)


UVPs for each edge: [0.01108639 0.01151861 0.00821571 0.01161379 0.01200999]
Average UVP: 0.010888897
