In [145]:
import sys
sys.path.append('..')

import utils, indep_sites
import adabmDCA
import selex_distribution, energy_models, tree, data_loading, training, callback, sampling

import torch
from utils import one_hot
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import random
from sklearn.linear_model import LinearRegression, RANSACRegressor

In [22]:
experiment_id = "Dop8V030"
round_ids = ["ARN", "R01", "R02N"]

device = torch.device("cpu")
dtype = torch.float32
ts = range(len(round_ids))

In [3]:
sequences = [utils.sequences_from_file(experiment_id, round_id, device) for round_id in round_ids]

In [126]:
"""
Split sequences into train and test set by selecting for test `n_test` among the `n_top` sequences with the highest count at each round
"""
def split_train_test(sequences_t, n_top, n_test):
    sequences_unique_t, inverse_indices_t, counts_t = torch.unique(sequences_t, dim=0, return_inverse=True, return_counts=True)
    perm = counts_t.argsort(descending=True)
    
    idx_test_ = random.sample(range(n_top), n_test)
    idx_train_ = [id for id in range(len(perm)) if not id in idx_test_]
    idx_test = [perm[id] for id in idx_test_]
    idx_train = [perm[id] for id in idx_train_]
    # check disjoint
    assert set(idx_test) & set(idx_train) == set()
    
    counts_perm_train_t = counts_t[idx_train]
    counts_perm_test_t = counts_t[idx_test]
    sequences_unique_train_t = sequences_unique_t[idx_train]
    sequences_unique_test_t = sequences_unique_t[idx_test]
    
    sequences_train_t = torch.repeat_interleave(sequences_unique_train_t, counts_perm_train_t, dim=0)
    sequences_test_t = torch.repeat_interleave(sequences_unique_test_t, counts_perm_test_t, dim=0)
    assert sequences_train_t.size(0) + sequences_test_t.size(0) == sequences_t.size(0)
    assert counts_perm_train_t.sum() + counts_perm_test_t.sum() == counts_t.sum()

    return sequences_unique_train_t, sequences_unique_test_t, sequences_train_t, sequences_test_t, counts_perm_train_t, counts_perm_test_t

In [164]:
n_test = 5
n_top = 100

sequences_train = []
sequences_test = []
sequences_unique_train = []
sequences_unique_test = []
counts_train = []
counts_test = []

for t in ts:
    print(f'Starting round {t} of {len(ts)}...')
    (sequences_unique_train_t, sequences_unique_test_t, sequences_train_t, sequences_test_t, counts_train_t, counts_test_t) = \
        split_train_test(sequences[t], n_top, n_test)
    sequences_train.append(sequences_train_t)
    sequences_test.append(sequences_test_t)
    sequences_unique_train.append(sequences_unique_train_t)
    sequences_unique_test.append(sequences_unique_test_t)
    counts_train.append(counts_train_t)
    counts_test.append(counts_test_t)

Starting round 0 of 3...


  counts_perm_test_t = counts_t[idx_test]


IndexError: too many indices for tensor of dimension 1

In [None]:
sequences_train_oh = [one_hot(sequences_train[t]).to(dtype=dtype) for t in ts]
sequences_test_oh = [one_hot(sequences_test[t]).to(dtype=dtype) for t in ts]
# sequences_unique_train_oh = [one_hot(sequences_unique_train[t]).to(dtype=dtype) for t in ts]
sequences_unique_test_oh = [one_hot(sequences_unique_test[t]).to(dtype=dtype) for t in ts]

In [None]:
pseudocount = 0.0

freq_single_tuple, freq_pair_tuple, total_reads_tuple = zip(*[
    utils.frequences_from_sequences_oh(seq, pseudo_count=pseudocount) 
    for seq in sequences_train_oh])

fi_train = torch.stack(freq_single_tuple)
fij_train = torch.stack(freq_pair_tuple)
total_reads_train = torch.tensor(total_reads_tuple, dtype=dtype)

freq_single_tuple, freq_pair_tuple, total_reads_tuple = zip(*[
    utils.frequences_from_sequences_oh(seq, pseudo_count=pseudocount) 
    for seq in sequences_test_oh])

fi_test = torch.stack(freq_single_tuple)
fij_test = torch.stack(freq_pair_tuple)
total_reads_test = torch.tensor(total_reads_tuple, dtype=dtype)

In [None]:
lr = 0.01
max_epochs = 5*10**4

params = indep_sites.init_parameters(fi_train)
params, history = indep_sites.train(
    fi=fi_train,
    total_reads=total_reads_train, 
    params=params,
    lr=lr,
    max_epochs=max_epochs,
    target_error=1e-6,
    progress_bar=False)

In [None]:
params_t = [indep_sites.get_params_at_round(params, t) for t in ts]
logNst_test = [- indep_sites.compute_energy(sequences_unique_test_oh[t], params_t[t])
               for t in ts]

In [None]:
def scatter_with_fit(x, y, **kwargs):
    fig, ax = plt.subplots(figsize=(3,3))
    ax.scatter(x, y, **kwargs)
    reg = RANSACRegressor().fit(x.reshape(-1, 1), y)
    ax.plot(x, reg.predict(x.reshape(-1,1)), color='r')
    return fig, ax

In [None]:
for t in ts:
    scatter_with_fit(logNst_test[t], torch.log(counts_test[t]), s=2)