In [17]:
from functions import *
from decimal import Decimal, getcontext

# Set the precision for decimal operations
getcontext().prec = 50

In [63]:
def likelihood(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """ 
    """
    return jnp.prod(likelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes))
    

In [59]:
def likelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """
    """
    # Calculate the difference between distance map and reference 
    # distance map
    subtraction_map_sq = (dmap_flat - ref_dmap_flat)**2
    sum_subtraction_map_sq = jnp.sum(subtraction_map_sq)
    
    # Calculate the normalization factor
    normalization_factor = 1/((jnp.sqrt(2*np.pi)*measurement_error)**((num_probes)**2))
    
    # Calculate the gaussian term 
    gaussian_term = jnp.exp(-jnp.sum(sum_subtraction_map_sq)/(2*measurement_error**2))
    
    print('Scaling factor = {}'.format(normalization_factor))
    print('Gaussian term = {}'.format(gaussian_term))
    
    return normalization_factor, gaussian_term

In [68]:
def loglikelihood(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """
    """
    return jnp.sum(jnp.array(loglikelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes)))

In [69]:
def loglikelihood_(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    """ 
    """
    # Calculate the difference between distance map and reference 
    # distance map
    subtraction_map_sq = (dmap_flat - ref_dmap_flat)**2
    sum_subtraction_map_sq = jnp.sum(subtraction_map_sq)
    
    # Calculate the normalization factor
    normalization_factor = -num_probes**2 * jnp.log(jnp.sqrt(2*np.pi*measurement_error**2))
    
    # Calculate the gaussian term 
    gaussian_term = -jnp.sum(sum_subtraction_map_sq)/(2*measurement_error**2)
    
    print('Scaling factor = {}'.format(normalization_factor))
    print('Gaussian term = {}'.format(gaussian_term))
    
    return normalization_factor, gaussian_term

In [73]:
def prior(dmap_flat, num_probes):
    """
    """
    return jnp.prod(jnp.array(prior_(dmap_flat, num_probes)))

In [75]:
def prior_(dmap_flat, num_probes):
    """
    """
    # Get 2D map back to simplify the expression 
    dmap = np.reshape(dmap_flat, [num_probes, num_probes])
    
    # Calculate the squared end-to-end distance 
    R_sq = dmap[0][-1] ** 2
    
    # Calculate the average bond length
    b = np.mean(np.diag(dmap, 1))
    
    N = num_probes
    
    # Calculate the probability
    scaling_factor = (3/(2*np.pi*N*b**2)) ** 1.5
    gaussian_term = jnp.exp(-3*R_sq/(2*N*b**2))
    
    print('Scaling factor = {}'.format(scaling_factor))
    print('Gaussian term = {}'.format(gaussian_term))
    
    return scaling_factor, gaussian_term 

In [2]:
num_monomers = 20
mean_bond_length = 1
std_bond_length = 20

template_chain = generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length)

In [5]:
num_observations = 10
gaussian_noise_std = 10

observations = generate_observations(template_chain, num_observations, gaussian_noise_std)

In [42]:
template_chain_flatten = generate_flatten_distance_map(template_chain)
observations_flatten = [generate_flatten_distance_map(c) for c in observations]

In [71]:
measurement_error = 10
num_probes = num_monomers

ll_template = loglikelihood(template_chain_flatten, template_chain_flatten, measurement_error, num_probes)

Normalization factor = -1288.609375
Gaussian term = -0.0


In [70]:
measurement_error = 10
lls = [loglikelihood(template_chain_flatten, x, measurement_error, num_probes) for x in observations_flatten]

Normalization factor = -1288.609375
Gaussian term = -347.5065612792969
Normalization factor = -1288.609375
Gaussian term = -311.9415588378906
Normalization factor = -1288.609375
Gaussian term = -315.6540222167969
Normalization factor = -1288.609375
Gaussian term = -257.48223876953125
Normalization factor = -1288.609375
Gaussian term = -315.6877136230469
Normalization factor = -1288.609375
Gaussian term = -347.7548522949219
Normalization factor = -1288.609375
Gaussian term = -375.1178894042969
Normalization factor = -1288.609375
Gaussian term = -343.8528137207031
Normalization factor = -1288.609375
Gaussian term = -348.2294006347656
Normalization factor = -1288.609375
Gaussian term = -643.3494262695312


In [61]:
likelihood(template_chain_flatten, observations_flatten[0], measurement_error, num_probes)

Normalization factor = 0.0
Gaussian term = 1.0


0.0

In [76]:
prior(template_chain_flatten, num_probes)

Scaling factor = 1.1828562172209968e-07
Gaussian term = 0.48222067952156067


Array(5.703977e-08, dtype=float32)

In [50]:
from scipy.special import logsumexp 

In [77]:
# This logsumexp get to underflow problem! So we need to try renormalization approach

logsumexp([likelihood_(template_chain_flatten, observations_flatten[0], measurement_error, num_probes)[0]])

Normalization factor = 0.0
Gaussian term = 0.0


0.0