In [1]:
from _utils import *
import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))  # This prevents crash on GPU

2024-10-21 17:14:40.341874: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-21 17:14:40.699937: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
dataset_path = '/mnt/home/tudomlumleart/ceph/15_SimonDataset/df20240311_posrate30-50p.mat'

In [3]:
dataset = scipy.io.loadmat(dataset_path)

In [4]:
dataset_list = [dataset['df'][:, i][0][0] for i in range(4)] 

In [5]:
# Interpolate the missing data
dataset_polys = []
for sample in dataset_list: 
    dataset_polys.append(interpolate_polymers(sample))

In [6]:
dataset_maps = []
for sample in dataset_polys:
    dataset_maps.append(calculate_distance_map(sample))
    
dataset_maps_flat = []
label_list = []
for i, sample in enumerate(dataset_maps):
    curr_data_flat = [x.flatten() for x in sample]
    dataset_maps_flat.append(curr_data_flat)
    label_list.extend([str(i)] * len(curr_data_flat))

In [7]:
dataset_maps_all = np.concatenate(dataset_maps, axis=0)
dataset_maps_flat_all = np.concatenate(dataset_maps_flat, axis=0)

In [8]:
dataset_maps_flat_all.shape

(16730, 2601)

In [9]:
len(label_list)

16730

In [10]:
from _reweight import reweight_samples

In [11]:
save_dir = '/mnt/home/tudomlumleart/ceph/01_ChromatinEnsembleRefinement/chromatin-ensemble-refinement/MCMC_results/20241021_RunWeightMCMC_Simon_PCA'

In [14]:
def reweight_samples(
    distance_map_list,
    distance_map_flat_list,
    sample_labels,
    num_microstates,
    save_dir,
    method='PCA',
    slurm_file=None):
    # Add docstring 
    """
    """
    num_probes = distance_map_list.shape[1]
    sample_labels = np.array(sample_labels)
    
    if slurm_file is None:
        slurm_file = '/mnt/home/tudomlumleart/ceph/01_ChromatinEnsembleRefinement/chromatin-ensemble-refinement/scripts/slurm/2024_RunPythonScript.sh'        
    
    print('PCA Fitting...')
    if method == 'PCA':
        pca = PCA(n_components=2)
        pca.fit(distance_map_flat_list)
        pca_samples = []
        unique_labels = np.unique(sample_labels)
        for label in unique_labels:
            pca_samples.append(pca.transform(distance_map_flat_list[sample_labels == label, :]))
            
        df_sample_list = []
        for i, label in enumerate(unique_labels):
            df_sample = pd.DataFrame(pca_samples[i], columns=['PC1', 'PC2'])
            df_sample['label'] = label
            df_sample_list.append(df_sample)
        df_samples = pd.concat(df_sample_list, axis=0)
        
        min_pc1 = df_samples['PC1'].min()
        max_pc1 = df_samples['PC1'].max()
        min_pc2 = df_samples['PC2'].min()
        max_pc2 = df_samples['PC2'].max()
        
        num_microstate_per_axis = np.round(np.sqrt(num_microstates), 0).astype(int)
        
        microstate_distance_maps = generate_microstates(
            min_pc1, max_pc1, min_pc2, max_pc2, num_microstate_per_axis, pca)
       
    print('Calculating likelihood...') 
    microstate_distance_maps_jnp = jnp.array(microstate_distance_maps)
    print(microstate_distance_maps_jnp.shape)
    sample_std = []
    sample_ll = []
    sample_num = []
    for label in unique_labels:
        curr_condition = jnp.array(distance_map_list[sample_labels == label, :, :])
        print(curr_condition.shape)
        curr_std = batch_calculate_variances(curr_condition,
                                             microstate_distance_maps_jnp,
                                             num_probes) ** 0.5
        sample_std.append(curr_std)
        
        curr_ll = []
        for y in tqdm(curr_condition):
            curr_ll.append(compute_loglikelihood_for_y(
                y.flatten(), microstate_distance_maps_jnp, 
                curr_std, num_probes).tolist()) 
        
        sample_num.append(curr_condition.shape[0])
        sample_ll.append(curr_ll)

    lpm = [(logprior(x, num_probes)).tolist() for x in microstate_distance_maps]
    
    # Load stan model 
    my_model = CmdStanModel(
        stan_file='/mnt/home/tudomlumleart/ceph/01_ChromatinEnsembleRefinement/chromatin-ensemble-refinement/scripts/stan/20240715_WeightOptimization.stan',
        cpp_options = {
            "STAN_THREADS": True,
        }
        )
    
    print('Saving data...')
    for i, label in tqdm(enumerate(unique_labels)):
        output_dir = os.path.join(save_dir, label)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        json_filename = os.path.join(output_dir, 'data.json')
        stan_output_file = os.path.join(output_dir, 'stan_output')
        
        data_dict = {
            'N': sample_num[i],
            'M': num_microstates,
            'll_map': sample_ll[i],
            'lpm_vec': lpm
        }
        
        json_obj = json.dumps(data_dict, indent=4)
        
        with open(json_filename, 'w') as f:
            f.write(json_obj)
            f.close()
    
    print('Submitting slurm jobs...')       
    submit_mcmc_slurm(save_dir, slurm_file)

In [15]:
reweight_samples(dataset_maps_all,
                dataset_maps_flat_all, 
                 label_list, 
                 75**2,
                 save_dir)

PCA Fitting...
Calculating likelihood...
(5625, 2601)
(3107, 51, 51)


  0%|          | 0/3107 [00:00<?, ?it/s]

(4574, 51, 51)


  0%|          | 0/4574 [00:00<?, ?it/s]

(4760, 51, 51)


  0%|          | 0/4760 [00:00<?, ?it/s]

(4289, 51, 51)


  0%|          | 0/4289 [00:00<?, ?it/s]

Saving data...


0it [00:00, ?it/s]

Submitting slurm jobs...
Submitting slurm job for 3
Submitted batch job 4082521
Submitting slurm job for 2
Submitted batch job 4082522
Submitting slurm job for 1
Submitted batch job 4082523
Submitting slurm job for 0
Submitted batch job 4082525
