In [None]:
import alsm
import cmdstanpy
from matplotlib import pyplot as plt
import matplotlib as mpl
from mpl_toolkits import axes_grid1
import numpy as np
import pandas as pd
from pathlib import Path
import re
from scipy.linalg import orthogonal_procrustes


mpl.rcParams['figure.dpi'] = 144

DATA_ROOT = Path('../data/addhealth')
SEED = 0

In [None]:
def pop_if_match(lines: list, pattern: str, index=0) -> re.Match:
    """
    Pop a line from the list if it matches a pattern.
    """
    line = lines[index]
    if (match := re.match(pattern, line)):
        lines.pop(index)
        return match
    else:
        raise ValueError(f'{line} does not match `{pattern}`')


def lines_to_array(lines, parser=float):
    return np.asarray([[parser(cell) for cell in line.split()] for line in lines])


# Load the edgelist.
with open(DATA_ROOT / 'comm72.dat') as fp:
    lines = fp.readlines()

pop_if_match(lines, 'DL')
num_nodes = int(pop_if_match(lines, r'N=(\d+)').group(1))
pop_if_match(lines, 'FORMAT=EDGELIST1')
pop_if_match(lines, 'DATA:')
edgelist = lines_to_array(lines).astype(int)

# Construct the adjacency matrix.
i, j, w = edgelist.T
adjacency = np.zeros((num_nodes, num_nodes), int)
adjacency[i - 1, j - 1] = 1

In [None]:
# Load the metadata.
with open(DATA_ROOT / 'comm72_att.dat') as fp:
    lines = fp.readlines()

pop_if_match(lines, 'DL')
num_rows, num_cols = map(int, pop_if_match(lines, r'NR=(\d+), * NC=(\d+)').groups())
assert num_rows == num_nodes
pop_if_match(lines, 'FORMAT = FULLMATRIX DIAGONAL PRESENT')

# Get the column labels.
pop_if_match(lines, 'COLUMN LABELS:')
labels = [label.strip('\n"') for label in lines[:num_cols]]
lines = lines[num_cols:]

# Skip to the data.
while not lines.pop(0).startswith('DATA:'):
    pass

# Create a dataframe for the attributes.
attributes = lines_to_array(lines, parser=int)
attributes = pd.DataFrame(attributes, columns=labels)
assert attributes.shape == (num_nodes, num_cols)

In [None]:
# Group the adjacency matrix (using the neat `ngroup` function).
keys = ['grade', 'sex']
grouper = attributes.groupby(keys)
group_idx = grouper.ngroup().values
group_sizes = np.bincount(group_idx)
num_groups, = group_sizes.shape
grouping = alsm.evaluate_grouping_matrix(group_idx)
group_adjacency = grouping @ adjacency @ grouping.T

# Get attributes of the groups.
group_attributes = pd.DataFrame([key for key, _ in grouper], columns=keys)

plt.imshow(group_adjacency)
group_attributes

In [None]:
# Assemble the data for stan.
data = {
    'num_nodes': num_nodes,
    'num_groups': num_groups,
    'num_dims': 2,
    'group_idx': group_idx + 1,
    'epsilon': 1e-20,
    'group_adjacency': group_adjacency,
    'adjacency': adjacency,
    'group_sizes': group_sizes,
    'weighted': 0,
}

In [None]:
# Fit the model. We will rearrange the groups such that the "furthest" groups are represented by the
# first two indices such that we end up pinning the posterior and killing rotational symmetry.

index = np.arange(num_groups)
index[1] = num_groups - 1
index[num_groups - 1] = 1

stan_file = alsm.write_stanfile(alsm.get_group_model_code())
posterior = cmdstanpy.CmdStanModel(stan_file=stan_file)
fit = posterior.sample(iter_warmup=1000, iter_sampling=1000, chains=8, inits=1e-2, seed=SEED,
                       data=alsm.apply_permutation_index(data, index), show_progress=False)

In [None]:
lps = alsm.get_samples(fit, 'lp__', False)
plt.plot(lps, alpha=.5)
print('mean lps', np.mean(lps, axis=0))
chain = alsm.get_chain(fit, 'best')
print('median leapfrog steps in best chain', np.median(chain['n_leapfrog__']))

In [None]:
# Apply the inverse index.
chain = alsm.get_chain(fit, 'best')
chain = alsm.apply_permutation_index(chain, alsm.invert_index(index))

# Align samples and estimate the mode.
samples = np.rollaxis(chain['group_locs'], -1)
aligned = alsm.align_samples(samples)
modes = alsm.estimate_mode(np.rollaxis(aligned, 1))

fig, ax = plt.subplots()
alsm.plot_edges(modes, group_adjacency, lw=3)

c = group_attributes.grade.values[:, None] * np.ones(fit.num_draws_sampling)
pts = ax.scatter(*aligned.T, c=c, marker='.', alpha=.01)
plt.draw()

for xy, radius, tup in zip(modes, np.median(chain['group_scales'], axis=-1),
                           group_attributes.itertuples()):
    color = pts.cmap(pts.norm(tup.grade))
    circle = mpl.patches.Circle(xy, radius, facecolor='none', edgecolor=color)
    ax.add_patch(circle)
    ax.scatter(*xy, color=color, marker='s' if tup.sex == 1 else 'o', zorder=2).set_edgecolor('w')

ax.set_aspect('equal')

In [None]:
l, u = np.percentile(chain['ppd_group_adjacency'], [25, 75], axis=-1)
coverage = np.mean((data['group_adjacency'] >= l) & (data['group_adjacency'] <= u))
print(f'ppd coverage of interquartile range: {coverage:.3f}')

In [None]:
code = alsm.get_individual_model_code(group_prior=True)
individual_model = cmdstanpy.CmdStanModel(stan_file=alsm.write_stanfile(code))

approximations = []
# Get the centred modes.
y = modes - modes.mean(axis=0)
for seed in range(fit.chains):
    approx = individual_model.variational(data, seed=seed + SEED, inits=1e-2)
    approximations.append(approx)
    # Evaluate the aligned loss.
    x = approx.stan_variable('group_locs')
    x = x - x.mean(axis=0)
    transform, _ = orthogonal_procrustes(x, y)
    x = x @ transform
    loss = np.mean(np.square(x - y))
    print(f'seed {seed}; elbo {alsm.get_elbo(approx)}; alignment loss: {loss}')

# Get the best ELBO.
approx = max(approximations, key=alsm.get_elbo)

In [None]:
fig = plt.figure()
grid = axes_grid1.AxesGrid(fig, 111, (1, 2), cbar_mode='single',
                           share_all=True, axes_pad=0.2)
ax1, ax2 = grid

rotation = alsm.evaluate_rotation_matrix(np.deg2rad(-20))
xs = approx.stan_variable('locs') @ rotation
x = approx.stan_variable('group_locs') @ rotation
y = modes

# Center both.
x = x - x.mean(axis=0)
xs = xs - x.mean(axis=0)
y = y - y.mean(axis=0)
ys = aligned

# Scale both.
scale_x = np.linalg.norm(x)
scale_y = np.linalg.norm(y)
y = scale_x * y / scale_y
ys = scale_x * ys / scale_y

# Apply the rigid procrustes transform.
transform, _ = orthogonal_procrustes(y, x)
y = y @ transform
ys = ys @ transform

alsm.plot_edges(xs, adjacency, alpha=.2, ax=ax1, zorder=0)
alsm.plot_edges(y, group_adjacency, ax=ax2, zorder=1, lw=3)
ax1.scatter(*xs.T, c=attributes.grade, marker='.')
ax2.scatter(*ys.T, c=group_attributes.grade.values[:, None] * np.ones(fit.num_draws_sampling),
            marker='.', alpha=.025, zorder=0)

# Show the markers for the group locations.
for sex, subset in group_attributes.groupby('sex'):
    marker = 'so'[sex - 1]
    ax1.scatter(*x[subset.index].T, c=subset.grade, marker=marker).set_edgecolor('w')
    pts = ax2.scatter(*y[subset.index].T, c=subset.grade, marker=marker, zorder=2)
    pts.set_edgecolor('w')

plt.draw()

for xy, radius, tup in zip(y, np.median(chain['group_scales'], axis=-1) * scale_x / scale_y,
                           group_attributes.itertuples()):
    color = pts.cmap(pts.norm(tup.grade))
    circle = mpl.patches.Circle(xy, radius, facecolor='none', edgecolor=color)
    ax2.add_patch(circle)

for ax, label in [(ax1, '(a)'), (ax2, '(b)')]:
    ax.set_aspect('equal')
    ax.text(0.05, 0.95, label, transform=ax.transAxes, va='top')
    ax.autoscale_view()
    ax.set_xlabel('Embedding $z_1$')

ax1.set_ylabel('Embedding $z_2$')
plt.setp(ax2.yaxis.get_ticklabels(), visible=False)

# Add a legend for the symbols.
ax = ax2
handle_girls = ax.scatter([], [], color='k', marker='o')
handle_girls.set_edgecolor('w')
handle_boys = ax.scatter([], [], color='k', marker='s')
handle_boys.set_edgecolor('w')
ax.legend([handle_girls, handle_boys], ['girls', 'boys'], fontsize='small',
          loc='upper right')

fig.colorbar(pts, cax=grid.cbar_axes[0]).set_label('Grade')
fig.savefig('../workspace/addhealth.pdf')
print(f'Scale adjustment factor: {scale_x / scale_y}')
