Try with Potts instead of indep sites

In [1]:
import sys
sys.path.append('..')

import utils, selex_dca, indep_sites
import adabmDCA
import selex_distribution, energy_models, tree, data_loading, training, callback, sampling

import torch
from utils import one_hot
import matplotlib.pyplot as plt
import random

  from tqdm.autonotebook import tqdm


In [2]:
experiment_id = "Dop8V030"
round_ids = ["ARN", "R01", "R01CS", "R02N"]

dtype = torch.float32

In [3]:
sequences = []
for round_id in round_ids:
    sequences.append(utils.sequences_from_file(experiment_id, round_id))
    print(f"Finished round {round_id}")

Finished round ARN
Finished round R01
Finished round R01CS
Finished round R02N


In [4]:
sequences_oh = [one_hot(seq).to(dtype=dtype) for seq in sequences]

In [5]:
sequences.clear()

In [6]:
total_reads = torch.Tensor([s.shape[0] for s in sequences_oh])
fi0, _, _ = utils.frequences_from_sequences_oh(sequences_oh[0])

In [7]:
L, q = sequences_oh[0][0].shape

k = torch.randn(L, q, dtype=dtype) * 1e-4
h_beads = torch.randn(L, q, dtype=dtype) * 1e-4
h = torch.randn(L, q, dtype=dtype) * 1e-4

tr = tree.Tree()
tr.add_node(-1)
tr.add_node(-1)
tr.add_node(0)

selected_modes = torch.BoolTensor(
    [[1, 1, 0],
     [0, 1, 0],
     [1, 1, 0]]
)

Ns0 = energy_models.IndepSites(k)
indep = energy_models.IndepSites(h)
beads = energy_models.IndepSites(h_beads)
unbound_sel = energy_models.ConstantEnergy(-10)

ps = selex_distribution.MultiModeDistribution(indep, beads, unbound_sel, normalized=True)
model = selex_distribution.MultiRoundDistribution(Ns0, ps, tr, selected_modes)

In [8]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [9]:
batch_size = 10**6
data_loaders = [data_loading.SelexRoundDataLoader(seq_oh, batch_size=batch_size, device=device) for seq_oh in sequences_oh]
n_rounds = len(data_loaders) 

In [10]:
n_chains = 10**5
chains = training.init_chains(n_rounds, n_chains, L, q, dtype=dtype)
log_weights = torch.zeros(n_rounds, n_chains, dtype=dtype)

In [11]:
callbacks = [callback.ConvergenceMetricsCallback()]

In [12]:
model_device = model.to(device)
chains_device = chains.to(device)
total_reads_device = total_reads.to(device)
log_weights_device = log_weights.to(device)

In [None]:
n_sweeps = 10
lr = 0.01
target_pearson = 1
max_epochs = 4000

training.train(model_device, data_loaders, total_reads_device, chains_device, n_sweeps, max_epochs, target_pearson, 
               callbacks=callbacks, log_weights=log_weights_device, lr=lr)

 0.00%[                                                                                                       â€¦

In [None]:
model = model_device.to(torch.device('cpu'))
chains = chains_device.to(torch.device('cpu'))
log_weights = log_weights_device.to(torch.device('cpu'))

del model_device, chains_device, total_reads_device, log_weights_device
torch.cuda.empty_cache()

import gc
gc.collect()

In [None]:
callbacks[0].plot();

In [None]:
torch.save(model, 'experiments/saved/dataset_1_beads_indep_sites.pt')

In [None]:
# fig, ax = callbacks[1].plot()
# ax.axhline(1, color='r', linestyle='--')
# ax.set_ylim((0,1.1))

In [None]:
# from IPython.display import display, Latex

# potts = energy_models.Potts(model.selection.modes[0].J.data, model.selection.modes[0].h.data)

# potts_zerosum = potts.set_zerosum_gauge()
# # potts_zerosum = potts

# pl, ax = plt.subplots(figsize=(3,3))
# F = selex_dca.get_contact_map(potts_zerosum.J.detach())
# im = ax.imshow(F)
# ax.set_xlabel("i"); ax.set_ylabel("i")
# ax.set_title(experiment_id)
# plt.colorbar(im)
# plt.tight_layout()
# # im.set_clim(-0.0015, 0.0030)
# display(Latex("$F_{ij}=\\sqrt {\\sum_{ab}(J_{ij}^{ab})^2}$"))

In [None]:
# h_norm_bead = torch.sqrt(torch.square(potts_zerosum.h.cpu().detach().clone()).sum(1))
# plt.scatter(range(L), h_norm_bead)

## Same thing but without the info from the control round

In [None]:
L, q = sequences_oh[0][0].shape
n_rounds_noc = 3

k_noc = torch.randn(L, q, dtype=dtype) * 1e-4
h_beads_noc = torch.randn(L, q, dtype=dtype) * 1e-4
h_noc = torch.randn(L, q, dtype=dtype) * 1e-4

tr_noc = tree.Tree()
tr_noc.add_node(-1)
tr_noc.add_node(0)

selected_modes_noc = torch.BoolTensor(
    [[1, 0],
     [1, 0]]
)

Ns0_noc = energy_models.IndepSites(k_noc)
indep_noc = energy_models.IndepSites(h_noc)
unbound_sel_noc = energy_models.ConstantEnergy(-10)

ps_noc = selex_distribution.MultiModeDistribution(indep_noc, unbound_sel_noc, normalized=True)
model_noc = selex_distribution.MultiRoundDistribution(Ns0_noc, ps_noc, tr_noc, selected_modes_noc)

Here we only select for unbound and target, and do not use the control round

In [None]:
total_reads_noc = torch.Tensor([sequences_oh[t].shape[0] for t in [0,1,3]])

In [None]:
batch_size = 10**6
data_loaders_noc = [data_loading.SelexRoundDataLoader(sequences_oh[t], batch_size=batch_size, device=device) for t in [0,1,3]]
# n_rounds = len(data_loaders_noc) 

In [None]:
n_chains = 10**5
chains_noc = training.init_chains(n_rounds_noc, n_chains, L, q, dtype=dtype)
log_weights_noc = torch.zeros(n_rounds_noc, n_chains, dtype=dtype)

In [None]:
callbacks_noc = [callback.ConvergenceMetricsCallback()]

In [None]:
model_device = model_noc.to(device)
chains_device = chains_noc.to(device)
total_reads_device = total_reads_noc.to(device)
log_weights_device = log_weights_noc.to(device)

In [None]:
n_sweeps = 10
lr = 0.01
target_pearson = 1
max_epochs = 4000

training.train(model_device, data_loaders, total_reads_device, chains_device, n_sweeps, max_epochs, target_pearson, 
               callbacks=callbacks_noc, log_weights=log_weights_device, lr=lr)

In [None]:
model_noc = model_device.to(torch.device('cpu'))
chains_noc = chains_device.to(torch.device('cpu'))
log_weights_noc = log_weights_device.to(torch.device('cpu'))

del model_device, chains_device, total_reads_device, log_weights_device
torch.cuda.empty_cache()

import gc
gc.collect()

In [None]:
callbacks_noc[0].plot();

In [None]:
torch.save(model_noc, 'experiments/saved/dataset_1_nobeads_indep_sites.pt')

## Comparison

In [None]:
N = 10**5
n_sel = model.get_n_rounds()
fig, axes = plt.subplots(1, n_sel, figsize=(12,3))

for t in range(n_sel):
    X = sequences_oh[t]
    x = model.selection_energy_at_round(X, 1).detach().cpu()
    y = model_noc.selection_energy_at_round(X, 1).detach().cpu()
    idx = random.sample(range(len(x)), N)
    ax = axes[t]
    ax.plot(x, x, color='gray')
    ax.scatter(x[idx], y[idx], s=2)
    ax.set_xlabel('logps - beads')
    ax.set_ylabel('logps - no beads')
    ax.set_title(f'Round {t}')
fig.suptitle('log ps')
fig.tight_layout();

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10,4))

ax = axes[0]
x = model.round_zero.set_zerosum_gauge().h.detach()
y = model_noc.round_zero.set_zerosum_gauge().h.detach()
ax.plot(x, x, color='gray')
ax.scatter(x, y)
ax.set_xlabel('k - beads')
ax.set_ylabel('k - no beads')
ax.set_title('Round zero field k')

ax = axes[1]
x = model.selection.modes[0].set_zerosum_gauge().h.detach()
y = model_noc.selection.modes[0].set_zerosum_gauge().h.detach()
ax.plot(x, x, color='gray')
ax.scatter(x, y)
ax.set_xlabel('h - beads')
ax.set_ylabel('h - no beads')
ax.set_title('Selection field h')

ax = axes[2]
x = utils.off_diagonal_terms(model.selection.modes[0].set_zerosum_gauge().J.detach())
y = utils.off_diagonal_terms(model_noc.selection.modes[0].set_zerosum_gauge().J.detach())
ax.plot(x, x, color='gray')
ax.scatter(x, y)
ax.set_xlabel('J - beads')
ax.set_ylabel('J - no beads')
ax.set_title('Selection couplings J')


fig.tight_layout();