In [1]:
from _utils import *

# Load dataset from file
folder_path = '/mnt/home/tudomlumleart/ceph/10_ToniaDataset/ToniaDataset_withPolys.mat'

dataset = scipy.io.loadmat(folder_path)

# Hox locus dataset
nHox = 72
ctcfNtPolysHox = dataset['ctcfNtPolysHox']
ctcfPolysHox = dataset['ctcfPolysHox']
ntPolysHox = dataset['ntPolysHox']
radNtPolysHox = dataset['radNtPolysHox']
radPolysHox = dataset['radPolysHox']

# Sox locus dataset
nSox = 93
ctcfNtPolysSox = dataset['ctcfNtPolysSox']
ctcfPolysSox = dataset['ctcfPolysSox']
ntPolysSox = dataset['ntPolysSox']
radNtPolysSox = dataset['radNtPolysSox']
radPolysSox = dataset['radPolysSox']

# Interpolate polymers to fill in the NaN values 
ctcfNtPolysHox = interpolate_polymers(ctcfNtPolysHox)
ctcfPolysHox = interpolate_polymers(ctcfPolysHox)
ntPolysHox = interpolate_polymers(ntPolysHox)
radNtPolysHox = interpolate_polymers(radNtPolysHox)
radPolysHox = interpolate_polymers(radPolysHox)

ctcfNtPolysSox = interpolate_polymers(ctcfNtPolysSox)
ctcfPolysSox = interpolate_polymers(ctcfPolysSox)
ntPolysSox = interpolate_polymers(ntPolysSox)
radNtPolysSox = interpolate_polymers(radNtPolysSox)
radPolysSox = interpolate_polymers(radPolysSox)


def calculate_distance_map(polys):
    # Extract the dimensions of the input array
    num_probes, num_coords, num_cells = polys.shape
    
    # Initialize an array of the same shape to hold the interpolated values
    new_maps = np.zeros((num_cells, num_probes, num_probes))
    
    # Iterate over each cell
    for c in range(num_cells):
        # Extract the data for the current cell
        curr_cells = polys[:, :, c]
        
        # Skip cells with all missing values
        if np.all(np.isnan(curr_cells)):
            continue  # This leaves a matrix of zeros in the output array
        
        # Calculate the pairwise Euclidean distance between each pair of probes
        dmap = squareform(pdist(curr_cells))
        
        # Assign the distance map to the corresponding position in the output array
        new_maps[c, :, :] = dmap
    
    # Return the array with interpolated values
    return new_maps

# Generate distance maps from interpolated polymers
ctcfNtMapsHox = calculate_distance_map(ctcfNtPolysHox)
ctcfMapsHox = calculate_distance_map(ctcfPolysHox)
ntMapsHox = calculate_distance_map(ntPolysHox)
radNtMapsHox = calculate_distance_map(radNtPolysHox)
radMapsHox = calculate_distance_map(radPolysHox)

ctcfNtMapsSox = calculate_distance_map(ctcfNtPolysSox)
ctcfMapsSox = calculate_distance_map(ctcfPolysSox)
ntMapsSox = calculate_distance_map(ntPolysSox)
radNtMapsSox = calculate_distance_map(radNtPolysSox)
radMapsSox = calculate_distance_map(radPolysSox)

# Plot the median maps of these distance maps 
# To check if the distance maps are reasonable
ctcfNtMapsHox_median = np.nanmedian(ctcfNtMapsHox, axis=0)
ctcfMapsHox_median = np.nanmedian(ctcfMapsHox, axis=0)
ntMapsHox_median = np.nanmedian(ntMapsHox, axis=0)
radNtMapsHox_median = np.nanmedian(radNtMapsHox, axis=0)
radMapsHox_median = np.nanmedian(radMapsHox, axis=0)

ctcfNtMapsSox_median = np.nanmedian(ctcfNtMapsSox, axis=0)
ctcfMapsSox_median = np.nanmedian(ctcfMapsSox, axis=0)
ntMapsSox_median = np.nanmedian(ntMapsSox, axis=0)
radNtMapsSox_median = np.nanmedian(radNtMapsSox, axis=0)
radMapsSox_median = np.nanmedian(radMapsSox, axis=0)

# Generate flatten maps from distance maps
ctcfNtFlattenHox = np.array([x.flatten() for x in ctcfNtMapsHox])
ctcfFlattenHox = np.array([x.flatten() for x in ctcfMapsHox])
ntFlattenHox = np.array([x.flatten() for x in ntMapsHox])
radNtFlattenHox = np.array([x.flatten() for x in radNtMapsHox])
radFlattenHox = np.array([x.flatten() for x in radMapsHox])

ctcfNtFlattenSox = np.array([x.flatten() for x in ctcfNtMapsSox])
ctcfFlattenSox = np.array([x.flatten() for x in ctcfMapsSox])
ntFlattenSox = np.array([x.flatten() for x in ntMapsSox])
radNtFlattenSox = np.array([x.flatten() for x in radNtMapsSox])
radFlattenSox = np.array([x.flatten() for x in radMapsSox])

allFlattenHox = np.concatenate((ctcfNtFlattenHox, ctcfFlattenHox, ntFlattenHox, radNtFlattenHox, radFlattenHox), axis=0)

pca_hox = PCA(n_components=2)
pca_hox.fit(allFlattenHox)
# Fit the PCA model to all Hox datasets
ctcfNtHox = pca_hox.transform(ctcfNtFlattenHox)
ctcfHox = pca_hox.transform(ctcfFlattenHox)
ntHox = pca_hox.transform(ntFlattenHox)
radNtHox = pca_hox.transform(radNtFlattenHox)
radHox = pca_hox.transform(radFlattenHox)

# Convert the principal components to a DataFrame
ctcfNtHox_df = pd.DataFrame(ctcfNtHox, columns=['PC1', 'PC2'])
ctcfNtHox_df['label'] = 'ctcfNtHox'
ctcfHox_df = pd.DataFrame(ctcfHox, columns=['PC1', 'PC2'])
ctcfHox_df['label'] = 'ctcfDegHox'
ntHox_df = pd.DataFrame(ntHox, columns=['PC1', 'PC2'])
ntHox_df['label'] = 'ntHox'
radNtHox_df = pd.DataFrame(radNtHox, columns=['PC1', 'PC2'])
radNtHox_df['label'] = 'radNtHox'
radHox_df = pd.DataFrame(radHox, columns=['PC1', 'PC2'])
radHox_df['label'] = 'radDegHox'

all_df = pd.concat([ntHox_df, radNtHox_df, radHox_df, ctcfNtHox_df, ctcfHox_df], axis=0)

# PCA for Sox locus
allFlattenSox = np.concatenate((ctcfNtFlattenSox, ctcfFlattenSox, ntFlattenSox, radNtFlattenSox, radFlattenSox), axis=0)
pca_sox = PCA(n_components=2)
pca_sox.fit(allFlattenSox)
# Fit the PCA model to all Sox datasets
ctcfNtSox = pca_sox.transform(ctcfNtFlattenSox)
ctcfSox = pca_sox.transform(ctcfFlattenSox)
ntSox = pca_sox.transform(ntFlattenSox)
radNtSox = pca_sox.transform(radNtFlattenSox)
radSox = pca_sox.transform(radFlattenSox)

# Convert the principal components to a DataFrame
ctcfNtSox_df = pd.DataFrame(ctcfNtSox, columns=['PC1', 'PC2'])
ctcfNtSox_df['label'] = 'ctcfNtSox'
ctcfSox_df = pd.DataFrame(ctcfSox, columns=['PC1', 'PC2'])
ctcfSox_df['label'] = 'ctcfDegSox'
ntSox_df = pd.DataFrame(ntSox, columns=['PC1', 'PC2'])
ntSox_df['label'] = 'ntSox'
radNtSox_df = pd.DataFrame(radNtSox, columns=['PC1', 'PC2'])
radNtSox_df['label'] = 'radNtSox'
radSox_df = pd.DataFrame(radSox, columns=['PC1', 'PC2'])
radSox_df['label'] = 'radDegSox'

all_df_sox = pd.concat([ntSox_df, radNtSox_df, radSox_df, ctcfNtSox_df, ctcfSox_df], axis=0)

# Find the lower and upper bounds of the PC1 and PC2 for Hox locus
min_pc1_hox = min(all_df['PC1'])
max_pc1_hox = max(all_df['PC1'])
min_pc2_hox = min(all_df['PC2'])
max_pc2_hox = max(all_df['PC2'])

# Find the lower and upper bounds of the PC1 and PC2 for Sox locus
min_pc1_sox = min(all_df_sox['PC1'])
max_pc1_sox = max(all_df_sox['PC1'])
min_pc2_sox = min(all_df_sox['PC2'])
max_pc2_sox = max(all_df_sox['PC2'])

def generate_microstates(min_pc1, max_pc1, min_pc2, max_pc2, num_microstates, pca_model):
    """
    Generates a grid of points (microstates) based on provided PCA components ranges, 
    sorts them, and applies inverse transformation using the given PCA model.
    
    Parameters:
    ----------
    min_pc1 : float
        Minimum value for the first principal component (PC1).
    max_pc1 : float
        Maximum value for the first principal component (PC1).
    min_pc2 : float
        Minimum value for the second principal component (PC2).
    max_pc2 : float
        Maximum value for the second principal component (PC2).
    num_microstates : int
        Number of microstates (grid points) to generate for each component.
    pca_model : sklearn.decomposition.PCA
        Pre-trained PCA model that will be used to inverse transform the generated grid points.
    
    Returns:
    -------
    np.ndarray
        A NumPy array containing the inverse-transformed microstates.
    
    Example:
    -------
    microstates = generate_microstates(-5, 5, -5, 5, 75, pca_model)
    """
    
    # Create a grid of points
    pc1 = np.linspace(min_pc1, max_pc1, num_microstates)
    pc2 = np.linspace(min_pc2, max_pc2, num_microstates)
    pc1, pc2 = np.meshgrid(pc1, pc2)
    pc1 = pc1.flatten()
    pc2 = pc2.flatten()

    # Create a DataFrame from the grid points
    grid = pd.DataFrame({'PC1': pc1, 'PC2': pc2})

    # Sort PC2 in descending order while keeping PC1 in ascending order
    grid_sorted = grid.sort_values(by=['PC1', 'PC2'], ascending=[True, False], ignore_index=True)

    # Apply inverse transformation using the provided PCA model
    microstates = pca_model.inverse_transform(grid_sorted)

    return microstates

# use 5 microstates for debugging
num_microstates = 75

pc1_hox = np.linspace(min_pc1_hox, max_pc1_hox, num_microstates)
pc2_hox = np.linspace(min_pc2_hox, max_pc2_hox, num_microstates)
pc1_hox, pc2_hox = np.meshgrid(pc1_hox, pc2_hox)
pc1_hox = pc1_hox.flatten()
pc2_hox = pc2_hox.flatten()

# Put the grid points into a DataFrame
grid_hox = pd.DataFrame({'PC1': pc1_hox, 'PC2': pc2_hox})
# Sort PC2 in descending order while keeping PC1 in ascending order
grid_hox = grid_hox.sort_values(by=['PC1', 'PC2'], ascending=[True, False], ignore_index=True)

# Infer microstates from PCA for Hox locus
microstates_hox = pca_hox.inverse_transform(grid_hox)

# Use the function I just wrote to generate microstates for the Sox locus
microstates_sox = generate_microstates(min_pc1_sox, max_pc1_sox, min_pc2_sox, max_pc2_sox, num_microstates, pca_sox)


def calculate_conformational_variance_new(dmap_list, microstates_dmap):
    # This is incorrect because it finds mean across all samples 
    """
    Calculate the conformational variation of a set of distance maps relative to a reference map.

    Parameters:
    dmap_list (list): A list of 2D numpy arrays representing the distance maps.
    dmap_ref (np.ndarray): A 2D numpy array representing the reference distance map.
    num_probes (int): The number of probes in the distance maps.

    Returns:
    np.ndarray: A 2D numpy array containing the variance of the squared Euclidean distances 
               between each distance map and the reference map.
    """
    # Convert dmap_list to a NumPy array
    dmap_list = np.array(dmap_list)
    
    num_microstates = microstates_dmap.shape[0]
    num_probes = np.round(microstates_dmap.shape[1] ** 0.5).astype(int)
    
    dmap_list = dmap_list[:, np.newaxis, :]
    microstates_dmap = microstates_dmap[np.newaxis, :, :]
    
    # Calculate the squared Euclidean distance between each distance map and the reference map
    diff_list = np.sqrt((dmap_list - microstates_dmap) ** 2)
    
    # Calculate the variance along the number of observation/cell dimension
    var = np.var(diff_list, axis=0)
    
    return np.reshape(var, (num_microstates, num_probes, num_probes))


2024-10-14 17:58:48.312013: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-14 17:58:48.694240: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Consider mio5.varmats_from_mat to split file into single variable files
  matfile_dict = MR.get_variables(variable_names)


In [2]:
def calculate_conformational_variance_jax(dmap_list, dmap_ref):
    """
    Calculate the conformational variation of a set of distance maps relative to a reference map.

    Parameters:
    dmap_list (list): A list of 2D numpy arrays representing the distance maps.
    dmap_ref (np.ndarray): A 2D numpy array representing the reference distance map.
    num_probes (int): The number of probes in the distance maps.

    Returns:
    np.ndarray: A 2D numpy array containing the variance of the squared Euclidean distances 
               between each distance map and the reference map.
    """
    # Convert dmap_list to a NumPy array
    dmap_list = jnp.array(dmap_list)
    
    # Calculate the squared Euclidean distance between each distance map and the reference map
    diff_list = jnp.sqrt((dmap_list - dmap_ref) ** 2) 
    
    # Calculate the variance along the number of observation/cell dimension
    var = jnp.var(diff_list, axis=0)
    
    return var

# Rewrite this in a jax-compatible fashion
from functools import partial
@partial(jax.jit, static_argnums=(2,)) 
def batch_calculate_variances(dmap_list, dmap_ref, num_probes):
    """
    Vectorized version that applies calculate_conformational_variance_jax across a batch of distance maps.
    """
    return jax.vmap(lambda dmap: calculate_conformational_variance_jax(dmap_list, jnp.reshape(dmap, [num_probes, num_probes])))(dmap_ref)


In [3]:
microstates_hox_jax = jnp.array(microstates_hox)

# Calculate variances across all microstates using JAX
# Faster but might run into GPU memory issues
# this takes 1 minute
ctcfNtHox_std = batch_calculate_variances(jnp.array(ctcfNtMapsHox), microstates_hox_jax, nHox)**0.5
ctcfHox_std = batch_calculate_variances(jnp.array(ctcfMapsHox), microstates_hox_jax, nHox)**0.5
ntHox_std = batch_calculate_variances(jnp.array(ntMapsHox), microstates_hox_jax, nHox)**0.5
radNtHox_std = batch_calculate_variances(jnp.array(radNtMapsHox), microstates_hox_jax, nHox)**0.5
radHox_std = batch_calculate_variances(jnp.array(radMapsHox), microstates_hox_jax, nHox)**0.5

microstates_sox_jax = jnp.array(microstates_sox)

ctcfNtSox_std = batch_calculate_variances(jnp.array(ctcfNtMapsSox), microstates_sox_jax, nSox)**0.5
ctcfSox_std = batch_calculate_variances(jnp.array(ctcfMapsSox), microstates_sox_jax, nSox)**0.5
ntSox_std = batch_calculate_variances(jnp.array(ntMapsSox), microstates_sox_jax, nSox)**0.5
radNtSox_std = batch_calculate_variances(jnp.array(radNtMapsSox), microstates_sox_jax, nSox)**0.5
radSox_std = batch_calculate_variances(jnp.array(radMapsSox), microstates_sox_jax, nSox)**0.5

: 

In [16]:
# Calculate variances across all microstates 
# this require bigmem because it requires big memory
ctcfNtHox_std = calculate_conformational_variance_new(ctcfNtFlattenHox, microstates_hox)**0.5
ctcfHox_std = calculate_conformational_variance_new(ctcfFlattenHox, microstates_hox)**0.5
ntHox_std = calculate_conformational_variance_new(ntFlattenHox, microstates_hox)**0.5
radNtHox_std = calculate_conformational_variance_new(radNtFlattenHox, microstates_hox)**0.5
radHox_std = calculate_conformational_variance_new(radFlattenHox, microstates_hox)**0.5

ctcfNtSox_std = calculate_conformational_variance_new(ctcfNtFlattenSox, microstates_sox)**0.5
ctcfSox_std = calculate_conformational_variance_new(ctcfFlattenSox, microstates_sox)**0.5
ntSox_std = calculate_conformational_variance_new(ntFlattenSox, microstates_sox)**0.5
radNtSox_std = calculate_conformational_variance_new(radNtFlattenSox, microstates_sox)**0.5
radSox_std = calculate_conformational_variance_new(radFlattenSox, microstates_sox)**0.5

MemoryError: Unable to allocate 1.17 TiB for an array with shape (5505, 5625, 5184) and data type float64

In [5]:
lpm_hox = [(logprior(x, nHox)).tolist() for x in microstates_hox]
lpm_sox = [(logprior(x, nSox)).tolist() for x in microstates_sox]

In [6]:
def loglikelihood(observed_dmap_flat, microstates_dmap_flat, measurement_error, num_probes):
    """ 
    """
    num_microstates = microstates_dmap_flat.shape[0]
    num_observations = observed_dmap_flat.shape[0]
    
    # Append a new axis for broadcasting
    observed_dmap_flat = observed_dmap_flat[np.newaxis, :, :]
    microstates_dmap_flat = microstates_dmap_flat[:, np.newaxis, :]
    measurement_error = measurement_error[:, np.newaxis, :, :]
        
    
    # Calculate the difference between distance map and reference 
    # distance map
    subtraction_map_sq = np.square(observed_dmap_flat - microstates_dmap_flat).reshape(num_microstates, num_observations, 
                                                                      num_probes, num_probes)

    # Only consider the upper triangular part of the distance map
    triu_indices = np.triu_indices(num_probes, k=1)
    measurement_error = 2*measurement_error[:, :, triu_indices[0], triu_indices[1]]  # both triangles 
    subtraction_map_sq = 2*subtraction_map_sq[:, :, triu_indices[0], triu_indices[1]]  # both triangles
    
    # Calculate the normalization factor
    normalization_factor = -np.sum(np.log(np.sqrt(2*np.pi*measurement_error**2)), axis=-1)
    
    # Calculate the gaussian term 
    gaussian_term = -np.sum(subtraction_map_sq/(2*np.square(measurement_error)), axis=-1)
    
    # if the reference distance map is not physical ie contains negative values
    # return very low probability: the lowest number numpy can handle
    if np.any(microstates_dmap_flat <= -1):
        unphysical_microstates_indices = np.squeeze(np.any(microstates_dmap_flat < 0, axis=-1))
        # print(unphysical_microstates_indices.shape)
        normalization_factor[unphysical_microstates_indices] = np.iinfo(np.int32).min
        gaussian_term[unphysical_microstates_indices] = np.iinfo(np.int32).min
    
    # Change the dimension so it is compatible with the downstream analysis
    return np.transpose(normalization_factor + gaussian_term)

In [9]:
# Define the main loglikelihood function using JAX
def loglikelihood_jax(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    return jnp.sum(_loglikelihood_jax(dmap_flat, ref_dmap_flat, measurement_error, num_probes))

# Define the helper function, with JAX-compatible logic
def _loglikelihood_jax(dmap_flat, ref_dmap_flat, measurement_error, num_probes):
    # Use lax.cond for control flow based on the condition
    min_value = jnp.iinfo(jnp.int32).min
    
    def handle_invalid_reference(ref_dmap_flat):
        # Return extremely low probability when ref_dmap_flat contains invalid values
        return jnp.array([jnp.float32(min_value), jnp.float32(min_value)])
    
    def handle_valid_reference(ref_dmap_flat):
        # Calculate the difference between distance map and reference 
        subtraction_map_sq = jnp.square(dmap_flat - ref_dmap_flat).reshape(num_probes, num_probes)

        # Only consider the upper triangular part of the distance map
        # because the diagonal values do not have variance
        triu_indices = jnp.triu_indices(num_probes, k=1)
        measurement_error_scaled = 2 * measurement_error[triu_indices]  # both triangles 
        subtraction_map_sq_scaled = 2 * subtraction_map_sq[triu_indices]  # both triangles
        
        # Calculate the normalization factor
        normalization_factor = -jnp.sum(jnp.log(jnp.sqrt(2 * jnp.pi * measurement_error_scaled**2)))
        
        # Calculate the Gaussian term 
        gaussian_term = -jnp.sum(subtraction_map_sq_scaled / (2 * jnp.square(measurement_error_scaled)))
        
        return jnp.array([normalization_factor, gaussian_term])

    # Apply the appropriate logic depending on whether ref_dmap_flat contains negative values
    return lax.cond(
        jnp.any(ref_dmap_flat <= -1),
        handle_invalid_reference,
        handle_valid_reference,
        ref_dmap_flat
    )
    
def compute_loglikelihood_for_y(y, templates_flatten, measurement_error_esc, num_probes):
    return jax.vmap(lambda x, z: loglikelihood_jax(y, x, z, num_probes))(templates_flatten, measurement_error_esc)

# Calculate likelihood for Hox samples
ctcfNtHox_ll = [compute_loglikelihood_for_y(y, microstates_hox_jax, ctcfNtHox_std, nHox) for y in tqdm(ctcfNtFlattenHox)] 
ctcfHox_ll = [compute_loglikelihood_for_y(y, microstates_hox_jax, ctcfHox_std, nHox) for y in tqdm(ctcfFlattenHox)]
ntHox_ll = [compute_loglikelihood_for_y(y, microstates_hox_jax, ntHox_std, nHox) for y in tqdm(ntFlattenHox)]
radNtHox_ll = [compute_loglikelihood_for_y(y, microstates_hox_jax, radNtHox_std, nHox) for y in tqdm(radNtFlattenHox)]
radHox_ll = [compute_loglikelihood_for_y(y, microstates_hox_jax, radHox_std, nHox) for y in tqdm(radFlattenHox)]

# Calculate likelihood for Sox samples
ctcfNtSox_ll = [compute_loglikelihood_for_y(y, microstates_sox_jax, ctcfNtSox_std, nSox) for y in tqdm(ctcfNtFlattenSox)]
ctcfSox_ll = [compute_loglikelihood_for_y(y, microstates_sox_jax, ctcfSox_std, nSox) for y in tqdm(ctcfFlattenSox)]
ntSox_ll = [compute_loglikelihood_for_y(y, microstates_sox_jax, ntSox_std, nSox) for y in tqdm(ntFlattenSox)]
radNtSox_ll = [compute_loglikelihood_for_y(y, microstates_sox_jax, radNtSox_std, nSox) for y in tqdm(radNtFlattenSox)]
radSox_ll = [compute_loglikelihood_for_y(y, microstates_sox_jax, radSox_std, nSox) for y in tqdm(radFlattenSox)]


  0%|          | 0/5505 [00:00<?, ?it/s]

  0%|          | 0/8737 [00:00<?, ?it/s]

  0%|          | 0/15319 [00:00<?, ?it/s]

  0%|          | 0/7855 [00:00<?, ?it/s]

  0%|          | 0/7462 [00:00<?, ?it/s]

  0%|          | 0/5451 [00:00<?, ?it/s]

  0%|          | 0/7902 [00:00<?, ?it/s]

  0%|          | 0/11900 [00:00<?, ?it/s]

  0%|          | 0/4094 [00:00<?, ?it/s]

  0%|          | 0/8090 [00:00<?, ?it/s]

In [12]:
# Define the helper function, with JAX-compatible logic
# run into a memory issue 
def loglikelihood_jax_new(observed_dmap_flat, microstates_dmap_flat, measurement_error, num_probes):
    num_microstates = microstates_dmap_flat.shape[0]
    num_observations = observed_dmap_flat.shape[0]
    
    # Append a new axis for broadcasting
    observed_dmap_flat = observed_dmap_flat[None, :, :]
    microstates_dmap_flat = microstates_dmap_flat[:, None, :]
    measurement_error = measurement_error[:, None, :, :]
    
    # Calculate the difference between distance map and reference 
    # distance map
    subtraction_map_sq = jnp.square(observed_dmap_flat - microstates_dmap_flat).reshape(num_microstates, num_observations, 
                                                                      num_probes, num_probes)
    
    # Calculate the normalization factor
    normalization_factor = -jnp.sum(jnp.log(jnp.sqrt(2 * jnp.pi * measurement_error_scaled**2)), axis=-1)
    
    # Calculate the Gaussian term 
    gaussian_term = -jnp.sum(subtraction_map_sq_scaled / (2 * jnp.square(measurement_error_scaled)), axis=-1)

    # Only consider the upper triangular part of the distance map
    # because the diagonal values do not have variance
    triu_indices = jnp.triu_indices(num_probes, k=1)
    measurement_error_scaled = 2 * measurement_error[triu_indices]  # both triangles 
    subtraction_map_sq_scaled = 2 * subtraction_map_sq[triu_indices]  # both triangles
   
    # Check for unphysical microstates (i.e., negative values in the distance map)
    unphysical_microstates_indices = jnp.any(microstates_dmap_flat < 0, axis=-1)

    # Get the minimum integer value that jax.numpy can handle
    min_value = jnp.iinfo(jnp.int32).min

    # Replace the values of `normalization_factor` and `gaussian_term` for unphysical indices
    normalization_factor = jnp.where(unphysical_microstates_indices, min_value, normalization_factor)
    gaussian_term = jnp.where(unphysical_microstates_indices, min_value, gaussian_term)

    return jnp.transpose(normalization_factor + gaussian_term)

In [13]:
ctcfNtHox_ll_jn = loglikelihood_jax_new(ctcfNtFlattenHox, microstates_hox, ctcfNtHox_std, nHox)

MemoryError: Unable to allocate 1.17 TiB for an array with shape (5625, 5505, 5184) and data type float64

In [7]:
# This takes about 20 seconds to run each sample
# JAX might be faster 
# requires bigmemory 

ctcfNtHox_ll = loglikelihood(ctcfNtFlattenHox, microstates_hox, ctcfNtHox_std, nHox)
ctcfHox_ll = loglikelihood(ctcfFlattenHox, microstates_hox, ctcfHox_std, nHox)
ntHox_ll = loglikelihood(ntFlattenHox, microstates_hox, ntHox_std, nHox)
radNtHox_ll = loglikelihood(radNtFlattenHox, microstates_hox, radNtHox_std, nHox)
radHox_ll = loglikelihood(radFlattenHox, microstates_hox, radHox_std, nHox)

ctcfNtSox_ll = loglikelihood(ctcfNtFlattenSox, microstates_sox, ctcfNtSox_std, nSox)
ctcfSox_ll = loglikelihood(ctcfFlattenSox, microstates_sox, ctcfSox_std, nSox)
ntSox_ll = loglikelihood(ntFlattenSox, microstates_sox, ntSox_std, nSox)
radNtSox_ll = loglikelihood(radNtFlattenSox, microstates_sox, radNtSox_std, nSox)
radSox_ll = loglikelihood(radFlattenSox, microstates_sox, radSox_std, nSox)

MemoryError: Unable to allocate 1.17 TiB for an array with shape (5625, 5505, 5184) and data type float64

In [10]:
# Convert all jnp arrays to list 
ctcfNtHox_ll = [x.tolist() for x in ctcfNtHox_ll]
ctcfHox_ll = [x.tolist() for x in ctcfHox_ll]
ntHox_ll = [x.tolist() for x in ntHox_ll]
radNtHox_ll = [x.tolist() for x in radNtHox_ll]
radHox_ll = [x.tolist() for x in radHox_ll]

ctcfNtSox_ll = [x.tolist() for x in ctcfNtSox_ll]
ctcfSox_ll = [x.tolist() for x in ctcfSox_ll]
ntSox_ll = [x.tolist() for x in ntSox_ll]
radNtSox_ll = [x.tolist() for x in radNtSox_ll]
radSox_ll = [x.tolist() for x in radSox_ll]

# Calculate the number of cells in each dataset
N_ctcfNtHox = ctcfNtMapsHox.shape[0]
N_ctcfHox = ctcfMapsHox.shape[0]
N_ntHox = ntMapsHox.shape[0]
N_radNtHox = radNtMapsHox.shape[0]
N_radHox = radMapsHox.shape[0]

N_ctcfNtSox = ctcfNtMapsSox.shape[0]
N_ctcfSox = ctcfMapsSox.shape[0]
N_ntSox = ntMapsSox.shape[0]
N_radNtSox = radNtMapsSox.shape[0]
N_radSox = radMapsSox.shape[0]

M = num_microstates**2  # Number of microstates

# Load stan model 
my_model = CmdStanModel(
    stan_file='/mnt/home/tudomlumleart/ceph/01_ChromatinEnsembleRefinement/chromatin-ensemble-refinement/scripts/stan/20240715_WeightOptimization.stan',
    cpp_options = {
        "STAN_THREADS": True,
    }
    )

n_cores = multiprocessing.cpu_count()
print(f"Number of CPU cores: {n_cores}")
parallel_chains = 4
threads_per_chain = int(n_cores / parallel_chains)
print(f"Number of threads per chain: {threads_per_chain}")

Number of CPU cores: 64
Number of threads per chain: 16


In [11]:
save_dir = '/mnt/home/tudomlumleart/ceph/01_ChromatinEnsembleRefinement/chromatin-ensemble-refinement/MCMC_results/20240930_RunWeightMCMC_Tonia_PCA_2'
conditions = [
    'ctcfNtHox', 'ctcfHox', 'ntHox', 'radNtHox', 'radHox',
    'ctcfNtSox', 'ctcfSox', 'ntSox', 'radNtSox', 'radSox'
]

for condition in tqdm(conditions):
    output_dir = os.path.join(save_dir, condition)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    json_filename = os.path.join(output_dir, 'data.json')
    stan_putput_file = os.path.join(output_dir, 'stan_output')
    
    # Generate a datadict for the current condition
    if condition == 'ctcfNtHox':
        data_dict = {
            'N': N_ctcfNtHox,
            'M': M,
            'll_map': ctcfNtHox_ll,
            'lpm_vec': lpm_hox
        }
    elif condition == 'ctcfHox':
        data_dict = {
            'N': N_ctcfHox,
            'M': M,
            'll_map': ctcfHox_ll,
            'lpm_vec': lpm_hox
        }
    elif condition == 'ntHox':
        data_dict = {
            'N': N_ntHox,
            'M': M,
            'll_map': ntHox_ll,
            'lpm_vec': lpm_hox
        }
    elif condition == 'radNtHox':
        data_dict = {
            'N': N_radNtHox,
            'M': M,
            'll_map': radNtHox_ll,
            'lpm_vec': lpm_hox
        }
    elif condition == 'radHox':
        data_dict = {
            'N': N_radHox,
            'M': M,
            'll_map': radHox_ll,
            'lpm_vec': lpm_hox
        }
    elif condition == 'ctcfNtSox':
        data_dict = {
            'N': N_ctcfNtSox,
            'M': M,
            'll_map': ctcfNtSox_ll,
            'lpm_vec': lpm_sox
        }
    elif condition == 'ctcfSox':
        data_dict = {
            'N': N_ctcfSox,
            'M': M,
            'll_map': ctcfSox_ll,
            'lpm_vec': lpm_sox
        }
    elif condition == 'ntSox':
        data_dict = {
            'N': N_ntSox,
            'M': M,
            'll_map': ntSox_ll,
            'lpm_vec': lpm_sox
        }
    elif condition == 'radNtSox':
        data_dict = {
            'N': N_radNtSox,
            'M': M,
            'll_map': radNtSox_ll,
            'lpm_vec': lpm_sox
        }
    elif condition == 'radSox':
        data_dict = {
            'N': N_radSox,
            'M': M,
            'll_map': radSox_ll,
            'lpm_vec': lpm_sox
        }
    
    json_obj = json.dumps(data_dict, indent=4)
    
    with open(json_filename, 'w') as json_file:
        json_file.write(json_obj)
        json_file.close()
    
    
    

  0%|          | 0/10 [00:00<?, ?it/s]