In [77]:
from functions import *

In [45]:
def likelihood(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """ 
    """
    return jnp.prod(jnp.array(likelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes)))
    

In [46]:
def likelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """
    """
    # Calculate the difference between distance map and reference 
    # distance map
    subtraction_map_sq = (dmap_flat - ref_dmap_flat)**2
    sum_subtraction_map_sq = jnp.sum(subtraction_map_sq)
    
    # Calculate the normalization factor
    normalization_factor = 1/((jnp.sqrt(2*np.pi)*measurement_error)**((num_probes)**2))
    
    # Calculate the gaussian term 
    gaussian_term = jnp.exp(-jnp.sum(sum_subtraction_map_sq)/(2*measurement_error**2))
    
    # print('Scaling factor = {}'.format(normalization_factor))
    # print('Gaussian term = {}'.format(gaussian_term))
    
    return normalization_factor, gaussian_term

In [47]:
def loglikelihood(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """
    """
    return jnp.sum(jnp.array(loglikelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes)))

In [48]:
def loglikelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """ 
    """
    # Calculate the difference between distance map and reference 
    # distance map
    subtraction_map_sq = (dmap_flat - ref_dmap_flat)**2
    sum_subtraction_map_sq = jnp.sum(subtraction_map_sq)
    
    # Calculate the normalization factor
    normalization_factor = -num_probes**2 * jnp.log(jnp.sqrt(2*np.pi*measurement_error**2))
    
    # Calculate the gaussian term 
    gaussian_term = -jnp.sum(sum_subtraction_map_sq)/(2*measurement_error**2)
    
    # print('Scaling factor = {}'.format(normalization_factor))
    # print('Gaussian term = {}'.format(gaussian_term))
    
    return normalization_factor, gaussian_term

In [92]:
def prior(dmap_flat, num_probes):
    """
    """
    return jnp.prod(jnp.array(prior_(dmap_flat, num_probes)))

In [89]:
def prior_(dmap_flat, num_probes):
    """
    """
    # Get 2D map back to simplify the expression 
    dmap = jnp.reshape(dmap_flat, [num_probes, num_probes])
    
    # Calculate the squared end-to-end distance 
    R_sq = dmap[0][-1] ** 2
    
    # Calculate the average bond length
    b = jnp.mean(jnp.diag(dmap, 1))
    
    N = num_probes
    
    # Calculate the probability
    scaling_factor = (3/(2*np.pi*N*b**2)) ** 1.5
    gaussian_term = jnp.exp(-3*R_sq/(2*N*b**2))
    
    # print('Scaling factor = {}'.format(scaling_factor))
    # print('Gaussian term = {}'.format(gaussian_term))
    
    return scaling_factor, gaussian_term 

In [90]:
def logprior(dmap_flat, num_probes):
    """
    """
    return jnp.sum(jnp.array(logprior_(dmap_flat, num_probes)))

In [91]:
def logprior_(dmap_flat, num_probes):
    """
    """
    # Get 2D map back to simplify the expression 
    dmap = jnp.reshape(dmap_flat, [num_probes, num_probes])
    
    # Calculate the squared end-to-end distance 
    R_sq = dmap[0][-1] ** 2
    
    # Calculate the average bond length
    b = jnp.mean(jnp.diag(dmap, 1))
    
    N = num_probes
    
    # Calculate the probability
    scaling_factor = 1.5 * jnp.log(3/(2*np.pi*N*b**2))
    gaussian_term = -3*R_sq/(2*N*b**2)
    
    # print('Scaling factor = {}'.format(scaling_factor))
    # print('Gaussian term = {}'.format(gaussian_term))
    
    return scaling_factor, gaussian_term 

In [53]:
num_monomers = 20
mean_bond_length = 1
std_bond_length = 20

template_chain = generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length)

In [54]:
num_observations = 10
gaussian_noise_std = 10

observations = generate_observations(template_chain, num_observations, gaussian_noise_std)

In [55]:
template_chain_flatten = generate_flatten_distance_map(template_chain)
observations_flatten = [generate_flatten_distance_map(c) for c in observations]

In [56]:
measurement_error = 10
num_probes = num_monomers

ll_template = loglikelihood(template_chain_flatten, template_chain_flatten, measurement_error, num_probes)

In [57]:
measurement_error = 10
lls = [loglikelihood(template_chain_flatten, x, measurement_error, num_probes) for x in observations_flatten]

In [58]:
likelihood(template_chain_flatten, observations_flatten[0], measurement_error, num_probes)

Array(0., dtype=float32)

In [59]:
prior(template_chain_flatten, num_probes)

Array(9.447387e-09, dtype=float32)

In [60]:
jscipy.special.logsumexp(jnp.array(lls))

Array(-1565.2112, dtype=float32)

In [85]:
def generate_posterior(templates, observations, template_weights):
    """
    """
    templates_flatten = [generate_flatten_distance_map(t) for t in templates]
    observations_flatten = [generate_flatten_distance_map(o) for o in observations]
    
    weight_prior = 1/len(template_weights) 
    total_posterior = 0
    
    for o in observations_flatten:
        curr_obs_list = []
        for t, alpha in zip(templates_flatten, template_weights):
            val = 0
            val += loglikelihood(o, t, measurement_error, num_probes)
            val += logprior(t, num_probes)
            val += jnp.log(alpha)
            val += jnp.log(weight_prior)
            curr_obs_list.append(val)
        curr_obs_list = jnp.array(curr_obs_list) 
        total_posterior += jscipy.special.logsumexp(curr_obs_list)
    
    return total_posterior

In [62]:
generate_posterior([template_chain], observations, [1])

Array(-16892.314, dtype=float32)

In [63]:
num_monomers = 20
mean_bond_length = 1
std_bond_length = 20

template_chain_1 = generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length)

num_observations = 5
gaussian_noise_std = 10

observations_1 = generate_observations(template_chain_1, num_observations, gaussian_noise_std)

template_chain_2 = generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length)

num_observations = 10
gaussian_noise_std = 10

observations_2 = generate_observations(template_chain_2, num_observations, gaussian_noise_std)

template_chain_3 = generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length)

num_observations = 15
gaussian_noise_std = 10

observations_3 = generate_observations(template_chain_3, num_observations, gaussian_noise_std)

In [64]:
templates_list = [template_chain_1, 
                  template_chain_2,
                  template_chain_3]

In [65]:
observation_list = np.concatenate([observations_1, 
                                   observations_2, 
                                   observations_3])

In [66]:
generate_posterior(templates_list, observation_list, [5, 10, 15])

Array(-49961.812, dtype=float32)

In [86]:
def weight_objective(template_weights):
    """
    """
    templates = templates_list
    observations = observation_list 
    return generate_posterior(templates, observations, template_weights)

In [87]:
initial_weight_guess = [1., 1., 1.]

In [75]:
weight_objective_jacobian = jax.jacfwd(weight_objective)

In [79]:
import jaxopt

In [94]:
solver = jaxopt.LBFGS(fun=weight_objective, verbose=True)
res = solver.run(jnp.array(initial_weight_guess))

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): inf Stepsize:1.0  Decrease Error:inf  Curvature Error:inf 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): inf Stepsize:0.5  Decrease Error:inf  Curvature Error:inf 
INFO: jaxopt.ZoomLineSearch: Iter: 3 Minimum Decrease & Curvature Errors (stop. crit.): inf Stepsize:0.25  Decrease Error:inf  Curvature Error:inf 
INFO: jaxopt.ZoomLineSearch: Iter: 4 Minimum Decrease & Curvature Errors (stop. crit.): inf Stepsize:0.125  Decrease Error:inf  Curvature Error:inf 
INFO: jaxopt.ZoomLineSearch: Iter: 5 Minimum Decrease & Curvature Errors (stop. crit.): 3588.0302734375 Stepsize:0.0625  Decrease Error:0.0  Curvature Error:3588.0302734375 
INFO: jaxopt.ZoomLineSearch: Iter: 6 Minimum Decrease & Curvature Errors (stop. crit.): inf Stepsize:0.09375  Decrease Error:inf  Curvature Error:inf 
INFO: jaxopt.ZoomLineSearch: Iter: 7 Minimum Decrease & Curvature Errors (stop. crit

KeyboardInterrupt: 