In [1]:
from functions import *

In [4]:
def add(a, b):
    return a + b

batch_a = jnp.array([1, 2, 3, 4])
batch_b = jnp.array([10, 20, 30, 40])
added_batch = jax.vmap(add)(batch_a, batch_b)

print(added_batch)

[11 22 33 44]


In [5]:
type(added_batch)

jaxlib.xla_extension.Array

In [2]:
# Define parameters 
num_monomers = 20 
mean_bond_length = 1
std_bond_length = 20
gaussian_noise_std = 10

num_observation_list = [30, 70]
num_templates = 2

template_chain_list = [generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length) for i in range(num_templates)]
observation_list = [generate_observations(c, n, gaussian_noise_std) for c, n in zip(template_chain_list, num_observation_list)]


In [3]:
observation_list = np.concatenate([*observation_list])

In [10]:
import jax.numpy as jnp
from jax import vmap

# Define the function to apply to each combination
def my_function(x, y):
    return x + y  # Example function: sum of x and y

# Define the input arrays
X = jnp.array([1, 2, 3])
Y = jnp.array([4, 5, 6])

# Create the grid of combinations
X_grid, Y_grid = jnp.meshgrid(X, Y)

# Flatten the grids to create pairs
X_flat = X_grid.flatten()
Y_flat = Y_grid.flatten()

# Apply the function using vmap
result = vmap(my_function)(X_flat, Y_flat)

# Reshape the result to match the grid shape if needed
result_grid = result.reshape(X_grid.shape)

print(result_grid)


[[5 6 7]
 [6 7 8]
 [7 8 9]]


In [12]:
X_flat

Array([1, 2, 3, 1, 2, 3, 1, 2, 3], dtype=int32)

In [21]:
# Define the 2D arrays
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = jnp.array([[9, 10], [11, 12]])

# Stack the arrays into a list
arrays = [A, B, C]

# Create the grid of combinations of indices
indices = jnp.arange(len(arrays))
index_grid = jnp.meshgrid(*[indices] * len(arrays), indexing='ij')

# Flatten the index grids
flat_indices = [grid.flatten() for grid in index_grid]

# Stack the flattened indices to create combinations
index_combinations = jnp.stack(flat_indices, axis=-1)

In [16]:
measurement_error = 10
num_probes = 20

In [11]:
x = jnp.arange(10)
y = jnp.arange(10, 20)
y[jnp.where(x > 5)]

Array([16, 17, 18, 19], dtype=int32)

In [12]:
jnp.where(x>5)

(Array([6, 7, 8, 9], dtype=int32),)

In [15]:
jnp.exp(-jnp.inf)

Array(0., dtype=float32, weak_type=True)

In [82]:
def generate_posterior_parallelize(templates, observations, template_weights, weight_renormalization=1000):
    """
    """
    templates_flatten = jnp.array([generate_flatten_distance_map(t) for t in templates])
    observations_flatten = jnp.array([generate_flatten_distance_map(o) for o in observations])
    template_weights = jnp.array(template_weights)
    
    weight_prior = 1/len(template_weights) 
    
    # Generate grid index combination
    template_info_indices = jnp.arange(len(templates_flatten))
    observation_info_indices = jnp.arange(len(observations_flatten))
    t_ind, o_ind = jnp.meshgrid(template_info_indices, observation_info_indices)
    
    total_posterior = 0
    
    t_ind = t_ind.flatten()
    o_ind = o_ind.flatten()
    
    jax.debug.print("Weights at current iteration: {y}", y=template_weights)
    def calculate_rhs(t_ind, o_ind):
        val = 0 
        o = observations_flatten[o_ind]
        t = templates_flatten[t_ind]
        alpha = template_weights[t_ind]
        
        val += loglikelihood(o, t, measurement_error, num_probes)

        val += logprior(t, num_probes)

        # This is the correct one 
        # But the scaling between alpha and weight priors and logliokelihood are so different 
        # val += jnp.log(alpha + 1e-32) * weight_renormalization 
        val += jnp.log(jnp.abs(alpha) + 1e-32) * weight_renormalization  # use jnp.abs to make sure that each alpha does not go to 0
        val += jnp.log(weight_prior) * weight_renormalization
           
        return val 
    
    def calculate_posterior(i):
        return jscipy.special.logsumexp(jnp.where(o_ind == i, curr_obs_list, -jnp.inf))
    
    curr_obs_list = jnp.array(jax.vmap(calculate_rhs)(t_ind, o_ind))
    
    total_posterior = jnp.sum(jax.vmap(calculate_posterior)(jnp.arange(len(observations))))

    return total_posterior

def weight_neg_objective_parallelize(template_weights):
    """
    """
    templates = template_chain_list
    observations = observation_list
    return -generate_posterior_parallelize(templates, observations, template_weights)
    

In [19]:
 def calculate_rhs(template_info, alpha, o, weight_renormalization, weight_prior):
    t, alpha = template_info
    
    val = 0
    val += loglikelihood(o, t, measurement_error, num_probes)
    val += logprior(t, num_probes)
    # This is the correct one 
    # But the scaling between alpha and weight priors and logliokelihood are so different 
    val += jnp.log(alpha) * weight_renormalization
    val += jnp.log(weight_prior) * weight_renormalization # * 1000 to make sure that they are in the same scale 
    
    # But if we add the alpha directly such that they are in the same scale, 
    # that might help
    # val += alpha
    # val += weight_prior
    
    return val

def generate_posterior_parallelize(templates, observations, template_weights, weight_renormalization=1000):
    """
    """
    templates_flatten = jnp.array([generate_flatten_distance_map(t) for t in templates])
    observations_flatten = jnp.array([generate_flatten_distance_map(o) for o in observations])
    
    weight_prior = 1/len(template_weights) 
    total_posterior = 0
    
    # Create the grid of combinations
    templates_grid, observations_grid = jnp.meshgrid(templates_flatten, observations_flatten)
    weights_grid, _ = jnp.meshgrid(templates_weights, observations_flatten)
    
    # Flatten the grids to create pairs
    templates_flat = templates_grid.flatten()
    observations_flat = observations_grid.flatten()
    weights_flat = weights_grid.flatten()

    curr_obs_list = jnp.array(jax.vmap(calculate_rhs)(template_info_flat)(weights_flat)(observations_flat)(weight_renormalization)(weight_prior))
     
    total_posterior += jscipy.special.logsumexp(curr_obs_list)
    
    return total_posterior

def weight_neg_objective_parallelize(template_weights):
    """
    """
    templates = template_chain_list
    observations = observation_list
    return -generate_posterior_parallelize(templates, observations, template_weights)

In [27]:
weight_neg_objective_parallelize([30, 70])

Array(-159339.62, dtype=float32)

In [25]:
weight_neg_objective([30, 70])

Array(-159339.64, dtype=float32)

In [24]:
def generate_posterior(templates, observations, template_weights, weight_renormalization=1000):
    """
    """
    templates_flatten = [generate_flatten_distance_map(t) for t in templates]
    observations_flatten = [generate_flatten_distance_map(o) for o in observations]
    
    weight_prior = 1/len(template_weights) 
    total_posterior = 0
    
    for o in observations_flatten:
        curr_obs_list = []
        for t, alpha in zip(templates_flatten, template_weights):
            val = 0
            val += loglikelihood(o, t, measurement_error, num_probes)
            val += logprior(t, num_probes)
            # This is the correct one 
            # But the scaling between alpha and weight priors and logliokelihood are so different 
            val += jnp.log(alpha) * weight_renormalization
            val += jnp.log(weight_prior) * weight_renormalization # * 1000 to make sure that they are in the same scale 
            
            # But if we add the alpha directly such that they are in the same scale, 
            # that might help
            # val += alpha
            # val += weight_prior
            
            curr_obs_list.append(val)
        curr_obs_list = jnp.array(curr_obs_list) 
        total_posterior += jscipy.special.logsumexp(curr_obs_list)
    
    return total_posterior

def weight_neg_objective(template_weights):
    """
    """
    templates = template_chain_list
    observations = observation_list
    return -generate_posterior(templates, observations, template_weights)

In [None]:
def generate_posterior(templates, observations, template_weights, weight_renormalization=1000):
    """
    """
    templates_flatten = [generate_flatten_distance_map(t) for t in templates]
    observations_flatten = [generate_flatten_distance_map(o) for o in observations]
    
    weight_prior = 1/len(template_weights) 
    total_posterior = 0
    
    for o in observations_flatten:
        curr_obs_list = []
        for t, alpha in zip(templates_flatten, template_weights):
            val = 0
            val += loglikelihood(o, t, measurement_error, num_probes)
            val += logprior(t, num_probes)
            # This is the correct one 
            # But the scaling between alpha and weight priors and logliokelihood are so different 
            val += jnp.log(alpha) * weight_renormalization
            val += jnp.log(weight_prior) * weight_renormalization # * 1000 to make sure that they are in the same scale 
            
            # But if we add the alpha directly such that they are in the same scale, 
            # that might help
            # val += alpha
            # val += weight_prior
            
            curr_obs_list.append(val)
        curr_obs_list = jnp.array(curr_obs_list) 
        total_posterior += jscipy.special.logsumexp(curr_obs_list)
    
    return total_posterior

def weight_neg_objective(template_weights):
    """
    """
    templates = template_chain_list
    observations = observation_list
    return -generate_posterior(templates, observations, template_weights)

In [5]:
# Create a random key
key = random.PRNGKey(1)

# Generate uniform random numbers
key, subkey = random.split(key)
uniform_array = random.uniform(subkey, shape=(1, 2)) 

INFO:jax._src.lib.xla_bridge:Remote TPU is not linked into jax; skipping remote TPU.
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'


INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.


In [6]:
measurement_error = gaussian_noise_std
num_probes = num_monomers

In [62]:
type(uniform_array)

jaxlib.xla_extension.Array

In [63]:
type(jnp.array(uniform_array))

jaxlib.xla_extension.Array

In [30]:


pg = ProjectedGradient(fun=weight_neg_objective_parallelize, projection=projection_simplex, implicit_diff=True, verbose=True)
pg_sol = pg.run(uniform_array[1, :], hyperparams_proj=sum(num_observation_list))

INFO: jaxopt.ProximalGradient: Iter: 1 Distance btw Iterates (stop. crit.): 297269.875 Stepsize:0.000244140625 
INFO: jaxopt.ProximalGradient: Iter: 2 Distance btw Iterates (stop. crit.): 100.38153839111328 Stepsize:0.00048828125 
INFO: jaxopt.ProximalGradient: Iter: 3 Distance btw Iterates (stop. crit.): 113.89667510986328 Stepsize:0.0009765625 
INFO: jaxopt.ProximalGradient: Iter: 4 Distance btw Iterates (stop. crit.): 123.01586151123047 Stepsize:0.001953125 
INFO: jaxopt.ProximalGradient: Iter: 5 Distance btw Iterates (stop. crit.): 127.76673889160156 Stepsize:0.00390625 
INFO: jaxopt.ProximalGradient: Iter: 6 Distance btw Iterates (stop. crit.): 126.60733032226562 Stepsize:0.0078125 
INFO: jaxopt.ProximalGradient: Iter: 7 Distance btw Iterates (stop. crit.): 115.83776092529297 Stepsize:0.015625 
INFO: jaxopt.ProximalGradient: Iter: 8 Distance btw Iterates (stop. crit.): 89.19316101074219 Stepsize:0.03125 
INFO: jaxopt.ProximalGradient: Iter: 9 Distance btw Iterates (stop. crit.): 4

In [31]:
pg_sol.params 
# This works beautifully, now we are gonna add more templates from the sample

Array([30.044329, 69.955666], dtype=float32)

In [7]:
# now lets add some random polymer in the sample to the model 

# shuffle observation list 
np.random.shuffle(observation_list)
template_chain_list_with_obs = np.concatenate([template_chain_list, observation_list]) 


In [9]:
def weight_neg_objective_parallelize_for_multiple(template_weights):
    """
    """
    templates = template_input
    observations = observation_list
    return -generate_posterior_parallelize(templates, observations, template_weights)

In [10]:
import torch
torch.cuda.empty_cache() 

In [71]:
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_simplex
sol_list = [] # this can be parallelize
torch.cuda.empty_cache() 
for i in tqdm(range(1, 50)): 
    torch.cuda.empty_cache()
    template_input = template_chain_list_with_obs[:i]
    uniform_array = random.uniform(subkey, shape=(1, i)) 
    pg = ProjectedGradient(fun=weight_neg_objective_parallelize_for_multiple, projection=projection_simplex, implicit_diff=True)
    pg_sol = pg.run(uniform_array[1, :], hyperparams_proj=sum(num_observation_list))
    sol_list.append(pg_sol)

  0%|          | 0/49 [00:00<?, ?it/s]

In [72]:
param_list = [x.params for x in sol_list]

In [73]:
param_list # this gives me nan because the polymer is exactly siomilar to the one in observation 

[Array([100.], dtype=float32),
 Array([30.029314, 69.97068 ], dtype=float32),
 Array([29.99826, 70.00173,  0.     ], dtype=float32),
 Array([29.989767,  0.      , 70.01022 ,  0.      ], dtype=float32),
 Array([ 0.      ,  0.      , 70.00058 ,  0.      , 29.999414], dtype=float32),
 Array([ 0.      ,  0.      , 70.01355 ,  0.      , 29.986456,  0.      ],      dtype=float32),
 Array([30.007818,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        69.99219 ], dtype=float32),
 Array([29.99987,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        70.00012,  0.     ], dtype=float32),
 Array([ 0.      , 69.99245 ,  0.      , 30.007559,  0.      ,  0.      ,
         0.      ,  0.      ,  0.      ], dtype=float32),
 Array([ 0.      , 70.000595,  0.      , 29.999409,  0.      ,  0.      ,
         0.      ,  0.      ,  0.      ,  0.      ], dtype=float32),
 Array([29.999653, 70.00034 ,  0.      ,  0.      ,  0.      ,  0.      ,
         0.      ,  0.      ,  0.      ,  0.      

In [12]:
torch.cuda.empty_cache()
i = 3
template_input = template_chain_list_with_obs[:i]
uniform_array = random.uniform(subkey, shape=(1, i)) 
pg = ProjectedGradient(fun=weight_neg_objective_parallelize_for_multiple, projection=projection_simplex, implicit_diff=True)
pg_sol = pg.run(uniform_array[1, :], hyperparams_proj=sum(num_observation_list))
pg_sol.params

-5665.23974609375, 0, 0
-3139.442626953125, 1, 0
-2924.3740234375, 2, 0
-3001.72412109375, 0, 1
-4959.82763671875, 1, 1
-5737.810546875, 2, 1
-5452.77099609375, 0, 2
-2719.2587890625, 1, 2
-3990.781005859375, 2, 2
-3103.89453125, 0, 3
-5496.51171875, 1, 3
-6217.2890625, 2, 3
-3087.11962890625, 0, 4
-5390.5078125, 1, 4
-6269.142578125, 2, 4
-5708.478515625, 0, 5
-2811.514892578125, 1, 5
-4167.9833984375, 2, 5
-5850.3388671875, 0, 6
-2799.411865234375, 1, 6
-4022.009765625, 2, 6
-5249.99951171875, 0, 7
-2869.6875, 1, 7
-4169.1728515625, 2, 7
-5292.07470703125, 0, 8
-2828.3232421875, 1, 8
-3984.1845703125, 2, 8
-5474.267578125, 0, 9
-2735.288330078125, 1, 9
-4086.78662109375, 2, 9
-5514.69140625, 0, 10
-2715.42822265625, 1, 10
-4162.2138671875, 2, 10
-6083.138671875, 0, 11
-2901.7392578125, 1, 11
-4066.5087890625, 2, 11
-4848.29443359375, 0, 12
-2768.6220703125, 1, 12
-3593.30029296875, 2, 12
-5270.9677734375, 0, 13
-2832.83740234375, 1, 13
-4295.908203125, 2, 13
-5607.04296875, 0, 14
-28

Array([nan, nan, nan], dtype=float32)

In [12]:
# Now how about adding random model 
new_template_list = [generate_gaussian_chain(num_monomers, mean_bond_length, std_bond_length) for i in range(100)]
 
template_chain_list_with_new = np.concatenate([template_chain_list, new_template_list])


In [68]:

def weight_neg_objective_parallelize_for_multiple(template_weights):
    """
    """
    templates = template_input
    observations = observation_list
    return -generate_posterior_parallelize(templates, observations, template_weights)

from jaxopt import ProjectedGradient
from jaxopt.projection import projection_simplex
sol_list = [] # this can be parallelize
torch.cuda.empty_cache() 
for i in tqdm(range(1, 100)): 
    torch.cuda.empty_cache()
    template_input = template_chain_list_with_new[:i]
    uniform_array = random.uniform(subkey, shape=(1, i)) 
    pg = ProjectedGradient(fun=weight_neg_objective_parallelize_for_multiple, projection=projection_simplex, implicit_diff=True)
    pg_sol = pg.run(uniform_array[1, :], hyperparams_proj=sum(num_observation_list))
    sol_list.append(pg_sol)

  0%|          | 0/99 [00:00<?, ?it/s]

In [69]:
param_list = [x.params for x in sol_list]

In [70]:
param_list

[Array([100.], dtype=float32),
 Array([30.029314, 69.97068 ], dtype=float32),
 Array([29.99826, 70.00173,  0.     ], dtype=float32),
 Array([30.002934, 69.99706 ,  0.      ,  0.      ], dtype=float32),
 Array([30.000183, 69.99982 ,  0.      ,  0.      ,  0.      ], dtype=float32),
 Array([29.991812, 70.00819 ,  0.      ,  0.      ,  0.      ,  0.      ],      dtype=float32),
 Array([30.000013, 69.99999 ,  0.      ,  0.      ,  0.      ,  0.      ,
         0.      ], dtype=float32),
 Array([29.968899, 70.0311  ,  0.      ,  0.      ,  0.      ,  0.      ,
         0.      ,  0.      ], dtype=float32),
 Array([29.993507, 70.00649 ,  0.      ,  0.      ,  0.      ,  0.      ,
         0.      ,  0.      ,  0.      ], dtype=float32),
 Array([30.029722, 69.970276,  0.      ,  0.      ,  0.      ,  0.      ,
         0.      ,  0.      ,  0.      ,  0.      ], dtype=float32),
 Array([29.999712, 70.00029 ,  0.      ,  0.      ,  0.      ,  0.      ,
         0.      ,  0.      ,  0.      ,  

In [84]:
torch.cuda.empty_cache()
i = 4
template_input = template_chain_list_with_new[:i]
uniform_array = random.uniform(subkey, shape=(1, i)) 
uniform_array = uniform_array / jnp.sum(uniform_array) * 100
pg = ProjectedGradient(fun=weight_neg_objective_parallelize_for_multiple, projection=projection_simplex, 
                       implicit_diff=True, implicit_diff_solve=True)
pg_sol = pg.run(uniform_array[1, :], hyperparams_proj=sum(num_observation_list))
pg_sol.params

Weights at current iteration: [41.66624   8.613459 35.632126 14.088178]
Weights at current iteration: [  0. 100.   0.   0.]
Weights at current iteration: [  0. 100.   0.   0.]
Weights at current iteration: [  0. 100.   0.   0.]
Weights at current iteration: [  0. 100.   0.   0.]
Weights at current iteration: [  0. 100.   0.   0.]
Weights at current iteration: [  0. 100.   0.   0.]
Weights at current iteration: [ 8.660751 91.33926   0.        0.      ]
Weights at current iteration: [28.948769 53.76164  17.289597  0.      ]
Weights at current iteration: [35.839302 31.71934  26.992655  5.448708]
Weights at current iteration: [38.752766 20.166397 31.312387  9.768441]
Weights at current iteration: [38.752766 20.166397 31.312387  9.768441]
Weights at current iteration: [37.63099  29.579699 27.166632  5.622685]
Weights at current iteration: [37.314926  32.23193   25.99855    4.4546037]
Weights at current iteration: [37.331497 42.934383 19.734127  0.      ]
Weights at current iteration: [37.20

Array([30.000765, 69.99924 ,  0.      ,  0.      ], dtype=float32)

In [89]:
import umap

ModuleNotFoundError: No module named 'umap'