# The Weighted Ensemble Method

## Toy model with double well potential, Brownian Dynamics

The Weighted Ensemble (WE) method provides a route to estimating kinetic and thermodynamic parameters for many different types of biomolecular simulation problem. For a good introduction, see this [2017 review from Zuckerman and Chong](https://pubmed.ncbi.nlm.nih.gov/28301772/).

The aim of this notebook is to illustrate the key aspects of "steady state" type WE simulations (walkers, progress coordinates, binning, splitting and merging, recycling) with a simple "toy" model that is fast enough to run that it can be experimented with interactively.

Rather than using an off-the-shelf WE simulation platform such  as [WESTPA](https://pubmed.ncbi.nlm.nih.gov/26392815/), here you will build it yourself from components in a Python library `WElib`.

`WElib` is not in Pypi, but you can install it directly from the GitHub repository, as follows:

    pip install git+http://github.com/CharlieLaughton/WElib.git
    


In [None]:
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
%load_ext Cython

## Part 1: The double well potential

Here is the simple skewed double-well energy function we will use for this toy problem. It has three parameters: `k` controls the barrier height, `a` the distance between the minima and `b` the degree of skewdness.

In [None]:
def double_well_function(x, k=1.0, a=2.0, b=0.0):
    '''
    A skewed double well potential
    
    x is the coordinate(s), k, a, and b are parameters
    '''
    energy = 0.25*k*((x-a)**2)*((x+a)**2) + b*x
    force = - (k * x * (x-a) * (x+a) + b)
    return energy, force

Evaluate it for x between -3.0 and 3.0:

In [None]:
k = 1.0
a = 2.0
b = 1.0
x_vals = np.linspace(-3.0, 3.0, 20)

e, f = double_well_function(x_vals, k, a, b)
plt.plot(x_vals, e, label='Energy')
#plt.plot(x_vals, f, label='Force')
#plt.plot(x_vals, [0.0] * len(x), '_')
plt.xlabel('coordinate')
plt.ylabel('Energy (kT)')
plt.legend()

The double-well potential has a "global" energy minimum at a coordinate of -2, and a "local" energy minimum at a coordinate of 2. The difference in energy between them is 4kT. The barrier height from left to right is 6kT and from right to left it is 2kT. So now our aim is to construct a Weighted Ensemble simulation to find the rates at which a particle will move from the left-hand (coordinate = -2) to the right-hand (coordinate = 2) well, and the reverse.

## Part 2: "Vanilla" Brownian dynamics simulations on the double well potential

We build a Brownian dynamics propagator that will move a particle over this potential. We use Cython here to get maximum speed (you can ignore warning messages about deprecated NumPy API, if you get them)

In [None]:
%%cython
import numpy as np
cimport numpy as np
np.import_array()

def bd_dwp_sim(double x, int n_steps, double dt, double gamma, double k, double a, double b):
    '''
    A Brownian dynamics simulator for a double-well potential (units of kT)
    
    Parameters:
        x: initial coordinate
        n_steps: number of BD steps
        dt: time interval
        gamma: friction coefficient
        k, a, b: parameters of the double-well potential
        
    Returns:
        The final coordinate
    '''
    cdef double scalefac1, scalefac2
    cdef int i
    friction_factor = dt/gamma
    noise_factor = np.sqrt(2 * gamma * dt)
    cdef np.ndarray[np.double_t] noise = np.random.normal(scale=noise_factor, size=n_steps)

    for i in range(n_steps):
        f = k * x * (x-a) * (x+a) + b
        x = x - f*friction_factor + noise[i]
    return x

Perform an unbiased simulation using these parameters - how does a single walker sample the double-well potential? The simulation is 1000,000 steps, saving the coordinate every 1000 steps:

In [None]:
traj = []
x = -2.0
dt = 1e-5
gamma = 1.0
k = 1.0
a = 2.0
b = 1.0
n_steps = 1000
n_cycles = 1000

for i in range(n_cycles):
    x = bd_dwp_sim(x, n_steps, dt, gamma, k, a, b)
    traj.append(x)

fig, ax1 = plt.subplots()
ax1.plot(x_vals, e)
plt.ylabel('Energy (kT)')
ax2 = ax1.twinx()
out = ax2.hist(traj, bins=50, color='green')
plt.xlabel('coordinate')
plt.ylabel('frequency')

The green histogram shows that only the left-hand well is sampled, there are no transitions to the right-hand well (coordinate > 0). We will now show how a weighted ensemble simulation can overcome this.

## Part 3: Build the WE simulation workflow

Now we build a Weighted Ensemble simulation workflow using "building blocks" from a small home-made Python library `WElib`. 

If neccessary, run the following cell to install the library:

In [None]:
!pip install git+http://github.com/CharlieLaughton/WElib.git

Now install the components you will need:

In [None]:
from WElib import Walker, FunctionProgressCoordinator, FunctionStepper, StaticBinner, SplitMerger, Recycler

### Walkers
Create a set of initial "walkers" with initial state (coordinates) corresponding to the base of the left-hand well. The total weight of all walkers will be 1.0:

In [None]:
n_reps = 5
initial_coordinates = -2.0
initial_weight = 1.0/n_reps
walkers = [Walker(initial_coordinates, initial_weight) for i in range(n_reps)]
for walker in walkers:
    print(walker)

Note that currently the progress coordinate and bin assignment of each walker is undefined - we will add this information to them now.

------
### The Progress Coordinator
In a WE simulation, we monitor one or more "progress coordinates". In general this is/are something that is calculated from the current coordinates, and choosing the right definition for the progress coordinate(s) can be a non-trivial issue.

`WElib` contains a "FunctionProgressCoordinator" class that is used to construct building blocks that do the job of adding progress coordinate information to sets of walkers. These are initialised with a user-supplied function that takes in a state, and returns a progress coordinate (or coordinates). The FunctionProgressCoordinator instance then does the job of applying this to a whole set of walkers, and does other housekeeping stuff behind the scene.

Create a function that takes a state and returns the progress coordinate, then create an instance of a FunctionProgressCoordinator, then use it to process the walkers we moved above. Note we overwrite the `walkers` list, as they are the same, just with the extra progress coordinate information added:

In [None]:
def pc_func(state):
    '''
    A function that takes a state and returns a progress coordinate
    
    Trivial in this case as state and PC are the same!
    '''
    return state

progress_coordinator = FunctionProgressCoordinator(pc_func)
walkers = progress_coordinator.run(walkers)
for walker in walkers:
    print(walker)

----------
### The Binner

Now we need a "binner" function that will assign each of our walkers to a bin, based on their progress coordinate. We will use the `StaticBinner` class from `WElib` for this, which uses static bin boundary definitions. For this all we need to define are the positions of the bin edges. We use 0.1 increments in the left-hand well (as the walkers "climb the hill") but a coarser binning once they are over the transition state. 

In real life, in addition to choosing the right progress coordinate definition, much of the effort in a WE simulation project is getting the binning strategy right.

In [None]:
edges = [-2.0, -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7,
         -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.2, 0.5, 1.0, 1.5, 2.0]
binner = StaticBinner(edges)

Test the binner, as before. The binner also keeps a record of a) the current and b) the mean weight in each bin, this mean weight data will be needed later:

In [None]:
walkers = binner.run(walkers)
for walker in walkers:
    print(walker)
for key in sorted(binner.bin_weights): # This is a dictionary with the ID and current weight in each populated bin
    print('Bin: ', key, ' weight: ', binner.bin_weights[key]) 

--------------

### The Stepper
Each WE cycle, the walkers take a "step" that might get them to a new bin. We need to decide how long the BD simulations should be, to optimise this.

Putting the walkers aside for a moment, examine the mean coordinate shift for different length BD simulations. The aim is to find the shortest simulation that will still give a reasonable shift in coordinates:

In [None]:
x = -2.0 # start at the bottom of the left-hand well
dt = 1e-5 # BD timestep
gamma = 1.0 # BD friction constant
# DWP parameters:
k = 1.0
a = 2.0
b = 1.0

mean_dx = [] # list to store mean displacements in

step_choices = [100, 200, 300, 500, 1000, 2000, 3000, 4000, 5000]

for n_steps in step_choices:
    sum_dx = 0.0
    n_reps = 1000
    for rep in range(n_reps):
        xnew = bd_dwp_sim(x, n_steps, dt, gamma, k, a, b)
        sum_dx += np.abs(x-xnew)
    mean_dx.append(sum_dx / n_reps)

plt.plot(step_choices, mean_dx)
plt.xlabel('n_steps')
plt.ylabel('mean coordinate shift')

We see a pattern of diminishing returns: a simulation twice as long results in less than a doubling in the mean distance moved. Therefore there seems little reason to make individual simulations longer than 1000 steps. This would make bins 0.1 apart (i.e., about the same as the mean coordinate shift) a reasonable choice. Lucky that's what we have already used, above!

Use these parameters to define the "stepper" function for the WE workflow. This uses `WElib`'s `FunctionStepper` class, which is initialised with the function to be called and any arguments it takes in addition to the walker's coordinates:

In [None]:
n_steps = 1000
dt = 1e-5
gamma = 1.0
k = 1.0
a = 2.0
b = 1.0
stepper = FunctionStepper(bd_dwp_sim, n_steps, dt, gamma, k, a, b)

Test the stepper function - let each of the walkers make a step, and then see what the new coordinates of each has become (and keep a record of them using the state_recorder):

In [None]:
moved_walkers = stepper.run(walkers)
for i, w in enumerate(moved_walkers):
    print(f'new coordinate for walker {i}: {w.state}')
for walker in moved_walkers:
    print(walker)

Note the progress coordinates and bin assignments for each walker have been reset to `None`, as the coordinates have changed, so we need to re-run the ProgressCoordinator:

In [None]:
moved_walkers = progress_coordinator.run(moved_walkers)
for walker in moved_walkers:
    print(walker)

-----------
### The Recycler
We are using the "steady state" version of the WE method, so we need a "recycler": a function which looks to see if any of the walkers has reached or exceeded the target value for the progress coordinate, and if so, replaces it with a walker with the same weight, but with the initial coordinates. The "recycler" also keeps a record of how much (if any) weight has been recycled (the flux), and -if any - a list of recycled walkers:

In [None]:
target_pc = 2.0
recycler = Recycler(target_pc)

moved_walkers = recycler.run(moved_walkers)
print('flux=', recycler.flux)
print('recycled walkers: ', recycler.recycled_walkers)

As expected, nothing gets recycled, all the walkers are still in the LH well. We still need to rerun the `binner` to add bin ids to the moved walkers:

In [None]:
moved_walkers = binner.run(moved_walkers)
for walker in moved_walkers:
    print(walker)
for key in sorted(binner.bin_weights):
    print('Bin: ', key, ' weight: ', binner.bin_weights[key]) 
for key in sorted(binner.mean_bin_weights): # This is a dictionary with the ID and mean weight in each populated bin
    print('Bin: ', key, ' mean weight: ', binner.mean_bin_weights[key]) 

---------
### The Splitter/Merger
Next we need a "splitmerger" that will split or merge walkers in each bin, according to the WE rules. `WElib` contains a straightforward class for this which will fit most circumstances. Instances of it just need to define the target number of walkers per bin:

In [None]:
n_reps = 5
splitmerger = SplitMerger(n_reps)

splitmerged_walkers = splitmerger.run(moved_walkers)
for walker in splitmerged_walkers:
    print(walker)

Note we now have more walkers - `n_reps` per occupied bin.

----------
## Part 4: Running a complete WE simulation workflow
OK, now we have tested the individual components, we can build a complete WE simulation. Each cycle will record the flux of walker weight from the target state (here, the same as the target progress coordinate) back to the initial state.

This cell is doing all the work, and may take a little time to run...

In [None]:
n_reps = 5
initial_coordinates = -2.0
initial_weight = 1.0/n_reps
walkers = [Walker(initial_coordinates, initial_weight) for i in range(n_reps)]
walkers = progress_coordinator.run(walkers)
walkers = binner.run(walkers)
n_cycles = 1000
walkers_per_cycle = []
binner.reset() # zero the weights memory
forward_recycled_walkers = []
for i in range(n_cycles):
    walkers = splitmerger.run(walkers)
    walkers_per_cycle.append(len(walkers))
    walkers = stepper.run(walkers)
    walkers = progress_coordinator.run(walkers)
    walkers = recycler.run(walkers)
    if recycler.flux > 0.0:
        walkers = progress_coordinator.run(walkers)
        forward_recycled_walkers += recycler.recycled_walkers
    walkers = binner.run(walkers)
    
    if i % (n_cycles // 10) == 0:
        print("{:4.1f}% done".format(100*i/n_cycles))
print('complete')

Plot a) the flux (weight being recycled from the target state to the initial state); b) total number of walkers as a function of time, and c) the mean weight in each bin:

In [None]:
plt.figure(figsize=(12,8))
plt.subplot(221)
plt.plot(recycler.flux_history)
plt.xlabel('time')
plt.ylabel('flux')
plt.subplot(222)
plt.plot(walkers_per_cycle)
plt.xlabel('time')
plt.ylabel('walkers per cycle')
plt.subplot(223)

bin_ids = list(binner.mean_bin_weights.keys())
mean_bin_weights = list(binner.mean_bin_weights.values())
plt.plot(bin_ids, mean_bin_weights)
plt.xlabel('bin index')
plt.ylabel('mean bin weight')

From the top two graphs we can see an equilibration phase as simulations gradually percolate through the bins and more and more walkers are required per cycle; only after a lag do the first of them reach the target progress coordinate and so begin to get recycled. The bin weights graph shows that though walkers percolate through the bins, the vast majority of the weight remains in the left-hand well. Its also evident from the flux graph that WE simulations are noisy: it may take quite a time to generate reasonably converged numbers from the data. In any case, we need to exclude the initial equilibration phase from the stats:

In [None]:
mean_forward_flux = np.array(recycler.flux_history)[100:].mean() # exclude first 100 data points
print('mean flux L->R = ',mean_forward_flux)
mean_forward_concentration = sum(mean_bin_weights[:19])
print('mean [L] = ', mean_forward_concentration)

The stepper keeps a record of all the states it has generated for the walkers. We can now use the data stored in the recorder to produce a 'replay' of the path taken by any of the walkers that made it to the target state (ones in the `recycled_walkers` list):

In [None]:
chosen_walker = 70
plt.plot(stepper.recorder.replay(forward_recycled_walkers[chosen_walker]))
plt.xlabel('step #')
plt.ylabel('progress coordinate')

By changing the value of `chosen_walker` you will be able to see how the first walkers to reach ther right-hand well hopped over the barrier quite quickly, but how later ones stay in the left-hand well for longer and longer before transitioning.

-----------
## Part 5: Running the reverse WE simulation workflow
Now we do the whole thing again, only this time from the right-hand well back to the left-hand one. First set things up:

In [None]:
initial_coordinates = 2.0
target_pc = -2.0
walkers = [Walker(initial_coordinates, 1.0/n_reps) for i in range(n_reps)]
edges = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
         1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
binner = StaticBinner(edges)
recycler = Recycler(target_pc, retrograde=True)

Now run the simulation:

In [None]:
n_cycles = 1000
walkers_per_cycle = []
reverse_recycled_walkers = []

walkers = progress_coordinator.run(walkers)

for i in range(n_cycles):
    walkers = splitmerger.run(walkers)
    walkers_per_cycle.append(len(walkers))
    walkers = stepper.run(walkers)
    walkers = progress_coordinator.run(walkers)
    walkers = recycler.run(walkers)
    if recycler.flux > 0.0: # Recycled walkers need their pc and bin id updated...
        walkers = progress_coordinator.run(walkers)
        reverse_recycled_walkers += recycler.recycled_walkers
    walkers = binner.run(walkers)
    
    if i % (n_cycles // 10) == 0:
        print("{:4.0f}% done...".format(100*i/n_cycles))
print('complete')

In [None]:
plt.figure(figsize=(12,8))
plt.subplot(221)
plt.plot(recycler.flux_history)
plt.xlabel('time')
plt.ylabel('flux')
plt.subplot(222)
plt.plot(walkers_per_cycle)
plt.xlabel('time')
plt.ylabel('walkers per cycle')
plt.subplot(223)
bin_ids = list(binner.mean_bin_weights.keys())
mean_bin_weights = list(binner.mean_bin_weights.values())
plt.plot(bin_ids, mean_bin_weights)
plt.xlabel('bin index')
plt.ylabel('mean bin weight')

Calculate the reverse fluxes:

In [None]:
mean_reverse_flux = np.array(recycler.flux_history)[100:].mean() # discard first 100 data points again
print('mean flux R->L:', mean_reverse_flux)
mean_reverse_concentration = sum(mean_bin_weights[5:])
print('mean [R] = ', mean_reverse_concentration)

In [None]:
chosen_walker = 70
plt.plot(stepper.recorder.replay(reverse_recycled_walkers[chosen_walker]))
plt.xlabel('step #')
plt.ylabel('progress coordinate')

-------------
## Part 6: Calculate the kinetic and thermodynamic parameters
Now we can calculate the rate constants and the equilibrium constant. To get the rate constants from the fluxes, we must correct for the concentration (weight) of the "reactants". Because weight only trickles over the barrier very slowly, and is then rapidly recycled, the total weight on the "reactants" side is typically very close to 1 - but sometimes it may not be, and we should be accurate:

In [None]:
forward_rate = mean_forward_flux / mean_forward_concentration
reverse_rate = mean_reverse_flux / mean_reverse_concentration
print('forward rate constant: {:6.2e}'.format(forward_rate))
print('reverse rate constant: {:6.2e}'.format(reverse_rate))
keq = reverse_rate / forward_rate
print('Keq = {:6.4}; deltaG = {:6.2f}'.format(keq, -np.log(keq)))

How have we done? For reference, the energy difference between the two minima is 4kT.

There are many experiments you can run using this notebook:

 - how does changing the shape of the skewed potential affect performance?
 - how does changing the binning affect performance?
 - how much do the kinetic parameters vary between replicate simulations - should they be longer?
 - how does changing the number of walkers per bin affect performance?