In [1]:
import sys
sys.path.append('..')

import utils, selex_dca, indep_sites
import adabmDCA
import selex_distribution, energy_models, tree, data_loading, training, callback, sampling

import torch
from utils import one_hot
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


In [2]:
experiment_id = "Dop8V030"
# round_ids = ["ARN", "R01", "R01CS", "R02N", "R02F", "R02S", "R02SF"]
# round_ids = ["ARN", "R01", "R01CS", "R02N"]
round_ids = ["ARN", "R01", "R01CS", "R02N", "R02F", "R02S", "R02SF"]

dtype = torch.float32

In [3]:
sequences = []
for round_id in round_ids:
    s = utils.sequences_from_file(experiment_id, round_id, device=torch.device("cpu"))
    sequences.append(s)
    print(f"finished {round_id}")

finished ARN
finished R01
finished R01CS
finished R02N
finished R02F
finished R02S
finished R02SF


In [4]:
sequences_oh = [one_hot(seq).to(dtype=dtype, device=torch.device("cpu")) for seq in sequences]

In [5]:
total_reads = torch.Tensor([s.shape[0] for s in sequences_oh])
fi0, _, _ = utils.frequences_from_sequences_oh(sequences_oh[0])

In [6]:
tr = tree.Tree()
tr.add_node(-1, name = "R01")
tr.add_node(0, name = "R02N")
tr.add_node(0, name = "R02F")
tr.add_node(0, name = "R02S")
tr.add_node(0, name = "R02SF")

picked_round_idx = [0, 1, 3, 4, 5, 6]

mode_names = ["NA, NW", "FA, NW", "NA, HW", "FA, HW"]

selected_modes = torch.BoolTensor(
    [[1, 0, 0, 0],
     [1, 0, 0, 0],
     [0, 1, 0, 0],
     [0, 0, 1, 0],
     [0, 0, 0, 1]]
)

n_selection_rounds, n_modes = selected_modes.size()
assert n_selection_rounds == len(picked_round_idx) - 1
n_rounds = n_selection_rounds + 1

In [8]:
L, q = sequences_oh[0][0].shape

k = torch.randn(L, q, dtype=dtype)
h = torch.randn(L, q, dtype=dtype)
indep_modes = [energy_models.IndepSites(h) for _ in range(n_modes - 1)]

Ns0 = energy_models.IndepSites(k)
unbound_mode = energy_models.ConstantEnergy(-10)

ps = selex_distribution.MultiModeDistribution(*indep_modes, unbound_mode, normalized=True)
model = selex_distribution.MultiRoundDistribution(Ns0, ps, tr, selected_modes, 
                                                  learn_selection_strength=True,
                                                  selection_strength = torch.rand(n_selection_rounds, dtype=dtype))

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
batch_size = 10**6
data_loaders = [data_loading.SelexRoundDataLoader(sequences_oh[id], batch_size=batch_size, device=device) 
                for id in picked_round_idx]

In [None]:
n_chains = 10**4

chains = training.init_chains(n_rounds, n_chains, L, q, dtype=dtype)
log_weights = torch.zeros(n_rounds, n_chains, dtype=dtype)

In [None]:
callbacks = [callback.ConvergenceMetricsCallback()]

In [None]:
model_device = model.to(device)
chains_device = chains.to(device)
total_reads_device = total_reads.to(device)
log_weights_device = log_weights.to(device)
optimizer = torch.optim.Adam(model_device.parameters(), lr=0.01)

In [None]:
n_sweeps = 10
lr = 0.01
target_pearson = 1
max_epochs = 500

training.train(model_device, data_loaders, total_reads_device, chains_device, n_sweeps, lr, max_epochs, target_pearson, 
               callbacks=callbacks, log_weights=log_weights_device, optimizer=optimizer)

In [None]:
model = model_device.to(torch.device('cpu'))
chains = chains_device.to(torch.device('cpu'))
total_reads = total_reads_device.to(torch.device('cpu'))
log_weights = log_weights_device.to(torch.device('cpu'))

del model_device, chains_device, total_reads_device, log_weights_device
torch.cuda.empty_cache()

import gc
gc.collect()

In [None]:
callbacks[0].plot();

In [None]:
model.selection_strength