In [None]:
import alsm
from matplotlib import pyplot as plt
import matplotlib as mpl
import nest_asyncio
import numpy as np
import stan
from scipy.linalg import orthogonal_procrustes


mpl.rcParams['figure.dpi'] = 144
nest_asyncio.apply()

In [None]:
# Generate a network.
seed = 2  # Good but no arcs.
seed = 14  # Reasonable example but very overlapping groups.
seed = 20  # Good but very wiggly.
seed = 24  # Good but no arcs.
seed = 31  # Good example, including arcs.
seed = 7
num_groups = 10
num_dims = 2

np.random.seed(seed)
group_sizes = np.random.poisson(100, num_groups)
data = alsm.generate_data(
    group_sizes,
    num_dims,
    group_scales=np.random.gamma(3, 1 / 5, num_groups),
    population_scale=2.5,
    propensity=0.1,
)

# Plot the detailed network.
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
ax1.scatter(*data['locs'].T, c=data['group_idx'], cmap='tab10', marker='.')
alsm.plot_edges(data['locs'], data['adjacency'], ax=ax1, alpha=.2, zorder=0)
ax1.set_aspect('equal')

# Plot the aggregate network and the radius of clusters.
pts = ax2.scatter(*data['group_locs'].T, c=np.arange(num_groups), cmap='tab10')
alsm.plot_edges(data['group_locs'], data['group_adjacency'], ax=ax2, zorder=0, alpha_min=.1)
ax2.set_aspect('equal')

plt.draw()
for color, group_loc, group_scale in zip(pts.get_facecolors(), data['group_locs'], data['group_scales']):
    circle = mpl.patches.Circle(group_loc, 2 * group_scale, color=color, alpha=.25)
    ax2.add_patch(circle)

ax2.autoscale_view()

print(f'mean degree: {data["group_adjacency"].sum() / data["num_nodes"]:.3f}')
print(f'ratio of largest to smallest scale: {data["group_scales"].max() / data["group_scales"].min():.3f}')

In [None]:
# Fit the model.
data['epsilon'] = 1e-12
posterior = stan.build(alsm.get_group_model_code(), data=data, random_seed=seed)
fit = posterior.sample(num_chains=8, num_warmup=1000, num_samples=1000, init_radius=1e-2)

In [None]:
# Go through each chain and evaluate the log probability as well as the alignment score with the 
# original data.

median_losses = []
median_lps = []
for i in range(fit.num_chains):
    chain = alsm.get_chain(fit, i)
    median_lp = np.median(chain['lp__'])
    
    # Align the samples.
    samples = np.rollaxis(chain['group_locs'], -1)
    aligned = alsm.align_samples(samples)

    # Align the samples to the reference data.
    reference = data['group_locs'] - data['group_locs'].mean(axis=0)
    transform, _ = orthogonal_procrustes(np.mean(aligned, axis=0), reference)
    aligned = aligned @ transform

    # Compute the median loss.
    median_loss = np.median((aligned - reference) ** 2)

    print(f'chain {i}; median lp: {median_lp:.3f}; median loss: {median_loss:.3f}')

    median_losses.append(median_loss)
    median_lps.append(median_lp)

# Select the best aligned chain that's within one unit of the highest lp chain (there's a lot of 
# noise in the lps and the best aligned solution may randomly have a low-ish lp).
median_losses = np.asarray(median_losses)
median_lps = np.asarray(median_lps)
best_chain = np.argmin(median_losses + 1e9 * (median_lps < np.max(median_lps) - 1))
chain = alsm.get_chain(fit, best_chain)
best_chain

In [None]:
angle = 80
rotation = alsm.evaluate_rotation_matrix(np.deg2rad(angle))

locs = data['locs'] @ rotation
samples = np.rollaxis(chain['group_locs'], -1) @ rotation
reference = data['group_locs'] @ rotation

locs = locs[..., ::-1]
samples = samples[..., ::-1]
reference = reference[..., ::-1]

fig = plt.figure(figsize=(6, 5))
gs = fig.add_gridspec(2, 2, height_ratios=[1, 1])

ax1 = fig.add_subplot(gs[0, 0])
ax1.set_aspect('equal')
alsm.plot_edges(locs, data['adjacency'], zorder=0, ax=ax1)
ax1.scatter(*locs.T, c=data['group_idx'], cmap='tab10', marker='.')
ax1.set_xlabel('Embedding $z_1$')
ax1.set_ylabel('Embedding $z_2$')

# Align the samples with one another and then with the reference.
samples = alsm.align_samples(samples)
transform, _ = orthogonal_procrustes(samples.mean(axis=0), reference - reference.mean(axis=0))
samples = (samples @ transform) + reference.mean(axis=0)

ax2 = fig.add_subplot(gs[0, 1], sharex=ax1, sharey=ax1)
alsm.plot_edges(samples.mean(axis=0), data['group_adjacency'] ** .5, zorder=0, ax=ax2, lw=3)
c = np.arange(data['num_groups'])[:, None] * np.ones(samples.shape[0])
ax2.scatter(*samples.T, c=c, cmap='tab10', marker='.', alpha=.05)

# Show the scales.
factor = 2
for i, (xy, radius) in enumerate(zip(samples.mean(axis=0), chain['group_scales'].mean(axis=-1))):
    circle = mpl.patches.Circle(xy, factor * radius, edgecolor=f'C{i}', facecolor='none')
    ax2.add_patch(circle)

ax2.set_aspect('equal')
ax2.set_xlabel(r'Embedding $z_1$')
plt.setp(ax2.yaxis.get_ticklabels(), visible=False)

# Show the propensity plot.
ax3 = fig.add_subplot(gs[1, 0])
ax3.hist(chain['propensity'], density=True, bins=15)
ax3.axvline(data['propensity'], color='k', ls=':')
ax3.set_xlabel(r'Propensity $\alpha$')
ax3.set_ylabel(r'Posterior $P(\alpha)$')

# Show the scales.
ax4 = fig.add_subplot(gs[1, 1])
l, m, u = np.percentile(chain['group_scales'], [2.5, 50, 97.5], axis=-1)
x = data['group_scales']
lims = x.min(), x.max()
ax4.plot(lims, lims, color='k', ls=':')
ax4.errorbar(x, m, (m - l, u - m), color='silver', ls='none')
ax4.scatter(x, m, c=np.arange(data['num_groups']), cmap='tab10', zorder=2)
ax4.set_aspect('equal')
ax4.set_xlabel(r'Group scales $\sigma$')
ax4.set_ylabel(r'Inferred group scales')

labels = [
    (ax1, 'top left', '(a)'), 
    (ax2, 'top left', '(b)'), 
    (ax3, 'top right', '(c)'),
    (ax4, 'top left', '(d)'),
]
for ax, loc, label in labels:
    va, ha = loc.split()
    ax.text(0.05 if ha == 'left' else 0.95, 0.05 if va == 'bottom' else 0.95, label, 
            transform=ax.transAxes, ha=ha, va=va)

fig.tight_layout()
fig.savefig(f'../workspace/simulation.pdf')

In [None]:
# Show the scatter plot.
fig, ax = plt.subplots()
ax.scatter(*samples.T, cmap='tab10', marker='.', alpha=.1, label='posterior samples',
           c=np.arange(num_groups)[:, None] * np.ones(aligned.shape[0]))
pts = ax.scatter(*reference.T, c=np.arange(num_groups), marker='X', cmap='tab10', label='reference')
pts.set_edgecolor('w')
ax.set_aspect('equal')
ax.legend(fontsize='small')

fig.tight_layout()

In [None]:
# Show the posterior predictive replication.
fig, ax = plt.subplots()

x = data['group_adjacency'].ravel()
ys = chain['ppd_group_adjacency'].reshape(x.shape + (-1,))
l, m, u = np.percentile(ys, [2.5, 50, 97.5], axis=-1)
lims = x.min(), x.max()
ax.plot(lims, lims, color='k', ls=':')
ax.errorbar(x, m, (m - l, u - m), ls='none', marker='.')
ax.set_aspect('equal')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Group adjacency $Y$')
ax.set_ylabel('Group adjacency posterior replicates')
fig.tight_layout()