-
Notifications
You must be signed in to change notification settings - Fork 235
/
ssbvm_mixture.py
302 lines (256 loc) · 10.7 KB
/
ssbvm_mixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
r"""
Example: Sine-skewed sine (bivariate von Mises) mixture
=======================================================
This example models the dihedral angles that occur in the backbone of a protein as a mixture of skewed
directional distributions. The backbone angle pairs, called :math:`\phi` and :math:`\psi`, are a canonical
representation for the fold of a protein. In this model, we fix the third dihedral angle (omega) as it usually only
takes angles 0 and pi radian, with the latter being the most common. We model the angle pairs as a distribution on
the torus using the sine distribution [1] and break point-wise (toroidal) symmetry using sine-skewing [2].
.. image:: ../_static/img/examples/ssbvm_mixture_torus_top.png
:align: center
:scale: 30%
**References:**
1. Singh et al. (2002). Probabilistic model for two dependent circular variables. Biometrika.
2. Jose Ameijeiras-Alonso and Christophe Ley (2021). Sine-skewed toroidal distributions and their application
in protein bioinformatics. Biostatistics.
.. image:: ../_static/img/examples/ssbvm_mixture.png
:align: center
:scale: 125%
"""
import argparse
import math
from math import pi
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from jax import numpy as jnp, random
import numpyro
from numpyro.distributions import (
Beta,
Categorical,
Dirichlet,
Gamma,
Normal,
SineBivariateVonMises,
SineSkewed,
Uniform,
VonMises,
)
from numpyro.distributions.transforms import L1BallTransform
from numpyro.examples.datasets import NINE_MERS, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive, init_to_value
from numpyro.infer.reparam import CircularReparam
AMINO_ACIDS = [
"M",
"N",
"I",
"F",
"E",
"L",
"R",
"D",
"G",
"K",
"Y",
"T",
"H",
"S",
"P",
"A",
"V",
"Q",
"W",
"C",
]
# The support of the von Mises is [-π,π) with a periodic boundary at ±π. However, the support of
# the implemented von Mises distribution is just the interval [-π,π) without the periodic boundary. If the
# loc is close to one of the boundaries (-π or π), the sampler must traverse the entire interval to cross the
# boundary. This produces a bias, especially if the concentration is high. The interval around
# zero will have a low probability, making the jump to the other boundary unlikely for the sampler.
# Using the `CircularReparam` introduces the periodic boundary by transforming the real line to [-π,π).
# The sampler can sample from the real line, thus crossing the periodic boundary without having to traverse the
# the entire interval, which eliminates the bias.
@numpyro.handlers.reparam(
config={"phi_loc": CircularReparam(), "psi_loc": CircularReparam()}
)
def ss_model(data, num_data, num_mix_comp=2):
# Mixture prior
mix_weights = numpyro.sample("mix_weights", Dirichlet(jnp.ones((num_mix_comp,))))
# Hprior BvM
# Bayesian Inference and Decision Theory by Kathryn Blackmond Laskey
beta_mean_phi = numpyro.sample("beta_mean_phi", Uniform(0.0, 1.0))
beta_count_phi = numpyro.sample(
"beta_count_phi", Gamma(1.0, 1.0 / num_mix_comp)
) # shape, rate
halpha_phi = beta_mean_phi * beta_count_phi
beta_mean_psi = numpyro.sample("beta_mean_psi", Uniform(0, 1.0))
beta_count_psi = numpyro.sample(
"beta_count_psi", Gamma(1.0, 1.0 / num_mix_comp)
) # shape, rate
halpha_psi = beta_mean_psi * beta_count_psi
with numpyro.plate("mixture", num_mix_comp):
# BvM priors
# Place gap in forbidden region of the Ramachandran plot (protein backbone dihedral angle pairs)
phi_loc = numpyro.sample("phi_loc", VonMises(pi, 2.0))
psi_loc = numpyro.sample("psi_loc", VonMises(0.0, 0.1))
phi_conc = numpyro.sample(
"phi_conc", Beta(halpha_phi, beta_count_phi - halpha_phi)
)
psi_conc = numpyro.sample(
"psi_conc", Beta(halpha_psi, beta_count_psi - halpha_psi)
)
corr_scale = numpyro.sample("corr_scale", Beta(2.0, 10.0))
# Skewness prior
ball_transform = L1BallTransform()
skewness = numpyro.sample("skewness", Normal(0, 0.5).expand((2,)).to_event(1))
skewness = ball_transform(skewness)
with numpyro.plate("obs_plate", num_data, dim=-1):
assign = numpyro.sample(
"mix_comp", Categorical(mix_weights), infer={"enumerate": "parallel"}
)
sine = SineBivariateVonMises(
phi_loc=phi_loc[assign],
psi_loc=psi_loc[assign],
# These concentrations are an order of magnitude lower than expected (550-1000)!
phi_concentration=70 * phi_conc[assign],
psi_concentration=70 * psi_conc[assign],
weighted_correlation=corr_scale[assign],
)
return numpyro.sample("phi_psi", SineSkewed(sine, skewness[assign]), obs=data)
def run_hmc(rng_key, model, data, num_mix_comp, args, bvm_init_locs):
kernel = NUTS(
model, init_strategy=init_to_value(values=bvm_init_locs), max_tree_depth=7
)
mcmc = MCMC(kernel, num_samples=args.num_samples, num_warmup=args.num_warmup)
mcmc.run(rng_key, data, len(data), num_mix_comp)
mcmc.print_summary()
post_samples = mcmc.get_samples()
return post_samples
def fetch_aa_dihedrals(aa):
_, fetch = load_dataset(NINE_MERS, split=aa)
return jnp.stack(fetch())
def num_mix_comps(amino_acid):
num_mix = {"G": 10, "P": 7}
return num_mix.get(amino_acid, 9)
def ramachandran_plot(data, pred_data, aas, file_name="ssbvm_mixture.pdf"):
amino_acids = {"S": "Serine", "P": "Proline", "G": "Glycine"}
fig, axss = plt.subplots(2, len(aas))
cdata = data
for i in range(len(axss)):
if i == 1:
cdata = pred_data
for ax, aa in zip(axss[i], aas):
aa_data = cdata[aa]
nbins = 50
ax.hexbin(
aa_data[..., 0].reshape(-1),
aa_data[..., 1].reshape(-1),
norm=matplotlib.colors.LogNorm(),
bins=nbins,
gridsize=100,
cmap="Blues",
)
# label the contours
ax.set_aspect("equal", "box")
ax.set_xlim([-math.pi, math.pi])
ax.set_ylim([-math.pi, math.pi])
ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
ax.xaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))
ax.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
ax.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
ax.yaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))
if i == 0:
axtop = ax.secondary_xaxis("top")
axtop.set_xlabel(amino_acids[aa])
axtop.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
axtop.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
axtop.xaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))
if i == 1:
ax.set_xlabel(r"$\phi$")
for i in range(len(axss)):
axss[i, 0].set_ylabel(r"$\psi$")
axss[i, 0].yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
axss[i, 0].yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
axss[i, 0].yaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))
axright = axss[i, -1].secondary_yaxis("right")
axright.set_ylabel("data" if i == 0 else "simulation")
axright.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
axright.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
axright.yaxis.set_major_formatter(plt.FuncFormatter(multiple_formatter()))
for ax in axss[:, 1:].reshape(-1):
ax.tick_params(labelleft=False)
ax.tick_params(labelleft=False)
for ax in axss[0, :].reshape(-1):
ax.tick_params(labelbottom=False)
ax.tick_params(labelbottom=False)
if file_name:
fig.tight_layout()
plt.savefig(file_name, bbox_inches="tight")
def multiple_formatter(denominator=2, number=np.pi, latex=r"\pi"):
def gcd(a, b):
while b:
a, b = b, a % b
return a
def _multiple_formatter(x, pos):
den = denominator
num = int(np.rint(den * x / number))
com = gcd(num, den)
(num, den) = (int(num / com), int(den / com))
if den == 1:
if num == 0:
return r"$0$"
if num == 1:
return r"$%s$" % latex
elif num == -1:
return r"$-%s$" % latex
else:
return r"$%s%s$" % (num, latex)
else:
if num == 1:
return r"$\frac{%s}{%s}$" % (latex, den)
elif num == -1:
return r"$\frac{-%s}{%s}$" % (latex, den)
else:
return r"$\frac{%s%s}{%s}$" % (num, latex, den)
return _multiple_formatter
def main(args):
data = {}
pred_datas = {}
rng_key = random.PRNGKey(args.rng_seed)
for aa in args.amino_acids:
rng_key, inf_key, pred_key = random.split(rng_key, 3)
data[aa] = fetch_aa_dihedrals(aa)
num_mix_comp = num_mix_comps(aa)
# Use kmeans to initialize the chain location.
kmeans = KMeans(num_mix_comp)
kmeans.fit(data[aa])
means = {
"phi_loc": kmeans.cluster_centers_[:, 0],
"psi_loc": kmeans.cluster_centers_[:, 1],
}
posterior_samples = {
"ss": run_hmc(inf_key, ss_model, data[aa], num_mix_comp, args, means)
}
predictive = Predictive(ss_model, posterior_samples["ss"], parallel=True)
pred_datas[aa] = predictive(pred_key, None, 1, num_mix_comp)["phi_psi"].reshape(
-1, 2
)
ramachandran_plot(data, pred_datas, args.amino_acids)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Sine-skewed sine (bivariate von mises) mixture model example"
)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
parser.add_argument("--amino-acids", nargs="+", default=["S", "P", "G"])
parser.add_argument("--rng_seed", type=int, default=123)
parser.add_argument("--device", default="gpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
assert all(
aa in AMINO_ACIDS for aa in args.amino_acids
), f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids."
main(args)