In [None]:
import alsm
from matplotlib import pyplot as plt
import matplotlib as mpl
import nest_asyncio
import numpy as np
import pandas as pd
from pathlib import Path
import re
import stan


mpl.rcParams['figure.dpi'] = 144
nest_asyncio.apply()

DATA_ROOT = Path('../data/addhealth')
SEED = 0

In [None]:
def pop_if_match(lines: list, pattern: str, index=0) -> re.Match:
    """
    Pop a line from the list if it matches a pattern.
    """
    line = lines[index]
    if (match := re.match(pattern, line)):
        lines.pop(index)
        return match
    else:
        raise ValueError(f'{line} does not match `{pattern}`')


def lines_to_array(lines, parser=float):
    return np.asarray([[parser(cell) for cell in line.split()] for line in lines])


# Load the edgelist.
with open(DATA_ROOT / 'comm72.dat') as fp:
    lines = fp.readlines()
    
pop_if_match(lines, 'DL')
num_nodes = int(pop_if_match(lines, r'N=(\d+)').group(1))
pop_if_match(lines, 'FORMAT=EDGELIST1')
pop_if_match(lines, 'DATA:')
edgelist = lines_to_array(lines).astype(int)

# Construct the adjacency matrix.
i, j, w = edgelist.T
adjacency = np.zeros((num_nodes, num_nodes), int)
adjacency[i - 1, j - 1] = w

In [None]:
# Load the metadata.
with open(DATA_ROOT / 'comm72_att.dat') as fp:
    lines = fp.readlines()

pop_if_match(lines, 'DL')
num_rows, num_cols = map(int, pop_if_match(lines, 'NR=(\d+), * NC=(\d+)').groups())
assert num_rows == num_nodes
pop_if_match(lines, 'FORMAT = FULLMATRIX DIAGONAL PRESENT')

# Get the column labels.
pop_if_match(lines, 'COLUMN LABELS:')
labels = [label.strip('\n"') for label in lines[:num_cols]]
lines = lines[num_cols:]

# Skip to the data.
while not lines.pop(0).startswith('DATA:'):
    pass

# Create a dataframe for the attributes.
attributes = lines_to_array(lines, parser=int)
attributes = pd.DataFrame(attributes, columns=labels)
assert attributes.shape == (num_nodes, num_cols)

In [None]:
# Group the adjacency matrix (using the neat `ngroup` function).
keys = ['grade', 'sex']
grouper = attributes.groupby(keys)
group_idx = grouper.ngroup().values
group_sizes = np.bincount(group_idx)
num_groups, = group_sizes.shape
grouping = alsm.evaluate_grouping_matrix(group_idx)
group_adjacency = grouping @ adjacency @ grouping.T

# Get attributes of the groups.
group_attributes = pd.DataFrame([key for key, _ in grouper], columns=keys)

plt.imshow(group_adjacency)
group_attributes

In [None]:
# Fit the model. We will rearrange the groups such that the "furthest" groups are represented by the first two
# indices such that we end up pinning the posterior and killing rotational symmetry.

index = np.arange(num_groups)
index[1] = num_groups - 1
index[num_groups - 1] = 1

data = {
    'num_groups': num_groups,
    'num_dims': 2,
    'epsilon': 1e-12,
    'group_adjacency': group_adjacency,
    'group_sizes': group_sizes,
}
posterior = stan.build(alsm.get_group_model_code(), data=alsm.apply_permutation_index(data, index), 
                       random_seed=SEED)
fit = posterior.sample(num_warmup=1000, num_samples=1000, num_chains=8, init_radius=1e-2)

In [None]:
lps = alsm.get_samples(fit, 'lp__', False)
plt.plot(lps, alpha=.5)
print('mean lps', np.mean(lps, axis=0))
chain = alsm.get_chain(fit, 'best')
print('median leapfrog steps in best chain', np.median(chain['n_leapfrog__']))

In [None]:
# Apply the inverse index.
chain = alsm.get_chain(fit, 'best')
chain = alsm.apply_permutation_index(chain, alsm.invert_index(index))


samples = np.rollaxis(chain['group_locs'], -1)
aligned = alsm.align_samples(samples)

fig, ax = plt.subplots()
modes = alsm.estimate_mode(np.rollaxis(aligned, 1))
alsm.plot_edges(modes, group_adjacency, lw=3)

c = group_attributes.grade.values[:, None] * np.ones(fit.num_samples)
pts = ax.scatter(*aligned.T, c=c, marker='.', alpha=.01)
plt.draw()

for xy, radius, tup in zip(modes, np.median(chain['group_scales'], axis=-1), 
                           group_attributes.itertuples()):
    color = pts.cmap(pts.norm(tup.grade))
    circle = mpl.patches.Circle(xy, radius, facecolor='none', edgecolor=color)
    ax.add_patch(circle)
    ax.scatter(*xy, color=color, marker='s' if tup.sex == 1 else 'o', zorder=2).set_edgecolor('w')

ax.set_aspect('equal')

In [None]:
l, u = np.percentile(chain['ppd_group_adjacency'], [25, 75], axis=-1)
coverage = np.mean((data['group_adjacency'] >= l) & (data['group_adjacency'] <= u))
print(f'ppd coverage of interquartile range: {coverage:.3f}')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

# Manually adjust the angle to better fit the figure space.
angle = 15
rotation = alsm.evaluate_rotation_matrix(np.deg2rad(angle))

ax = ax1
samples = np.rollaxis(chain['group_locs'], -1)
aligned = alsm.align_samples(samples) @ rotation

modes = alsm.estimate_mode(np.rollaxis(aligned, 1), scale=5)
alsm.plot_edges(modes, group_adjacency, lw=3, ax=ax)

c = group_attributes.grade.values[:, None] * np.ones(fit.num_samples)
pts = ax.scatter(*aligned.T, c=c, marker='.', alpha=.01)
plt.draw()

for xy, radius, tup in zip(modes, np.median(chain['group_scales'], axis=-1), 
                           group_attributes.itertuples()):
    color = pts.cmap(pts.norm(tup.grade))
    circle = mpl.patches.Circle(xy, radius, facecolor='none', edgecolor=color)
    ax.add_patch(circle)
    ax.scatter(*xy, color=color, marker='s' if tup.sex == 1 else 'o', zorder=2).set_edgecolor('w')

ax.set_aspect('equal')
ax.set_xlabel('Embedding $z_1$')
ax.set_ylabel('Embedding $z_2$')
cb = fig.colorbar(pts, ax=ax, location='top')
cb.set_label('Grade')
cb.set_alpha(1)
cb.draw_all()

handle_girls = ax.scatter([], [], color='k', marker='o')
handle_girls.set_edgecolor('w')
handle_boys = ax.scatter([], [], color='k', marker='s')
handle_boys.set_edgecolor('w')
ax.legend([handle_girls, handle_boys], ['girls', 'boys'], fontsize='small',
          loc='lower right')
ax.set_ylim(-8.5)


ax = ax2
x = data['group_adjacency'].ravel()
ys = chain['ppd_group_adjacency'].reshape(x.shape + (-1,))
y = np.median(ys, axis=-1)
lims = x.min(), x.max()
ax.plot(lims, lims, color='k', ls=':', zorder=0)
grade = group_attributes.grade.values
delta = (grade[None, :] - grade[:, None]).ravel()
c = np.clip(delta, -1, 1)

df = pd.DataFrame({'x': x, 'y': y, 'c': c})
series = df.groupby(['x', 'y', 'c']).size().sort_values(ascending=False)
series.name = 's'
series = series.reset_index()
lookup = {-1: 'C1', 1: 'C0', 0: 'gray'}
ax.scatter(series.x, series.y, s=10 * series.s, c=[lookup[c] for c in series.c])

ax.set_aspect('equal')
ax.set_xscale('symlog', linthresh=1, linscale=.25)
ax.set_yscale('symlog', linthresh=1, linscale=.25)
ax.set_xlabel('Group adjacency $Y_{ab}$')
ax.set_ylabel(r'Posterior predictive replicates $Y_{ab}^{(\mathrm{rep})}$')

prop_ge = df.groupby('c').apply(lambda x: (x.y > x.x).mean() + (x.y == x.x).mean() / 2)
handles_labels = [
    (mpl.lines.Line2D([], [], color='C1', ls='none', marker='o'), 'to lower grade'),
    (mpl.lines.Line2D([], [], color='C0', ls='none', marker='o'), 'to higher grade'),
    (mpl.lines.Line2D([], [], color='gray', ls='none', marker='o'), 'to same grade'),
]
for handle, _ in handles_labels:
    handle.set_markeredgecolor('w')
ax.legend(*zip(*handles_labels), fontsize='small')
ax.set_ylim(-1, 1e3)
ax.set_xlim(-1)

ax1.text(0.05, 0.95, '(a)', va='top', transform=ax1.transAxes)
ax2.text(0.95, 0.05, '(b)', ha='right', transform=ax2.transAxes)

fig.tight_layout()
fig.savefig('../workspace/addhealth.pdf')
prop_ge