In [None]:
import numpy as np
import mesa
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns
from scipy.stats import zscore

from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF

%load_ext autoreload
%autoreload 2
%load_ext line_profiler

## Social information

In [None]:
# make the binary feature matrix
grid_size = 100
margin = 20
x_plot = np.meshgrid(range(grid_size+margin), range(grid_size+margin))
print(np.array(x_plot).shape)
s = np.zeros((grid_size + margin, grid_size + margin))

# set seed
np.random.seed(2)

# add some random 1s
n_other_agents = 5
for i in range(n_other_agents):
    ai = np.random.randint(margin//2, grid_size+margin//2)
    aj = np.random.randint(margin//2, grid_size+margin//2)
    print((ai, aj))
    s[ai, aj] = 1


In [None]:
def plot_belief_style(_ax):
    # add border
    for _, spine in _ax.spines.items():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1)
    # ax.xaxis.set_ticks_position('top')
    # remove ticks
    _ax.set_xticks([])
    _ax.set_yticks([])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
cmap = 'binary'
sns.heatmap(s, ax=ax, cmap=cmap, cbar=False, square=True, vmin=0, vmax=1)
plot_belief_style(ax)
plt.show()


In [None]:
x = np.meshgrid(range(grid_size+margin), range(grid_size+margin))
# X = np.array(X).reshape(2, -1).T
# y = s.flatten()

# select only the 1s
X = np.array(x).reshape(2, -1).T[s.flatten() == 1]
y = np.ones(X.shape[0])
print(X.shape, y.shape)

# add regular grid of every tenth mesh point with its value if not already in X
X = np.vstack([X, np.array(x)[:, ::20, ::20].reshape(2, -1).T])
y = np.hstack([y, np.zeros(X.shape[0] - y.shape[0])])
print(X.shape, y.shape)

# to reconstruct s
# s_reconstructed = y.reshape(50, 50)

In [None]:
# plot selected points
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
cmap = 'binary'
sns.heatmap(s, ax=ax, cmap=cmap, cbar=False, square=True, vmin=0, vmax=1)
plot_belief_style(ax)
ax.scatter(X[:, 0], X[:, 1], color='red', s=10, alpha=0.5)
plt.show()


In [None]:
gpc = GaussianProcessClassifier(kernel=RBF(12), random_state=0, optimizer=None)

In [None]:
%%time

gpc.fit(X, y)

In [None]:
%%time
x_plot = np.meshgrid(range(margin//2, grid_size + margin//2), range(margin//2, grid_size + margin//2))
s_prob = gpc.predict_proba(np.array(x_plot).reshape(2, -1).T)

In [None]:
from ice_fishing_abm_gp.belief import generate_belief_matrix, construct_dataset_info

In [None]:
# %prun gpc.fit(X, y)
# s_prob = gpc.predict_proba(X)

In [None]:
agent_locs = np.array(np.meshgrid(range(grid_size+margin), range(grid_size+margin))).reshape(2, -1).T[s.flatten() == 1] - margin // 2

In [None]:
%%time
X_new, y_new = construct_dataset_info(100, 20, agent_locs, step_size=20)

In [None]:
gpc = GaussianProcessClassifier(kernel=RBF(10), random_state=0, optimizer=None)

In [None]:
%%time
# s_prob_plot = s_prob[:, 1].reshape(grid_size, grid_size)
s_prob_plot = generate_belief_matrix(grid_size, margin, X_new, y_new, gpc)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
cmap = 'viridis'

sns.heatmap(s_prob_plot, ax=ax, cmap=cmap, cbar=False, square=True)
# plot_belief_style(ax)
ax.scatter(agent_locs[:, 0], agent_locs[:, 1], color='red', s=10, alpha=0.5, marker='x')
ax.scatter(X_new[y_new==0, 0] - margin // 2, X_new[y_new==0, 1] - margin // 2, color='white', s=10, alpha=0.5, marker='o')
plt.show()

In [None]:
agent_locs

## Individual features: success and loss

# GPJax

In [None]:
# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)

from time import time
import blackjax
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import jax.tree_util as jtu
from jaxtyping import (
    Array,
    Float,
    install_import_hook,
)
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax as tfp
from tqdm import trange

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx

tfd = tfp.distributions
identity_matrix = jnp.eye
key = jr.key(123)
plt.style.use(
    "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)


import cola
from gpjax.lower_cholesky import lower_cholesky

In [None]:
D = gpx.Dataset(X=X, y=y)
kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=grid_size**2)

In [None]:
posterior = prior * likelihood
print(type(posterior))

In [None]:
negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True))

optimiser = ox.adam(learning_rate=0.01)

opt_posterior, history = gpx.fit(
    model=posterior,
    objective=negative_lpd,
    train_data=y,
    optim=ox.adamw(learning_rate=0.01),
    num_iters=1000,
    key=key,
)

In [None]:




def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL:
    map_latent_dist = opt_posterior.predict(xtest, train_data=D)

    Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
    Kxx = opt_posterior.prior.kernel.gram(x)
    Kxx += identity_matrix(D.n) * jitter
    Kxx = cola.PSD(Kxx)

    # Kxx⁻¹ Kxt
    Kxx_inv_Kxt = cola.solve(Kxx, Kxt)

    # Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
    laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)

    mean = map_latent_dist.mean()
    covariance = map_latent_dist.covariance() + laplace_cov_term
    L = jnp.linalg.cholesky(covariance)
    return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L)