In [None]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
%matplotlib inline
import os
import torch
print(torch.cuda.is_available())
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler, BatchSampler
from scipy import stats

from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

True


In [None]:
# config
filtering = 5
MODEL_PATH =  'model/all_large_6_log21p_activity_11262024_NGG_filtering%d_6layer_128_3linear_test.pth' % filtering
DATA_PATH = 'data/randomized_change_seq_large_data_log21p_activity_11262024_site.hdf5'
SITE_LIST = ['CCR5_s8', 'LAG3_s9', 'TRAC_s1', 'CTLA4_s9', 'AAVS1_s14']


BATCH_SIZE = 1024
NUM_EPOCHS = 300
gRNA = {'AAVS1_s14': 'GGGGCCACTAGGGACAGGATTGG',
        'CTLA4_s9': 'GGACTGAGGGCCATGGACACGGG',
        'TRAC_s1': 'GTCAGGGTTCTGGATATCTGTGG',
        'LAG3_s9': 'GAAGGCTGAGATCCTGGAGGGGG',
        'CXCR4_s8': 'GTCCCCTGAGCCCATTTCCTCGG',
        'CCR5_s8': 'GGACAGTAAGAAGGAAAAACAGG'}

# Loading dataset

In [None]:
# Loading dataset
class SeqData(Dataset):
    def __init__(self, X, y, seq=None):
        self.X = X
        self.y = y
        self.seq = seq
        self.length = len(self.y)


    def __getitem__(self, i):
        X = torch.tensor(self.X[i], dtype=torch.float32)
        y = torch.tensor(self.y[i], dtype = torch.float32)
        return (X, y)


    def __len__(self):
        return self.length


print('Loading dataset...')
X = []
y = []
seq = []
mismatches = []
len_list = []
source_site = []
control_counts = []
cas9_counts = []

for site in SITE_LIST:
    data_path = DATA_PATH.replace('site', site)
    with h5py.File(data_path, 'r') as f:
        for num_mismatches in range(0, 7):
            X_s = np.array(f[str(num_mismatches)]['X']).astype(np.float32)
            X_s[:,20,:4] = 0.25 # Set the first base of PAM to N
            y_s = np.array(f[str(num_mismatches)]['y']).astype(np.float32)
            seq_s = np.array(f[str(num_mismatches)]['seq']).astype(str)
            control_count_s = np.array(f[str(num_mismatches)]['control_counts']).astype(np.float32)
            cas9_count_s = np.array(f[str(num_mismatches)]['cas9_counts']).astype(np.float32)

            ind = np.nonzero(y_s != np.inf)[0]
            X.append(X_s[ind])
            y.append(y_s[ind])
            seq.append(seq_s[ind])
            control_counts.append(control_count_s[ind])
            cas9_counts.append(cas9_count_s[ind])

            mismatches.append([num_mismatches] * len(ind))
            source_site.append([site] * len(ind))
            print(num_mismatches, X_s.shape, y_s.shape)
            len_list.append(len(ind))
            del X_s, y_s, seq_s


X = np.concatenate(X)
y = np.concatenate(y)
seq = np.concatenate(seq)
mismatches = np.concatenate(mismatches)
source_site = np.concatenate(source_site)
control_counts = np.concatenate(control_counts)
cas9_counts = np.concatenate(cas9_counts)

print(X.shape, y.shape, mismatches.shape, source_site.shape, control_counts.shape)


Loading dataset...
0 (4, 23, 8) (4,)
1 (263, 23, 8) (263,)
2 (6548, 23, 8) (6548,)
3 (58797, 23, 8) (58797,)
4 (227458, 23, 8) (227458,)
5 (399260, 23, 8) (399260,)
6 (472397, 23, 8) (472397,)
0 (4, 23, 8) (4,)
1 (262, 23, 8) (262,)
2 (7298, 23, 8) (7298,)
3 (74356, 23, 8) (74356,)
4 (282869, 23, 8) (282869,)
5 (503703, 23, 8) (503703,)
6 (619366, 23, 8) (619366,)
0 (4, 23, 8) (4,)
1 (264, 23, 8) (264,)
2 (7471, 23, 8) (7471,)
3 (71368, 23, 8) (71368,)
4 (274688, 23, 8) (274688,)
5 (464658, 23, 8) (464658,)
6 (523566, 23, 8) (523566,)
0 (4, 23, 8) (4,)
1 (264, 23, 8) (264,)
2 (7006, 23, 8) (7006,)
3 (65819, 23, 8) (65819,)
4 (222382, 23, 8) (222382,)
5 (368976, 23, 8) (368976,)
6 (431718, 23, 8) (431718,)
0 (4, 23, 8) (4,)
1 (264, 23, 8) (264,)
2 (7507, 23, 8) (7507,)
3 (75916, 23, 8) (75916,)
4 (323177, 23, 8) (323177,)
5 (587621, 23, 8) (587621,)
6 (671959, 23, 8) (671959,)
(6757221, 23, 8) (6757221,) (6757221,) (6757221,) (6757221,)


In [None]:
# Construct the training set
np.random.seed(42)
ind_high_count = np.nonzero((control_counts >= 10))[0]
print(len(control_counts), len(ind_high_count))
test_ind = np.random.choice(ind_high_count, len(ind_high_count) // 10, replace = False)
print(test_ind, len(test_ind))
all_ind = np.arange(len(control_counts))
remaining_ind = all_ind[np.isin(all_ind, test_ind, invert = True)]
print(remaining_ind, len(remaining_ind))

mis_test = mismatches[test_ind]
print(np.unique(mis_test, return_counts = True))

np.random.seed(42)
ind_selected = np.nonzero((control_counts >= filtering))[0]
remaining_ind = remaining_ind[np.isin(remaining_ind, ind_selected)]
print(remaining_ind, len(remaining_ind))
train_ind, valid_ind = train_test_split(remaining_ind, test_size = 0.1, shuffle = True)
assert(len(set(train_ind).intersection(set(test_ind))) == 0) # make sure that there is no overlap
X_train, X_valid, X_test = X[train_ind], X[valid_ind], X[test_ind]
y_train, y_valid, y_test = y[train_ind], y[valid_ind], y[test_ind]
seq_train, seq_valid, seq_test = seq[train_ind], seq[valid_ind], seq[test_ind]
# del X, y

print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape, X_test.shape, y_test.shape)

dts_train = SeqData(X_train, y_train, seq_train)
dts_valid = SeqData(X_valid, y_valid)
dts_test = SeqData(X_test, y_test)


loader_train = DataLoader(dataset = dts_train,
                              batch_size = 8192,
                              pin_memory=True,
                              shuffle=True,
                              num_workers = 0)

loader_valid = DataLoader(dataset = dts_valid, \
                          batch_size = 8192,\
                          pin_memory=True,\
                          num_workers = 0,\
                          shuffle = False)


loader_test = DataLoader(dataset = dts_test, \
                          batch_size = 8192,\
                          pin_memory = True,\
                          num_workers = 0,\
                          shuffle = False)



6757221 2497088
[6731810 2675084 5216061 ... 5388470 1976540 3769999] 249708
[      0       2       3 ... 6757218 6757219 6757220] 6507513
(array([0, 1, 2, 3, 4, 5, 6]), array([    3,   120,  2775, 20608, 54852, 80378, 90972]))
[      0       2       3 ... 6757213 6757217 6757218] 3751228
(3376105, 23, 8) (3376105,) (375123, 23, 8) (375123,) (249708, 23, 8) (249708,)


In [None]:
def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A5 GPU|grep Free > ./tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return int(np.argmax(memory_available))

id = get_free_gpu()
device = torch.device("cuda:%d" % id)

class CHANGENET(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        dropout = 0.2
        hidden_dim = 128

        self.seq_length = 23
        self.layers.append(nn.Conv1d(in_channels = 8, out_channels = hidden_dim, kernel_size = 3, padding = 1))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 3, padding = 1))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 3, padding = 1))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 5, padding = 2))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 5, padding = 2))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = 10, kernel_size = 5, padding = 2))
        self.layers.append(nn.BatchNorm1d(10, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Flatten())

        self.layers.append(nn.Linear(self.seq_length * 10, 128))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Linear(128, 32))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Linear(32, 1))


    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out


model = CHANGENET()
model.to(device)
mseloss = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), weight_decay = 1e-5)

# Training the model

In [None]:
from scipy import stats
def train(model, loader_train, loader_valid):
    best_epoch = -1
    best_pearsonr = 0

    for r in range(num_epochs):
        model.train()
        train_loss_list = []
        mse_loss_list = []
        ranking_loss_list = []

        for i, (X,y) in enumerate(tqdm(loader_train)):
            (X, y) = (X.to(device), y.to(device))
            X = torch.transpose(X, 1, 2)
            y = y.view(y.shape[0], 1)
            output = model(X).reshape(-1,1)
            loss = mseloss(output, y)
            train_loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        valid_loss_list = []
        y_true_list = []
        y_pred_list = []
        for i, (X,y) in enumerate(loader_valid):
            (X, y) = (X.to(device), y.to(device))
            X = torch.transpose(X, 1, 2)

            y = y.view(y.shape[0], 1)
            output = model(X)
            loss = mseloss(output, y)
            valid_loss_list.append(loss.item())
            y_true_list.append(y.cpu().detach().numpy().reshape(-1,))
            y_pred_list.append(output.cpu().detach().numpy().reshape(-1,))
        y_true_list = np.concatenate(y_true_list)
        y_pred_list = np.concatenate(y_pred_list)
        corr = np.corrcoef(y_true_list, y_pred_list)[0,1]
        spearmanr = stats.spearmanr(y_true_list, y_pred_list).statistic
        print('Epoch %d, train loss %.7f, valid loss %.7f, valid pearson %.3f, valid spearman %.3f' % (r, np.mean(train_loss_list), np.mean(valid_loss_list), corr, spearmanr))

        if corr > best_pearsonr:
            best_epoch = r
            best_pearsonr = corr
            torch.save({
                'epoch': r,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, MODEL_PATH)
# large_5_sample_by_bin_up_around: upsample bins between 0.1 and 0.4
num_epochs = 100
print('Start training...')
train(model, loader_train, loader_valid)

Start training...


100%|██████████| 413/413 [01:48<00:00,  3.81it/s]


Epoch 0, train loss 0.2674089, valid loss 0.2823876, valid pearson 0.887, valid spearman 0.825


100%|██████████| 413/413 [01:48<00:00,  3.81it/s]


Epoch 1, train loss 0.1958488, valid loss 0.2063117, valid pearson 0.900, valid spearman 0.832


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 2, train loss 0.1844929, valid loss 0.1709180, valid pearson 0.902, valid spearman 0.835


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 3, train loss 0.1781500, valid loss 0.1649811, valid pearson 0.904, valid spearman 0.837


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 4, train loss 0.1736734, valid loss 0.1981403, valid pearson 0.902, valid spearman 0.837


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 5, train loss 0.1709175, valid loss 0.1703352, valid pearson 0.905, valid spearman 0.838


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 6, train loss 0.1687759, valid loss 0.1507666, valid pearson 0.910, valid spearman 0.840


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 7, train loss 0.1668939, valid loss 0.1689624, valid pearson 0.906, valid spearman 0.839


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 8, train loss 0.1653027, valid loss 0.1612870, valid pearson 0.907, valid spearman 0.841


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 9, train loss 0.1640317, valid loss 0.1507917, valid pearson 0.910, valid spearman 0.842


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 10, train loss 0.1629460, valid loss 0.1497314, valid pearson 0.910, valid spearman 0.843


100%|██████████| 413/413 [01:46<00:00,  3.87it/s]


Epoch 11, train loss 0.1619359, valid loss 0.1685218, valid pearson 0.908, valid spearman 0.841


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 12, train loss 0.1610274, valid loss 0.1559530, valid pearson 0.909, valid spearman 0.842


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 13, train loss 0.1604582, valid loss 0.1502132, valid pearson 0.911, valid spearman 0.842


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 14, train loss 0.1597247, valid loss 0.1456519, valid pearson 0.912, valid spearman 0.844


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 15, train loss 0.1591445, valid loss 0.1515215, valid pearson 0.912, valid spearman 0.843


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 16, train loss 0.1585768, valid loss 0.1439902, valid pearson 0.912, valid spearman 0.844


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 17, train loss 0.1577569, valid loss 0.1505881, valid pearson 0.912, valid spearman 0.845


100%|██████████| 413/413 [01:46<00:00,  3.87it/s]


Epoch 18, train loss 0.1575152, valid loss 0.1444706, valid pearson 0.912, valid spearman 0.844


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 19, train loss 0.1570707, valid loss 0.1443264, valid pearson 0.912, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 20, train loss 0.1566647, valid loss 0.1465178, valid pearson 0.912, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 21, train loss 0.1562755, valid loss 0.1448331, valid pearson 0.913, valid spearman 0.845


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 22, train loss 0.1557227, valid loss 0.1438142, valid pearson 0.913, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 23, train loss 0.1555720, valid loss 0.1482150, valid pearson 0.913, valid spearman 0.845


100%|██████████| 413/413 [01:48<00:00,  3.79it/s]


Epoch 24, train loss 0.1550560, valid loss 0.1440088, valid pearson 0.913, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 25, train loss 0.1548556, valid loss 0.1460128, valid pearson 0.913, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 26, train loss 0.1544697, valid loss 0.1447529, valid pearson 0.913, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 27, train loss 0.1543634, valid loss 0.1442411, valid pearson 0.912, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 28, train loss 0.1539626, valid loss 0.1435564, valid pearson 0.913, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 29, train loss 0.1537308, valid loss 0.1418924, valid pearson 0.913, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 30, train loss 0.1535223, valid loss 0.1467002, valid pearson 0.912, valid spearman 0.845


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 31, train loss 0.1534437, valid loss 0.1439450, valid pearson 0.913, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 32, train loss 0.1531524, valid loss 0.1445817, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:48<00:00,  3.81it/s]


Epoch 33, train loss 0.1528833, valid loss 0.1413539, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 34, train loss 0.1528781, valid loss 0.1433163, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 35, train loss 0.1527051, valid loss 0.1424428, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 36, train loss 0.1523936, valid loss 0.1416259, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 37, train loss 0.1520516, valid loss 0.1412475, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:48<00:00,  3.80it/s]


Epoch 38, train loss 0.1519839, valid loss 0.1417277, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:48<00:00,  3.81it/s]


Epoch 39, train loss 0.1519976, valid loss 0.1412213, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:48<00:00,  3.81it/s]


Epoch 40, train loss 0.1518420, valid loss 0.1428810, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:48<00:00,  3.82it/s]


Epoch 41, train loss 0.1515083, valid loss 0.1435391, valid pearson 0.914, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 42, train loss 0.1514366, valid loss 0.1423186, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 43, train loss 0.1510562, valid loss 0.1416913, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 44, train loss 0.1512106, valid loss 0.1414314, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 45, train loss 0.1509724, valid loss 0.1415794, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 46, train loss 0.1509286, valid loss 0.1439259, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 47, train loss 0.1508885, valid loss 0.1407027, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 48, train loss 0.1505281, valid loss 0.1395583, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 49, train loss 0.1504894, valid loss 0.1417981, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 50, train loss 0.1504692, valid loss 0.1411826, valid pearson 0.914, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 51, train loss 0.1504496, valid loss 0.1403889, valid pearson 0.914, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 52, train loss 0.1501647, valid loss 0.1420931, valid pearson 0.914, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 53, train loss 0.1501563, valid loss 0.1401025, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 54, train loss 0.1499890, valid loss 0.1399011, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 55, train loss 0.1499432, valid loss 0.1410673, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 56, train loss 0.1499849, valid loss 0.1404017, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 57, train loss 0.1497549, valid loss 0.1406737, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 58, train loss 0.1496660, valid loss 0.1403033, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 59, train loss 0.1495893, valid loss 0.1413112, valid pearson 0.914, valid spearman 0.847


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 60, train loss 0.1495536, valid loss 0.1400995, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 61, train loss 0.1494882, valid loss 0.1404447, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 62, train loss 0.1493679, valid loss 0.1396387, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.86it/s]


Epoch 63, train loss 0.1491973, valid loss 0.1392935, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:46<00:00,  3.86it/s]


Epoch 64, train loss 0.1492825, valid loss 0.1416367, valid pearson 0.914, valid spearman 0.846


100%|██████████| 413/413 [01:47<00:00,  3.82it/s]


Epoch 65, train loss 0.1492142, valid loss 0.1408185, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 66, train loss 0.1490913, valid loss 0.1396186, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 67, train loss 0.1488127, valid loss 0.1396407, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 68, train loss 0.1489244, valid loss 0.1418204, valid pearson 0.914, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 69, train loss 0.1487708, valid loss 0.1394833, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 70, train loss 0.1488039, valid loss 0.1408518, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 71, train loss 0.1487965, valid loss 0.1401423, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:48<00:00,  3.82it/s]


Epoch 72, train loss 0.1485079, valid loss 0.1412247, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 73, train loss 0.1486133, valid loss 0.1413276, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 74, train loss 0.1484699, valid loss 0.1395923, valid pearson 0.916, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 75, train loss 0.1484375, valid loss 0.1392215, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 76, train loss 0.1483610, valid loss 0.1396611, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 77, train loss 0.1482075, valid loss 0.1392681, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 78, train loss 0.1481377, valid loss 0.1396112, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 79, train loss 0.1482958, valid loss 0.1389340, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 80, train loss 0.1482569, valid loss 0.1406843, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 81, train loss 0.1480961, valid loss 0.1387363, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.83it/s]


Epoch 82, train loss 0.1480511, valid loss 0.1400853, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 83, train loss 0.1478800, valid loss 0.1405469, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 84, train loss 0.1478693, valid loss 0.1391058, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 85, train loss 0.1480102, valid loss 0.1399146, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 86, train loss 0.1478026, valid loss 0.1395343, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 87, train loss 0.1476147, valid loss 0.1398054, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 88, train loss 0.1478288, valid loss 0.1389927, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 89, train loss 0.1477438, valid loss 0.1389571, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 90, train loss 0.1476548, valid loss 0.1396112, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 91, train loss 0.1476083, valid loss 0.1389203, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 92, train loss 0.1476309, valid loss 0.1394606, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 93, train loss 0.1474159, valid loss 0.1388699, valid pearson 0.916, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 94, train loss 0.1474861, valid loss 0.1387662, valid pearson 0.916, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 95, train loss 0.1474666, valid loss 0.1395601, valid pearson 0.915, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.84it/s]


Epoch 96, train loss 0.1473458, valid loss 0.1384330, valid pearson 0.916, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 97, train loss 0.1472551, valid loss 0.1383601, valid pearson 0.916, valid spearman 0.848


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 98, train loss 0.1473285, valid loss 0.1409185, valid pearson 0.915, valid spearman 0.847


100%|██████████| 413/413 [01:47<00:00,  3.85it/s]


Epoch 99, train loss 0.1470982, valid loss 0.1392899, valid pearson 0.915, valid spearman 0.848


# Evaluating on the test set

In [None]:
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint['model_state_dict'])

  checkpoint = torch.load(MODEL_PATH)


<All keys matched successfully>

In [None]:
model.eval()

loss_list = []
true_list = []
pred_list = []
for i, (X,y) in enumerate(loader_test):
    (X, y) = (X.to(device), y.to(device))
    X = torch.transpose(X, 1, 2)
    y = y.view(y.shape[0], 1)
    output = model(X)
    loss = mseloss(output, y)
    loss_list.append(loss.item())
    true_list.append(y.cpu().detach().numpy().reshape(-1,))
    pred_list.append(output.cpu().detach().numpy().reshape(-1,))

true_list = np.concatenate(true_list)
pred_list = np.concatenate(pred_list)

print('Test loss', np.mean(loss_list))
print(np.corrcoef(true_list, pred_list))
source_test = source_site[test_ind]

print(len(test_ind), len(mismatches))
d = {'seq': seq_test,
     'gRNA': source_test,
     'mismatches': mismatches[test_ind],
     'gRNA_seq': [gRNA[s] for s in source_test],
     'log2FC_true': true_list,
     'log2FC_pred': pred_list
}
df = pd.DataFrame(data=d)
df.to_csv('results/prediction_results_MM6_log21p_activity_NGG_11262024_filtering%d_%s_test.csv' % (filtering, 'all'))





Test loss 0.09542680988388677
[[1.         0.94376154]
 [0.94376154 1.        ]]
249708 6757221


In [None]:
# Find all pairs of off-target sequences within edit distance 1 and compare their predicted vs actual change
def permute(seq, ref):
    if len(seq) == 23:
        neighbors = np.array([seq[:i] + base + seq[i+1:] for i in range(23) if (i != 20 and seq[i] == ref[i]) for base in ['A', 'C', 'G', 'T']])
        # neighbors = np.array([seq[:i] + base + seq[i+1:] for i in range(23) if (i != 20) for base in ['A', 'C', 'G', 'T']])
        neighbors = neighbors[np.nonzero(neighbors != seq)[0]]

    elif len(seq) == 47:
        seq1, seq2 = seq.split('_')[0],  seq.split('_')[1]
        neighbors = []
        neighbors.append(np.array([seq1[:i] + base + seq1[i+1:] for i in range(21) for base in ['A', 'C', 'G', 'T']]))
        neighbors.append(np.array([seq2[:i] + base + seq2[i+1:] for i in range(21) for base in ['A', 'C', 'G', 'T']]))
        neighbors = np.concatenate(neighbors)
        neighbors = neighbors[np.nonzero((neighbors != seq1) & (neighbors != seq2))[0]]
    return neighbors

mm_scores={'rU:dT,12': 0.8, 'rU:dT,13': 0.692307692, 'rU:dC,5': 0.64, 'rG:dA,14': 0.26666666699999997, 'rG:dG,19': 0.448275862, 'rG:dG,18': 0.47619047600000003, 'rG:dG,15': 0.272727273, 'rG:dG,14': 0.428571429, 'rG:dG,17': 0.235294118, 'rG:dG,16': 0.0, 'rC:dC,20': 0.058823529000000006, 'rG:dT,20': 0.9375, 'rG:dG,13': 0.42105263200000004, 'rG:dG,12': 0.529411765, 'rU:dC,6': 0.571428571, 'rU:dG,14': 0.28571428600000004, 'rU:dT,18': 0.666666667, 'rA:dG,13': 0.21052631600000002, 'rA:dG,12': 0.263157895, 'rA:dG,11': 0.4, 'rA:dG,10': 0.333333333, 'rA:dA,19': 0.538461538, 'rA:dA,18': 0.5, 'rA:dG,15': 0.272727273, 'rA:dG,14': 0.214285714, 'rA:dA,15': 0.2, 'rA:dA,14': 0.533333333, 'rA:dA,17': 0.133333333, 'rA:dA,16': 0.0, 'rA:dA,11': 0.307692308, 'rA:dA,10': 0.882352941, 'rA:dA,13': 0.3, 'rA:dA,12': 0.333333333, 'rG:dA,13': 0.3, 'rG:dA,12': 0.384615385, 'rG:dA,11': 0.384615385, 'rG:dA,10': 0.8125, 'rG:dA,17': 0.25, 'rG:dA,16': 0.0, 'rG:dA,15': 0.14285714300000002, 'rG:dA,6': 0.666666667, 'rG:dG,20': 0.428571429, 'rG:dA,19': 0.666666667, 'rG:dA,18': 0.666666667, 'rU:dC,4': 0.625, 'rG:dT,12': 0.933333333, 'rG:dT,13': 0.923076923, 'rU:dG,11': 0.666666667, 'rC:dA,3': 0.6875, 'rC:dA,2': 0.9090909090000001, 'rC:dA,1': 1.0, 'rC:dA,7': 0.8125, 'rC:dA,6': 0.9285714290000001, 'rC:dA,5': 0.636363636, 'rC:dA,4': 0.8, 'rC:dA,9': 0.875, 'rC:dA,8': 0.875, 'rU:dT,6': 0.8666666670000001, 'rA:dG,20': 0.22727272699999998, 'rG:dT,18': 0.692307692, 'rU:dG,10': 0.533333333, 'rG:dT,19': 0.7142857140000001, 'rG:dA,20': 0.7, 'rC:dT,20': 0.5, 'rU:dC,2': 0.84, 'rG:dG,10': 0.4, 'rC:dA,17': 0.46666666700000003, 'rC:dA,16': 0.307692308, 'rC:dA,15': 0.066666667, 'rC:dA,14': 0.7333333329999999, 'rC:dA,13': 0.7, 'rC:dA,12': 0.538461538, 'rC:dA,11': 0.307692308, 'rC:dA,10': 0.9411764709999999, 'rG:dG,11': 0.428571429, 'rU:dC,20': 0.176470588, 'rG:dG,3': 0.384615385, 'rC:dA,19': 0.46153846200000004, 'rC:dA,18': 0.642857143, 'rU:dG,17': 0.705882353, 'rU:dG,16': 0.666666667, 'rU:dG,15': 0.272727273, 'rG:dG,2': 0.692307692, 'rU:dG,13': 0.7894736840000001, 'rU:dG,12': 0.947368421, 'rG:dA,9': 0.533333333, 'rG:dA,8': 0.625, 'rG:dA,7': 0.571428571, 'rG:dG,5': 0.7857142859999999, 'rG:dA,5': 0.3, 'rG:dA,4': 0.363636364, 'rG:dA,3': 0.5, 'rG:dA,2': 0.636363636, 'rG:dA,1': 1.0, 'rG:dG,4': 0.529411765, 'rG:dG,1': 0.7142857140000001, 'rA:dC,9': 0.666666667, 'rG:dG,7': 0.6875, 'rG:dT,5': 0.8666666670000001, 'rU:dT,20': 0.5625, 'rC:dC,15': 0.05, 'rC:dC,14': 0.0, 'rC:dC,17': 0.058823529000000006, 'rC:dC,16': 0.153846154, 'rC:dC,11': 0.25, 'rC:dC,10': 0.38888888899999996, 'rC:dC,13': 0.13636363599999998, 'rC:dC,12': 0.444444444, 'rC:dA,20': 0.3, 'rC:dC,19': 0.125, 'rC:dC,18': 0.133333333, 'rA:dA,1': 1.0, 'rA:dA,3': 0.705882353, 'rA:dA,2': 0.727272727, 'rA:dA,5': 0.363636364, 'rA:dA,4': 0.636363636, 'rA:dA,7': 0.4375, 'rA:dA,6': 0.7142857140000001, 'rA:dA,9': 0.6, 'rA:dA,8': 0.428571429, 'rU:dG,20': 0.090909091, 'rC:dC,9': 0.6190476189999999, 'rC:dC,8': 0.642857143, 'rU:dT,10': 0.857142857, 'rU:dT,11': 0.75, 'rU:dT,16': 0.9090909090000001, 'rU:dT,17': 0.533333333, 'rU:dT,14': 0.6190476189999999, 'rU:dT,15': 0.578947368, 'rC:dC,1': 0.913043478, 'rU:dT,3': 0.7142857140000001, 'rC:dC,3': 0.5, 'rC:dC,2': 0.695652174, 'rC:dC,5': 0.6, 'rC:dC,4': 0.5, 'rC:dC,7': 0.470588235, 'rC:dC,6': 0.5, 'rU:dT,4': 0.47619047600000003, 'rU:dT,8': 0.8, 'rU:dT,9': 0.9285714290000001, 'rA:dC,19': 0.375, 'rA:dC,18': 0.4, 'rA:dC,17': 0.176470588, 'rA:dC,16': 0.192307692, 'rA:dC,15': 0.65, 'rA:dC,14': 0.46666666700000003, 'rA:dC,13': 0.6521739129999999, 'rA:dC,12': 0.7222222220000001, 'rA:dC,11': 0.65, 'rA:dC,10': 0.5555555560000001, 'rU:dC,7': 0.588235294, 'rC:dT,8': 0.65, 'rC:dT,9': 0.857142857, 'rC:dT,6': 0.9285714290000001, 'rC:dT,7': 0.75, 'rC:dT,4': 0.842105263, 'rC:dT,5': 0.571428571, 'rC:dT,2': 0.727272727, 'rC:dT,3': 0.8666666670000001, 'rC:dT,1': 1.0, 'rA:dC,8': 0.7333333329999999, 'rU:dT,1': 1.0, 'rU:dC,3': 0.5, 'rU:dC,1': 0.956521739, 'rU:dT,2': 0.846153846, 'rU:dG,19': 0.275862069, 'rG:dT,14': 0.75, 'rG:dT,15': 0.9411764709999999, 'rG:dT,16': 1.0, 'rG:dT,17': 0.933333333, 'rG:dT,10': 0.933333333, 'rG:dT,11': 1.0, 'rA:dG,9': 0.571428571, 'rA:dG,8': 0.428571429, 'rA:dG,7': 0.4375, 'rA:dG,6': 0.454545455, 'rA:dG,5': 0.5, 'rA:dG,4': 0.352941176, 'rA:dG,3': 0.428571429, 'rA:dG,2': 0.7857142859999999, 'rA:dG,1': 0.857142857, 'rU:dT,5': 0.5, 'rG:dT,2': 0.846153846, 'rA:dC,3': 0.611111111, 'rA:dC,20': 0.764705882, 'rG:dT,1': 0.9, 'rG:dT,6': 1.0, 'rG:dT,7': 1.0, 'rG:dT,4': 0.9, 'rC:dT,19': 0.428571429, 'rG:dG,9': 0.538461538, 'rG:dG,8': 0.615384615, 'rG:dT,8': 1.0, 'rG:dT,9': 0.642857143, 'rU:dG,18': 0.428571429, 'rU:dT,7': 0.875, 'rG:dG,6': 0.681818182, 'rA:dA,20': 0.6, 'rU:dC,9': 0.6190476189999999, 'rA:dG,17': 0.176470588, 'rU:dC,8': 0.7333333329999999, 'rA:dG,16': 0.0, 'rA:dG,19': 0.20689655199999998, 'rG:dT,3': 0.75, 'rU:dG,3': 0.428571429, 'rU:dG,2': 0.857142857, 'rU:dG,1': 0.857142857, 'rA:dG,18': 0.19047619, 'rU:dG,7': 0.6875, 'rU:dG,6': 0.9090909090000001, 'rU:dG,5': 1.0, 'rU:dG,4': 0.647058824, 'rU:dG,9': 0.923076923, 'rU:dG,8': 1.0, 'rU:dC,19': 0.25, 'rU:dC,18': 0.333333333, 'rU:dC,13': 0.260869565, 'rU:dC,12': 0.5, 'rU:dC,11': 0.4, 'rU:dC,10': 0.5, 'rU:dC,17': 0.117647059, 'rU:dC,16': 0.346153846, 'rU:dC,15': 0.05, 'rU:dC,14': 0.0, 'rC:dT,10': 0.8666666670000001, 'rC:dT,11': 0.75, 'rC:dT,12': 0.7142857140000001, 'rC:dT,13': 0.384615385, 'rC:dT,14': 0.35, 'rC:dT,15': 0.222222222, 'rC:dT,16': 1.0, 'rC:dT,17': 0.46666666700000003, 'rC:dT,18': 0.538461538, 'rA:dC,2': 0.8, 'rA:dC,1': 1.0, 'rA:dC,7': 0.705882353, 'rA:dC,6': 0.7142857140000001, 'rA:dC,5': 0.72, 'rA:dC,4': 0.625, 'rU:dT,19': 0.28571428600000004}
pam_scores={'AA': 0.0, 'AC': 0.0, 'GT': 0.016129031999999998, 'AG': 0.25925925899999996, 'CC': 0.0, 'CA': 0.0, 'CG': 0.107142857, 'TT': 0.0, 'GG': 1.0, 'GC': 0.022222222000000003, 'AT': 0.0, 'GA': 0.06944444400000001, 'TG': 0.038961038999999996, 'TA': 0.0, 'TC': 0.0, 'CT': 0.0}

def reverse_complement(seq):
    base_rev_comp = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A','U':'A', '-':'-'}
    rc_seq = ''
    for base in seq:
        rc_seq += base_rev_comp[base]
    return rc_seq[::-1]

def calc_cfd(wt,sg,pam):
    score = 1
    sg = sg.replace('T','U')
    wt = wt.replace('T','U')
    s_list = list(sg)
    wt_list = list(wt)
    for i,sl in enumerate(s_list):
        if wt_list[i] == sl:
            score *= 1
        else:
            key = 'r' + wt_list[i] + ':d' + reverse_complement(sl) + ',' + str(i+1)
            score *= mm_scores[key]
    score *= pam_scores[pam]
    return 100 * score

data = []
row_ind = []
col_ind = []
seq_to_ind_dict = {seq_test[i]: i for i in range(len(seq_test))}
CFD_list = []
for i in range(len(seq_test)):
    g = gRNA[source_test[i]]
    CFD_score = calc_cfd(g[:20], seq_test[i][:20], seq_test[i][21:])
    CFD_list.append(CFD_score)
CFD_list = np.array(CFD_list)

for i in range(len(seq_test)):
    if i % 10000 == 0:
        print(i)
    neighbors = permute(seq_test[i], gRNA[source_test[i]])
    neighbors_ind = np.array([seq_to_ind_dict.get(ngb, -1) for ngb in neighbors])
    neighbors_ind = neighbors_ind[np.nonzero(neighbors_ind != -1)[0]]
    neighbors_ind = neighbors_ind[np.nonzero(source_test[neighbors_ind] == source_test[i])[0]]

    self_ind = [i] * len(neighbors_ind)
    indicator = np.where(y_test[self_ind] - y_test[neighbors_ind] > 0, 1, -1)

    row_ind.extend(self_ind)
    col_ind.extend(neighbors_ind)
    data.extend(indicator.tolist())


0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000


In [None]:
print(len(row_ind))
# plt.figure(figsize = (7,7))

# plt.scatter(true_list[row_ind] - true_list[col_ind], pred_list[row_ind] - pred_list[col_ind], s = 2, c = 'black')
# plt.scatter(true_list[row_ind] - true_list[col_ind], CFD_list[row_ind] - CFD_list[col_ind], s = 2, c = 'black')
mis_test = mismatches[test_ind]
sign = (true_list[row_ind] - true_list[col_ind]) * (pred_list[row_ind] - pred_list[col_ind])
same_sign = np.nonzero(sign > 0)[0]
print(len(sign), len(same_sign))

pearsonr = np.corrcoef(true_list[row_ind] - true_list[col_ind], pred_list[row_ind] - pred_list[col_ind])[0,1]
spearmanr = stats.spearmanr(true_list[row_ind] - true_list[col_ind], pred_list[row_ind] - pred_list[col_ind]).statistic
acc = len(same_sign) / len(sign)
print(pearsonr, spearmanr, acc)
# plt.text(0.01, 0.99, 'Pearson r: %.3f\nSpearman r: %.3f\nAcc: %.3f' % (pearsonr, spearmanr, acc), transform=plt.gca().transAxes,
#          verticalalignment='top', horizontalalignment='left')
# plt.xlabel('Actual log2FC difference')
# plt.ylabel('Predicted log2FC difference')

# Define mutation types
wobble_mutations = {'G': 'A', 'T': 'C'}
wobblelike_mutations = {'A': 'G', 'C': 'T'}
transitions = {'A': 'G', 'G': 'A', 'C': 'T', 'T': 'C'}
transversions = {'A': ['T', 'C'], 'C': ['A','G'], 'G': ['C', 'T'], 'T': ['A','G']}
hoogsteen_mutations = {'A': 'C', 'G': ['T','C']}

pos_list = []
wobble_list = []
wobblelike_list = []
transition_list = []
transversion_list = []
hoogsteen_list = []

for i in range(len(row_ind)):
    pos = 0
    while seq_test[row_ind[i]][pos] == seq_test[col_ind[i]][pos] and pos != 20:
        pos += 1
    pos_list.append(pos)
    b1, b2 = seq_test[row_ind[i]][pos], seq_test[col_ind[i]][pos]
    wobble_list.append(wobble_mutations.get(b1, None) == b2)
    wobblelike_list.append(wobblelike_mutations.get(b1, None) == b2)
    transition_list.append(transitions.get(b1, None) == b2)
    transversion_list.append(b2 in transversions.get(b1, []))
    hoogsteen_list.append(b2 in hoogsteen_mutations.get(b1, []))


import pandas as pd
df = {'seq_1': seq_test[row_ind],
      'seq_2': seq_test[col_ind],
      'target': source_test[row_ind],
      'MM_1': mis_test[row_ind],
      'MM_2': mis_test[col_ind],
      'log2FC_1': true_list[row_ind],
      'log2FC_2': true_list[col_ind],
      'predicted_log2FC_1': pred_list[row_ind],
      'predicted_log2FC_2': pred_list[col_ind],
      'CFD_1': CFD_list[row_ind],
      'CFD_2': CFD_list[col_ind],
      'delta_log2FC': true_list[col_ind] - true_list[row_ind],
      'delta_predicted_log2FC': pred_list[col_ind] - pred_list[row_ind],
      'delta_CFD': CFD_list[col_ind] - CFD_list[row_ind],
      'position': pos_list,
      'wobble_mutations': wobble_list,
      'wobblelike_mutations': wobblelike_list,
      'transitions': transition_list,
      'transversions': transversion_list,
      'hoogsteen_mutations': hoogsteen_list,
     }
df = pd.DataFrame(df)
df.to_csv('results/pair_table_log21p_activity_filtering%d_test.csv' % filtering)



27701
27701 22838
0.8544172436377268 0.7750522730142403 0.8244467708746976
