In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import adabmDCA

from adabmDCA.utils import get_device, get_dtype, get_mask_save
from adabmDCA.sampling import get_sampler

import sys
sys.path.append('..')
import selex_distribution, energy_models, tree, data_loading, training, callback, sampling
import selex_dca, utils

  from tqdm.autonotebook import tqdm


In [2]:
round_ids = ['Input_R1_N', 'OplusR1_N', 'OplusR2_N']

dtype = torch.float32

In [3]:
sequences = []
for round_id in round_ids:
    sequences.append(utils.sequences_from_file_ab6(round_id))

In [4]:
sequences_oh = [utils.one_hot(seq, num_classes=21) for seq in sequences]
total_reads = torch.Tensor([s.shape[0] for s in sequences_oh])

In [25]:
tr = tree.Tree()
tr.add_node(-1, name = 'OplusR1_N')
tr.add_node(0, name = 'OplusR2_N')

selected_modes = torch.BoolTensor(
    [[1], [1]]
)

L, q = sequences_oh[0][0].shape

k = torch.zeros(L, q, dtype=dtype)
h = torch.zeros(L, q, dtype=dtype)
J = torch.zeros(L, q, L, q, dtype=dtype)

Ns0 = energy_models.IndepSites(k)
potts = energy_models.Potts(J, h)

ps = selex_distribution.MultiModeDistribution(potts, normalized=False)
model = selex_distribution.MultiRoundDistribution(Ns0, ps, tr, selected_modes)

In [26]:
# if torch.cuda.is_available():
#     device = torch.device('cuda')
# else:
#     device = torch.device('cpu')
device = torch.device('cpu')

In [27]:
batch_size = 10**6
data_loaders = [data_loading.SelexRoundDataLoader(seq_oh, batch_size=batch_size, device=device) for seq_oh in sequences_oh]
n_rounds = len(data_loaders) 

In [28]:
n_chains = 10**5
chains = training.init_chains(n_rounds, n_chains, L, q, dtype=dtype).to(device)
log_weights = torch.zeros(n_rounds, n_chains, dtype=dtype).to(device)

In [29]:
from importlib import reload
# reload(utils)
reload(callback)
# reload(training)
# reload(sampling)
# reload(selex_distribution)
# reload(energy_models)
# reload(data_loading)

  # ax.plot(self.grad_median, label='median abs')


<module 'callback' from '/home/scrotti/Aptamer2025py/experiments/../callback.py'>

In [30]:
callbacks = [callback.ConvergenceMetricsCallback(), callback.PearsonCovarianceCallback(), callback.ParamsCallback(save_every=100)]

In [None]:
n_sweeps = 10
lr = 0.01
target_pearson = 1
max_epochs = 3000

training.train(model, data_loaders, total_reads, chains, n_sweeps, max_epochs, target_pearson, 
               callbacks=callbacks, log_weights=log_weights, lr=lr)

 0.00%[                                                                                                       â€¦

In [None]:
callbacks[0].plot();

In [None]:
callbacks[0].plot_pearson_detail(figsize=(8,3))

In [None]:
fig, ax = callbacks[1].plot(figsize=(8,3))

In [None]:
training.scatter_moments(model, data_loaders, chains, total_reads, figsize=(10,3));

In [None]:
from IPython.display import display, Latex

potts_zerosum = potts.set_zerosum_gauge()
J = potts_zerosum.J.detach()
h = potts_zerosum.h.detach()

pl, ax = plt.subplots(figsize=(3,3))
F = selex_dca.get_contact_map(J)
im = ax.imshow(F)
ax.set_xlabel("i"); ax.set_ylabel("i")
plt.colorbar(im)
plt.tight_layout()
plt.title('Contact map');
# display(Latex("$F_{ij}=\\sqrt {\\sum_{ab}(J_{ij}^{ab})^2}$"))

In [None]:
(h ** 2).sum(1)

In [None]:
import logomaker
import pandas as pd

logomaker.Logo(pd.DataFrame(h, columns=list(utils.TOKENS_PROTEIN)))