In [None]:

import jax
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ortho_group
from ima.ima.mixing_functions import build_moebius_transform
from ima.ima.plotting import cart2pol
from gp_ima.ima import C_ima_digamma, C_ima_sample
from jax import numpy as jnp
import GPy
from tueplots import bundles, figsizes

In [None]:
import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [None]:
from analysis import plot_typography, estimate2uniform

In [None]:
USETEX = True

In [None]:
plt.rcParams.update(bundles.neurips2022(usetex=USETEX))
plt.rcParams.update({
    'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
                            r'\usepackage{amsmath}'] # boldsymbol
})

In [None]:
plot_typography(usetex=USETEX, small=12, medium=16, big=20)

In [None]:

plt.ion(); plt.style.use('seaborn-pastel')
np.random.seed(42)

In [None]:
NUM_DATA = 2500
LATENT_DIM = OBS_DIM = 2



In [None]:

Z = np.random.uniform(-0.5, 0.5, size=(NUM_DATA, LATENT_DIM))
_, c = cart2pol(Z[:, 1], Z[:, 0])

# plt.scatter(Z[:, 0], Z[:, 1], c=c)

A = ortho_group.rvs(dim=OBS_DIM)
A = jax.numpy.array(A)

alpha = 1.0
a = []
while len(a) < OBS_DIM:
    s = np.random.randn()
    if np.abs(s) > 0.5:
        a = a + [s]
a = jax.numpy.array(a)
b = jax.numpy.zeros(OBS_DIM)

mixing, _ = build_moebius_transform(alpha, A, a, b, epsilon=2)

X = jax.vmap(mixing)(Z)
plt.scatter(X[:, 0], X[:, 1], c=c)

In [None]:


kernel = GPy.kern.RBF(2, ARD=False) + GPy.kern.Bias(2)
m = GPy.models.BayesianGPLVM(np.asarray(X), 2, kernel=kernel, num_inducing=100)
m.optimize_restarts(3, optimizer='lbfgs')

In [None]:
NUM_SAMPLES_C_IMA = 1000
C_IMA = [C_ima_sample(m) for _ in range(NUM_SAMPLES_C_IMA)]

In [None]:
Zest = m.X.mean
Zest_uni_cima = estimate2uniform(Zest)



In [None]:


_, cest = cart2pol(Zest[:, 0], Zest[:, 1])
plt.scatter(Zest[:, 0], Zest[:, 1], c=cest)

In [None]:


_, cest = cart2pol(Zest_uni_cima[:, 0], Zest_uni_cima[:, 1])
plt.scatter(Zest_uni_cima[:, 0], Zest_uni_cima[:, 1], c=cest)

In [None]:
plt.plot(C_IMA)

In [None]:
Ds = np.logspace(0, 3, 1000).astype(int)
plt.plot(Ds, [C_ima_digamma(max(1, int(np.log(D))), D) for D in Ds])