# High-dimensional Gaussian experiment
- Find the barycentre of 3 Gaussian marginals
- Parameters for Gaussian marginals were generated using the generation procedure from Kolesov et al. (2023)

In [1]:
import os  # before importing anything jax

# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
os.environ['CUDA_VISIBLE_DEVICES']='4'

import sys
sys.path.append("..")

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

from tqdm import trange
from omegaconf import OmegaConf

from models import ScoreMLP, BasicModel

from run_BarycentreDSBM import BarycentreDSBM

## define distributions

In [2]:
def symmetrize(M):
    """Symmetrizes a matrix."""
    return np.real((M + M.T) / 2)

def sqrtm_jax(matrix, eps=1e-12):
    eigvals, eigvecs = jnp.linalg.eigh(matrix)
    sqrt_eigvals = jnp.sqrt(jnp.clip(eigvals, eps))
    return eigvecs @ jnp.diag(sqrt_eigvals) @ eigvecs.T

class Gaussian:
    def __init__(self, shape, mean=0.0, std=1.0):
        self.shape = shape
        self.mean = mean * jnp.ones(shape)
        self.std = std

    def sample(self, key, num_samples):
        return jax.random.normal(key, (num_samples,) + self.shape) * self.std + self.mean
    
class GaussianCov:
    def __init__(self, shape, mean, cov):
        self.shape = shape
        self.mean = mean
        self.cov = cov
        # self.cov_chol = jnp.linalg.cholesky(cov)
        self.weight = sqrtm_jax(cov)

    def sample(self, key, num_samples):
        return jax.random.normal(key, (num_samples,) + self.shape) @ self.weight.T + self.mean
        # return jax.random.normal(key, (num_samples,) + self.shape) @ self.cov_chol.T + self.mean
    
class t_Dist:
    def sample(self, key, num_samples):
        raise NotImplementedError
    
class UniformDist(t_Dist):
    def sample(self, key, num_samples):
        return jax.random.uniform(key, (num_samples,), minval=0.001, maxval=1.0-0.001)

## define the problem, model

In [3]:
d = 64 # 96, 128
shape = (d,)
N = 3

import json

# Load the JSON file
with open(f"../data/gaussian/sampler_stats_{d}.json", "r") as f:
    loaded_gaussian_params = json.load(f)

# Optionally convert lists back to NumPy arrays
for key, value in loaded_gaussian_params.items():
    value["cov"] = jnp.array(value["cov"])
    value["mean"] = jnp.array(value["mean"])

epsilon = 0.0001
sigma = jnp.sqrt(epsilon / 2)   # convert from epsilon to sigma


mu_lst = []
for i in range(N):
    mu = GaussianCov(shape=shape, mean=loaded_gaussian_params[f"sampler_{i}"]["mean"], cov=loaded_gaussian_params[f"sampler_{i}"]["cov"])
    mu_lst.append(mu)

ground_truth = GaussianCov(shape=shape, mean=loaded_gaussian_params["ground_truth_sampler"]["mean"], cov=loaded_gaussian_params["ground_truth_sampler"]["cov"])

weights = jnp.ones((N,)) / N
weights = weights / jnp.sum(weights)

model = BasicModel(out_dim=d, d=d)

In [4]:
baryDSBM = BarycentreDSBM(
    mu_lst=mu_lst,
    sigma=sigma,
    shape=shape,
    model=model,
)

train_config = OmegaConf.create({
    'num_IMF_steps': 4,
    'num_sampling_steps': 50,
    'num_training_steps': 10_000,
    'reflow_num_training_steps': None, # number of training steps for reflow (could be lower if desired)
    'num_training_samples': 50_000,  # number of samples to simulate for subsequent IMF iterations
    'lr': 1e-3,
    'batch_size': 4096,
    'simulation_batch_size': 10_000, # if num_training_samples is too large, set this to be smaller to simulate in batches
    'ema_rate': 0.01,
    'simultaneous_training': True, # True, False,
    'warmstart': False, # True, False, # whether to warmstart the model with the params from the first iteration
})

key = jax.random.PRNGKey(0)
all_states_lst, all_bms_lst = baryDSBM.train(key, train_config=train_config, model=model)

Running IMF step 1


Training: 100%|██████████| 10000/10000 [01:13<00:00, 135.14step/s, loss=34.9]


Running IMF step 2


Training: 100%|██████████| 10000/10000 [01:10<00:00, 140.92step/s, loss=0.288]


Running IMF step 3


Training: 100%|██████████| 10000/10000 [01:10<00:00, 142.01step/s, loss=0.219]


Running IMF step 4


Training: 100%|██████████| 10000/10000 [01:08<00:00, 145.85step/s, loss=0.224]


## Evaluation
- Using the BW2-UVP and L2-UVP metrics

In [5]:
# from Korotin et al. (2023), https://github.com/iamalexkorotin/Wasserstein2Barycenters

import numpy as np
import scipy.linalg as ln

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = ln.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

# get ground-truth maps
# based on Korotin et al. (2022)

def get_map_to_barycentre_from_params(mean_leaf, cov_leaf, mean_bary, cov_bary):
    root_cov_leaf = sqrtm_jax(cov_leaf)
    inv_root_cov_leaf = jnp.linalg.inv(root_cov_leaf)
    middle = root_cov_leaf @ cov_bary @ root_cov_leaf
    sqrt_middle = sqrtm_jax(middle)
    weight = inv_root_cov_leaf @ sqrt_middle @ inv_root_cov_leaf
    bias = mean_bary - weight @ mean_leaf  # if both are zero, bias = 0

    def map_to_barycentre(x):
        return x @ weight.T + bias
    
    return map_to_barycentre, weight, bias

maps_to_barycentre = []
for i in range(N):
    mean_leaf = np.array(mu_lst[i].mean, dtype=np.float64)
    cov_leaf = np.array(mu_lst[i].cov, dtype=np.float64)
    mean_bary = np.array(ground_truth.mean, dtype=np.float64)
    cov_bary = np.array(ground_truth.cov, dtype=np.float64)

    map_to_barycentre,_,_ = get_map_to_barycentre_from_params(mean_leaf, cov_leaf, mean_bary, cov_bary)
    maps_to_barycentre.append(map_to_barycentre)

In [11]:
num_steps = 50
num_samples = 100_000

def get_UVP_along_edge(key, state, bm, edge_idx):

    drift_fn = bm.get_drift_fn(state, use_ema_params=True, fwd=True)
    traj, nu_samples = bm.sample(key, drift_fn, num_samples, num_steps, fwd=True)

    # BW2-UVP
    bary_samples_cov = jnp.cov(nu_samples.T)
    bary_samples_mean = jnp.mean(nu_samples, axis=0)

    ground_truth_mean = ground_truth.mean
    ground_truth_cov = ground_truth.cov
    ground_truth_var = jnp.trace(ground_truth_cov)

    BW2_UVP = 100 * calculate_frechet_distance(
                bary_samples_mean, bary_samples_cov,
                ground_truth_mean, ground_truth_cov,
            ) / ground_truth_var
    
    # L2-UVP
    ground_truth_transported_samples = maps_to_barycentre[edge_idx](traj[:,0])
    diffs = ground_truth_transported_samples - nu_samples
    L2_UVP = 100 * (jnp.mean(jnp.sum(diffs**2, axis=1))) / jnp.trace(ground_truth.cov)
    
    return BW2_UVP, L2_UVP

def get_UVPs_from_run(key, states_lst, bm_lst):

    run_BW2_UVP = []
    run_L2_UVP = []
    for i in range(len(states_lst)):
        state = states_lst[i]
        bm = bm_lst[i]
        BW2_UVP, L2_UVP = get_UVP_along_edge(key, state, bm, i)
        run_BW2_UVP.append(BW2_UVP)
        run_L2_UVP.append(L2_UVP)
    return jnp.array(run_BW2_UVP), jnp.array(run_L2_UVP)

In [12]:
key = jax.random.PRNGKey(0)

IMF_idx = -1  # use the last IMF
states_lst = all_states_lst[IMF_idx]
bm_lst = all_bms_lst[IMF_idx]

BW2_UVPs, L2_UVPs = get_UVPs_from_run(key, states_lst, bm_lst)

print("BW2_UVPs for each edge:", BW2_UVPs)
print("Average BW2_UVP:", jnp.mean(BW2_UVPs))

print("L2_UVPs for each edge:", L2_UVPs)
print("Average L2_UVP:", jnp.mean(L2_UVPs))

  covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)


BW2_UVPs for each edge: [0.12454717 0.13369545 0.13783808]
Average BW2_UVP: 0.1320269
L2_UVPs for each edge: [1.1490363 1.1663812 1.1848252]
Average L2_UVP: 1.1667476
