In [None]:
import alsm
from matplotlib import pyplot as plt
import matplotlib as mpl
import nest_asyncio
import numpy as np
import os
import stan
from scipy.linalg import orthogonal_procrustes


mpl.rcParams['figure.dpi'] = 144
nest_asyncio.apply()

GROUP_SIMULATION = int(os.environ.get('GROUP_SIMULATION', 0))

# Choose whether to simulate the detailed connections or sample from the negative binomial directly.
# This is useful for testing whether the negative-binomial approximation substantially affects the
# inference.
if GROUP_SIMULATION:
    generator = alsm.generate_group_data
else:
    generator = alsm.generate_data

In [None]:
# Generate a network.
seed = 1
num_groups = 10
num_dims = 2

np.random.seed(seed)
group_sizes = np.random.poisson(100, num_groups)
data = generator(
    group_sizes,
    num_dims,
    group_scales=np.random.gamma(10, 1 / 15, num_groups),
    population_scale=3,
    propensity=0.1,
)

# Plot the detailed network if it was generated.
if 'locs' in data:
    fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
    ax1.scatter(*data['locs'].T, c=data['group_idx'], cmap='tab10', marker='.')
    alsm.plot_edges(data['locs'], data['adjacency'], ax=ax1, alpha=.2, zorder=0)
    ax1.set_aspect('equal')
else:
    fig, ax2 = plt.subplots()

# Plot the aggregate network and the radius of clusters.
pts = ax2.scatter(*data['group_locs'].T, c=np.arange(num_groups), cmap='tab10')
alsm.plot_edges(data['group_locs'], data['group_adjacency'], ax=ax2, zorder=0, alpha_min=.1)
ax2.set_aspect('equal')

plt.draw()
for color, group_loc, group_scale in zip(pts.get_facecolors(), data['group_locs'], data['group_scales']):
    circle = mpl.patches.Circle(group_loc, 2 * group_scale, color=color, alpha=.25)
    ax2.add_patch(circle)

ax2.autoscale_view()

print(f'mean degree: {data["group_adjacency"].sum() / data["num_nodes"]:.3f}')

In [None]:
# Fit the model.
data['epsilon'] = 1e-12
posterior = stan.build(alsm.model.GROUP_MODEL, data=data, random_seed=seed)
fit = posterior.sample(num_chains=4, num_warmup=1000, num_samples=200)

In [None]:
# Get the best chain and align the samples with one another.
chain = alsm.get_chain(fit, 'best')
samples = np.rollaxis(chain['group_locs'], -1)
aligned = alsm.align_samples(samples)

# Align the samples to the reference data.
reference = data['group_locs'] - data['group_locs'].mean(axis=0)
transform, _ = orthogonal_procrustes(np.mean(aligned, axis=0), reference)
aligned = aligned @ transform

# Show the scatter plot.
fig, ax = plt.subplots()
ax.scatter(*aligned.T, cmap='tab10', marker='.', alpha=.1, label='posterior samples',
           c=np.arange(num_groups)[:, None] * np.ones(aligned.shape[0]))
pts = ax.scatter(*reference.T, c=np.arange(num_groups), marker='X', cmap='tab10', label='reference')
pts.set_edgecolor('w')
ax.set_aspect('equal')
ax.legend(fontsize='small')

fig.tight_layout()

In [None]:
# Compare group scales and the propensity.
fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.hist(chain['propensity'], density=True)
ax1.axvline(data['propensity'], color='k', ls=':')
ax1.set_xlabel(r'Propensity $\alpha$')
ax1.set_ylabel(r'Posterior $P(\alpha)$')

l, m, u = np.percentile(chain['group_scales'], [2.5, 50, 97.5], axis=-1)
x = data['group_scales']
lims = x.min(), x.max()
ax2.plot(lims, lims, color='k', ls=':')
ax2.errorbar(x, m, (m - l, u - m), color='silver', ls='none')
ax2.scatter(x, m, c=np.arange(num_groups), cmap='tab10', zorder=2)
ax2.set_aspect('equal')
ax2.set_xlabel(r'Group scales $\sigma$')
ax2.set_ylabel(r'Inferred group scales')

fig.tight_layout()

In [None]:
# Show the posterior predictive replication.
fig, ax = plt.subplots()

x = data['group_adjacency'].ravel()
ys = chain['ppd_group_adjacency'].reshape(x.shape + (-1,))
l, m, u = np.percentile(ys, [2.5, 50, 97.5], axis=-1)
lims = x.min(), x.max()
ax.plot(lims, lims, color='k', ls=':')
ax.errorbar(x, m, (m - l, u - m), ls='none', marker='.')
ax.set_aspect('equal')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Group adjacency $Y$')
ax.set_ylabel('Group adjacency posterior replicates')
fig.tight_layout()