-
Notifications
You must be signed in to change notification settings - Fork 0
/
sampler.py
98 lines (86 loc) · 4.01 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Implements Markov Chain Monte Carlo sampling."""
from functools import partial
import jax
import jax.numpy as jnp
from wavefunction import psi
@jax.jit
def connected_states(a, n_max):
"""Generates all states connected by the Hamiltonian to a given state.
The second returned array, physical, flags states which are physical,
i. e. with bosonic occupations in the range [0, n_max]. Jax requires static
shape of arrays, hence we return also the unphysical states and flag them
with 0.
"""
states_plus = jnp.repeat(a[None], len(a), 0) + jnp.eye(len(a))
states_minus = jnp.repeat(a[None], len(a), 0) - jnp.eye(len(a))
states = jnp.vstack((jnp.reshape(a, (1,-1)), states_plus, states_minus))
geq_zero = jnp.all(states >= 0, axis=1, keepdims=True)
leq_nmax = jnp.all(states <= n_max, axis=1, keepdims=True)
physical = jnp.logical_and(geq_zero, leq_nmax)
physical = physical.astype(int).ravel()
return states, physical
@partial(jax.jit, static_argnums=1)
def num_connected(a, n_max):
"""Counts states connected by the Hamiltonian to a given state.
Equivalent to _, p = connected_states(a, n_max); return jnp.sum(p).
"""
count_unreachable = jnp.count_nonzero(a==0) + jnp.count_nonzero(a==n_max)
return 2 * a.shape[0] + 1 - count_unreachable
@partial(jax.jit, static_argnums=2)
def generate_proposal(key, state, n_max):
"""Proposes next state in Monte Carlo chain.
The proposal is generated by uniform sampling from states connected by
the Hamiltonian to the current state.
"""
conn, physical = connected_states(state, n_max)
s = jnp.sum(physical)
probs = physical/s
c = jax.random.choice(key, conn.shape[0], replace=True, p=probs)
return conn[c], s
@partial(jax.jit, static_argnums=(1, 5))
def make_step(key, model, variational_pars, state, log_prob, n_max):
"""Performs a single step of the Monte Carlo chain."""
key, subkey = jax.random.split(key)
log_uniform = jnp.log(jax.random.uniform(subkey))
key, subkey = jax.random.split(subkey)
proposal, num_conn = generate_proposal(subkey, state, n_max)
proposal_psi = psi(variational_pars, model, proposal)
proposal_log_prob = jnp.log(jnp.square(proposal_psi))
log_corr = jnp.log(num_conn) - jnp.log(num_connected(proposal, n_max))
do_accept = log_uniform < proposal_log_prob + log_corr - log_prob
next_state = jnp.where(do_accept, proposal, state)
log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
return next_state, log_prob, subkey
@partial(jax.jit, static_argnums=(1, 3, 4))
def single_chain(key, model, variational_pars, alg_pars, phys_pars):
"""Runs a single Monte Carlo chain.
For stability and saple indepencence, irst phys_pars.burnin samples
are discarded, then only every num_k-th sample is used.
"""
max_n_start = 2 # Choose initial state only among n=0, n=1 occupations.
num_k = phys_pars.k_grid.shape[0]
init_state = jax.random.randint(
key, shape=(num_k,), minval=0, maxval=max_n_start).astype(jnp.float32)
init_psi = psi(variational_pars, model, init_state)
init_log_prob = jnp.log(jnp.square(init_psi))
init = (key, init_log_prob, init_state)
def f(carry, x):
"""Auxiliary function for jax.lax.scan."""
key, log_prob, state = carry
out = make_step(
key, model, variational_pars, state, log_prob, phys_pars.n_max)
next_state, log_prob, subkey = out
return (subkey, log_prob, next_state), next_state
loop_length = alg_pars.samples_per_chain + alg_pars.burnin
_, states = jax.lax.scan(f, init=init, xs=None, length=loop_length)
states = states[alg_pars.burnin:]
return states[::num_k]
@partial(jax.jit, static_argnums=(1, 3, 4))
def generate_samples(key, model, variational_pars, alg_pars, phys_pars):
"""Runs several Monte Carlo chains and collates samples."""
in_axes = (0, None, None, None, None)
multi_chain = jax.vmap(single_chain, in_axes=in_axes)
keys = jax.random.split(key, num=alg_pars.num_chains)
samples = multi_chain(keys, model, variational_pars, alg_pars, phys_pars)
num_k = phys_pars.k_grid.shape[0]
return jnp.reshape(samples, (-1, num_k))