In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import adabmDCA

from adabmDCA.utils import get_device, get_dtype, get_mask_save
from adabmDCA.sampling import get_sampler

import sys
sys.path.append('..')
import selex_distribution, energy_models, tree, data_loading, training, callback, sampling
import selex_dca, utils

  from tqdm.autonotebook import tqdm


In [2]:
from importlib import reload
reload(utils)

<module 'utils' from '/home/scrotti/Aptamer2025py/experiments/../utils.py'>

In [3]:
round_ids = ['Input_R1_N', 'OplusR1_N', 'OplusR2_N']

dtype = torch.float32

In [4]:
sequences = []
for round_id in round_ids:
    sequences.append(utils.sequences_from_file_ab6(round_id))

In [5]:
sequences_oh = [utils.one_hot(seq, num_classes=21) for seq in sequences]
total_reads = torch.Tensor([s.shape[0] for s in sequences_oh])

In [6]:
tr = tree.Tree()
tr.add_node(-1, name = 'OplusR1_N')
tr.add_node(0, name = 'OplusR2_N')

selected_modes = torch.BoolTensor(
    [[1], [1]]
)

L, q = sequences_oh[0][0].shape

k = torch.zeros(L, q, dtype=dtype)
h = torch.zeros(L, q, dtype=dtype)
J = torch.zeros(L, q, L, q, dtype=dtype)

Ns0 = energy_models.IndepSites(k)
potts = energy_models.Potts(J, h)

ps = selex_distribution.MultiModeDistribution(potts, normalized=False)
model = selex_distribution.MultiRoundDistribution(Ns0, ps, tr, selected_modes)

In [7]:
batch_size = 10**6
data_loaders = [data_loading.SelexRoundDataLoader(seq_oh, batch_size=batch_size) for seq_oh in sequences_oh]
n_rounds = len(data_loaders) 

In [8]:
n_chains = 10**4
chains = training.init_chains(n_rounds, n_chains, L, q, dtype=dtype)
log_weights = torch.zeros(n_rounds, n_chains, dtype=dtype)

In [9]:
callbacks = [callback.ConvergenceMetricsCallback(), callback.PearsonCovarianceCallback()]

In [None]:
n_sweeps = 10
lr = 0.01
target_pearson = 1
max_epochs = 2000

training.train(model, data_loaders, total_reads, chains, n_sweeps, lr, max_epochs, target_pearson, 
               callbacks=callbacks, log_weights=log_weights)

 0.00%[                                          ] Epoch: 0/2000 [00:00, ?it/s]

In [None]:
callbacks[0].plot();

In [None]:
fig, ax = callbacks[1].plot()
ax.axhline(1, color='r', linestyle='--')

In [None]:
from IPython.display import display, Latex

potts_zerosum = potts.set_zerosum_gauge()
J = potts_zerosum.J.detach()
h = potts_zerosum.h.detach()

pl, ax = plt.subplots(figsize=(3,3))
F = selex_dca.get_contact_map(J)
im = ax.imshow(F)
ax.set_xlabel("i"); ax.set_ylabel("i")
plt.colorbar(im)
plt.tight_layout()
display(Latex("$F_{ij}=\\sqrt {\\sum_{ab}(J_{ij}^{ab})^2}$"))
print('Multi-round')

In [None]:
(h ** 2).sum(1)