In [1]:
import torch
from itertools import product
import torch.nn.functional as F
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import numpy as np
import seaborn as sns
from collections import Counter
from src.bio_utils import seqlogo_from_msa
import logomaker as lm
from scipy.stats import mode
from math import ceil, floor, sqrt
from sklearn.metrics import r2_score
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import time
import RNA

In [2]:
def add_transcrib_eff(virus_name, date):
    data_df = pd.read_csv(f"./data/{virus_name}_{date}.csv")
    utr_df = pd.read_csv('../data/sequencing/tol_seq.csv', usecols=['utr','pre_exp'])
    data_df['utr'] = data_df.seq.str[:5]
    merge_df = pd.merge(left=data_df,right=utr_df,how='inner',on='utr')
    merge_df['score_final'] = merge_df.score - merge_df.pre_exp
    merge_df.to_csv(f'./data/{virus_name}_{date}_tol_seq.csv',index=None)

def seq2tensor(X):
    corpus = ["".join(s) for s in list(product(*["ACGU"]))]
    tok2idx = {s: i for i, s in enumerate(corpus)}
    X = [[tok2idx[X[j][i : i + 1]] for i in range(len(X[j]))] for j in range(len(X))]
    X = torch.tensor(X).type(torch.int64)
    X = F.one_hot(X,num_classes=4).type(torch.float)
    return X

def struc2tensor(X):
    corpus = ["".join(s) for s in list(product(*[".()"]))]
    tok2idx = {s: i for i, s in enumerate(corpus)}
    X = [[tok2idx[X[j][i : i + 1]] for i in range(len(X[j]))] for j in range(len(X))]
    X = torch.tensor(X).type(torch.int64)
    X = F.one_hot(X,num_classes=3).type(torch.float)
    return X

def seq2tsor_transformer(X):
    corpus = ["".join(s) for s in list(product(*["ACGU"]))]
    tok2idx = {s: i for i, s in enumerate(corpus)}
    X = [[tok2idx[X[j][i : i + 1]] for i in range(len(X[j]))] for j in range(len(X))]
    X = torch.LongTensor(X)
    return X

def random_split(virus_name,date):
    merge_df = pd.read_csv(f'./data/{virus_name}_{date}_tol_seq.csv')
    train, test = train_test_split(merge_df, test_size=0.05)
    train.to_csv(f'./data/{virus_name}_{date}_train_tol_seq.csv',index=None)
    test.to_csv(f'./data/{virus_name}_{date}_test_tol_seq.csv',index=None)

def split_for_distribu(virus_name,date):
    data1 = pd.read_csv(f'./data/{virus_name}_{date}_tol_seq.csv')
    data1 = data1.sort_values(by = [ 'distance' , 'rna_counts' ], ascending= ( True , False ))
    test_data = pd.DataFrame(columns=data1.columns)
    train_data = pd.DataFrame(columns=data1.columns)
    for i in range(1 , data1.distance.max()+1 , 1):
        data2 = data1.loc[data1['distance'] == i]
        for j in range(int(len(data2) / 20 )):
            df1 = data2.iloc[1 + j * 19]  
            test_data = pd.concat([test_data,pd.DataFrame(df1).transpose()])
    df2 = pd.concat([data1,test_data])
    train_data = df2.drop_duplicates(subset=['ID'],keep=False)
    train_data = train_data.sample(frac=1, random_state=42).reset_index(drop=True)
    train_data.to_csv(f'./data/{virus_name}_{date}_train_tol_seq.csv', index = None)
    test_data.to_csv(f'./data/{virus_name}_{date}_test_tol_seq.csv', index = None)

def count_paras(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# VEE
para_list = []
for lay1 in range(3,10,2):
    for lay2 in range(3,19,4):
        for lay3 in range(3,13,2):
            para_list.append([lay1,lay2,lay3])
len(para_list)
# para_list
# SFV
para_list = []
for lay1 in range(3,23,4):
    for lay2 in range(3,23,4):
        for lay3 in range(3,23,4):
            para_list.append([lay1,lay2,lay3])
len(para_list)
# para_list

para_list = []
for lay1 in range(0,85-11):
        para_list.append([lay1,lay1+11])
len(para_list)

74

In [3]:
class MotifFind(nn.Module):
    def __init__(self,paras):
        super(MotifFind, self).__init__()
        self.convs1_l1 = nn.Conv1d(4, 1, paras[0], padding='same')
        self.convs1_l2 = nn.Conv1d(4, 1, paras[1], padding='same') 
        self.convs1_l3 = nn.Conv1d(4, 1, paras[2], padding='same')
        self.convs2_l = nn.Conv1d(3*1, paras[3], 21, padding='same')   
        self.actia = nn.Softplus()

    def forward(self, x):
        x1 = x[:, :, 0:10]  # 20
        x2 = x[:, :, 11:30] # 21:40
        x3 = x[:, :,31:44] # 41:60
        # x1 = x[:, :, 0:20]  # 20
        # x2 = x[:, :, 21:40] # 21:40
        # x3 = x[:, :, 41:60] # 41:60
        c1_l1 = self.actia(self.convs1_l1(x1))
        c1_l2 = self.actia(self.convs1_l2(x2))
        c1_l3 = self.actia(self.convs1_l3(x3))
        max_len = 21
        c1_l1 = F.pad(c1_l1, (0, max_len - c1_l1.size(2)))
        c1_l2 = F.pad(c1_l2, (0, max_len - c1_l2.size(2)))
        c1_l3 = F.pad(c1_l3, (0, max_len - c1_l3.size(2)))
        c1_l = torch.cat((c1_l1, c1_l2, c1_l3), dim=1)    
        c2_l = self.actia(self.convs2_l(c1_l))
        cnn_out1 = F.max_pool1d(c2_l, int(c2_l.size(2)))
        cnn_out = cnn_out1.squeeze(2)
        y_pre = torch.log(cnn_out.sum(axis=1))
        return y_pre

In [7]:
num = 0
virus_name="SFV"
# for paras in para_list: # 15,9,9
for paras in [[15,15,9,9]]:
    num +=1
    time_start = time.time()
    model = MotifFind(paras=paras)
    model_name = 'Motifind'

    data = pd.read_csv("./data/SFV_train.csv", 
                    usecols=['seq','score'])
    data.columns = ['seq','score']
    model.train()
    criterion = nn.MSELoss()
    batch_num = 32
    epochs = 50
    opt = torch.optim.Adam(model.parameters(), lr=5e-3) # suitable for motiffind deepcnn
    epoch_loss = []
    epoch_r = []
    epoch_r2 = []
    for epoch in tqdm(range(epochs)):
        losses = []
        pearsonrs = []
        r2s = []
        data_random = data.sample(n=len(data)).reset_index(drop=True)
        seq = data_random.seq.tolist()
        score = data_random.score.tolist()
        for i in range(0,len(seq),batch_num):
            X = seq[i:i+batch_num]
            Y = score[i:i+batch_num]
            Y = torch.tensor(Y).type(torch.float)
            X = seq2tensor(X).permute(0,2,1)
            output = model(X)
            loss = criterion(output, Y)
            # regularization
            # l_reg = torch.norm(model.convs1_l.weight)**2
            # loss += model.alpha * l_reg
            model.zero_grad()
            loss.backward()
            opt.step()
            losses.append(loss.item())
            r,_ = pearsonr(output.detach().numpy(), Y.detach().numpy())
            r2 = r2_score(Y.detach().numpy(), output.detach().numpy())
            pearsonrs.append(r)
            r2s.append(r2)
        epoch_loss.append(np.mean(losses))
        epoch_r.append(np.mean(pearsonrs))
        epoch_r2.append(np.mean(r2s))
    time_end = time.time()
    time_cost = time_end - time_start


    model.eval()
    test_data = pd.read_csv("./data/SFV_test.csv", 
                    usecols=['seq','score'])
    test_data.columns = ['seq','score']
    test_seq = test_data.seq.tolist()
    test_score = test_data.score.tolist()
    X = test_seq
    Y = test_score
    Y = torch.tensor(Y).type(torch.float)
    X = seq2tensor(X).permute(0,2,1)
    criterion = nn.MSELoss()
    with torch.no_grad():
        output = model(X)
        loss = criterion(output, Y)

    test_r, p_val = pearsonr(output.detach().numpy(), Y.detach().numpy())
    test_r2 = r2_score(Y.detach().numpy(), output.detach().numpy())

    """np.save(f"./model/{model_name}_{virus_name}_paras_{paras[3]}.npy", {"epoch_loss": epoch_loss, "epoch_r": epoch_r, 'time_cost': time_cost, 'paras':paras, 
            'test_r':test_r,'test_r2':test_r2,}
            )"""
    torch.save(model.state_dict(), f"./model/{model_name}_{virus_name}_paras_{paras[3]}")

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [01:20<00:00,  1.62s/it]


In [8]:
num = 0
virus_name="VEE"
# for paras in para_list: # 15,9,9
for paras in [[15,15,15,3]]:
    num +=1
    time_start = time.time()
    model = MotifFind(paras=paras)
    model_name = 'Motifind'
    # model = DeepCNN_BiLSTM()
    # model_name = 'DeepCNN_BiLSTM'

    data = pd.read_csv("./data/VEE_train.csv",
                    usecols=['seq','score'])
    data.columns = ['seq','score']
    model.train()
    criterion = nn.MSELoss()
    batch_num = 32
    epochs = 50
    opt = torch.optim.Adam(model.parameters(), lr=5e-3) # suitable for motiffind deepcnn
    epoch_loss = []
    epoch_r = []
    epoch_r2 = []
    for epoch in tqdm(range(epochs)):
        losses = []
        pearsonrs = []
        r2s = []
        data_random = data.sample(n=len(data)).reset_index(drop=True)
        seq = data_random.seq.tolist()
        score = data_random.score.tolist()
        for i in range(0,len(seq),batch_num):
            X = seq[i:i+batch_num]
            Y = score[i:i+batch_num]
            Y = torch.tensor(Y).type(torch.float)
            X = seq2tensor(X).permute(0,2,1)
            output = model(X)
            loss = criterion(output, Y)
            # regularization
            # l_reg = torch.norm(model.convs1_l.weight)**2
            # loss += model.alpha * l_reg
            model.zero_grad()
            loss.backward()
            opt.step()
            losses.append(loss.item())
            r,_ = pearsonr(output.detach().numpy(), Y.detach().numpy())
            r2 = r2_score(Y.detach().numpy(), output.detach().numpy())
            pearsonrs.append(r)
            r2s.append(r2)
        epoch_loss.append(np.mean(losses))
        epoch_r.append(np.mean(pearsonrs))
        epoch_r2.append(np.mean(r2s))
    time_end = time.time()
    time_cost = time_end - time_start


    model.eval()
    test_data = pd.read_csv("./data/VEE_test.csv",
                    usecols=['seq','score'])
    test_data.columns = ['seq','score']
    test_seq = test_data.seq.tolist()
    test_score = test_data.score.tolist()
    X = test_seq
    Y = test_score
    Y = torch.tensor(Y).type(torch.float)
    X = seq2tensor(X).permute(0,2,1)
    criterion = nn.MSELoss()
    with torch.no_grad():
        output = model(X)
        loss = criterion(output, Y)

    test_r, p_val = pearsonr(output.detach().numpy(), Y.detach().numpy())
    test_r2 = r2_score(Y.detach().numpy(), output.detach().numpy())

    """np.save(f"./model/{model_name}_{virus_name}_paras_{paras[3]}.npy", {"epoch_loss": epoch_loss, "epoch_r": epoch_r, 'time_cost': time_cost, 'paras':paras, 
            'test_r':test_r,'test_r2':test_r2,}
            )"""
    torch.save(model.state_dict(), f"./model/{model_name}_{virus_name}_paras_{paras[3]}")

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [01:45<00:00,  2.11s/it]


In [11]:
model = MotifFind(paras=[15,15,9,9])
model_name = 'MotifFind'
model.load_state_dict(torch.load("./model/Motifind_SFV_paras_9"))
model.eval()

MotifFind(
  (convs1_l1): Conv1d(4, 1, kernel_size=(15,), stride=(1,), padding=same)
  (convs1_l2): Conv1d(4, 1, kernel_size=(15,), stride=(1,), padding=same)
  (convs1_l3): Conv1d(4, 1, kernel_size=(9,), stride=(1,), padding=same)
  (convs2_l): Conv1d(3, 9, kernel_size=(21,), stride=(1,), padding=same)
  (actia): Softplus(beta=1.0, threshold=20.0)
)

In [13]:
model = MotifFind(paras=[15,15,15,3])
model_name = 'MotifFind'
model.load_state_dict(torch.load("./model/Motifind_VEE_paras_3"))
model.eval()

MotifFind(
  (convs1_l1): Conv1d(4, 1, kernel_size=(15,), stride=(1,), padding=same)
  (convs1_l2): Conv1d(4, 1, kernel_size=(15,), stride=(1,), padding=same)
  (convs1_l3): Conv1d(4, 1, kernel_size=(15,), stride=(1,), padding=same)
  (convs2_l): Conv1d(3, 3, kernel_size=(21,), stride=(1,), padding=same)
  (actia): Softplus(beta=1.0, threshold=20.0)
)

In [None]:
model.eval()
test_seq = test_data.seq.tolist()
test_score = test_data.score.tolist()

X = seq
Y = score
Y = torch.tensor(Y).type(torch.float)
X = seq2tensor(X).permute(0,2,1)
criterion = nn.MSELoss()
with torch.no_grad():
    output = model(X)
    loss = criterion(output, Y)

_, ax = plt.subplots(1, 1, figsize=(4,4))
ax.scatter( Y.detach().numpy(), output.detach().numpy(),alpha=0.45, s=50, facecolor="gray")
r, p_val = pearsonr(output.detach().numpy(), Y.detach().numpy())
print(r, p_val)
r2 = r2_score(Y.detach().numpy(), output.detach().numpy())
print(r2)

X = test_seq
Y = test_score
Y = torch.tensor(Y).type(torch.float)
X = seq2tensor(X).permute(0,2,1)
criterion = nn.MSELoss()
with torch.no_grad():
    output = model(X)
    loss = criterion(output, Y)

_, ax = plt.subplots(1, 1, figsize=(4,4))
ax.scatter( Y.detach().numpy(), output.detach().numpy(),alpha=0.45, s=50, facecolor="gray")
r, p_val = pearsonr(output.detach().numpy(), Y.detach().numpy())
print(r, p_val)


In [15]:
model.eval()
X = seq
Y = score
Y = torch.tensor(Y).type(torch.float)
X = seq2tensor(X).permute(0,2,1)
x1 = X[:, :, 0:10]  # 20
x2 = X[:, :, 20:40] # 21:40
x3 = X[:, :, 41:60] # 41:60
# x1 = X[:, :, 0:20]  # 20
# x2 = X[:, :, 21:40] # 21:40
# x3 = X[:, :, 41:60] # 41:60
layer_output = model.actia(model.convs1_l1(x1)).detach().numpy()

In [None]:
filter_size =7
window_idx = 0
window = 'l1'
if virus_name=="SFV":
    df = pd.read_csv("../data/alphavirus/SFV_RandomMutants_Final.csv")
else:
    df = pd.read_csv("../data/alphavirus/VEE_RandomMutants_Final.csv")
def generate_positional_counts_mat(seqs, centers, filter_size):
    counts_df_group = []
    center_dict = dict(Counter(centers))
    for i in range(filter_size):
        nt_counts_all = {"A": 0, "U": 0, "C": 0, "G": 0}
        for center in center_dict.keys():
            if (
                center - filter_size // 2 + i >= 0
                and center - filter_size // 2 + i < len(seqs[0])
            ):
                nt_counts = dict(
                    Counter([s[center - filter_size // 2 + i] for s in seqs])
                )
            else:
                nt_counts = {"A": 1, "U": 1, "C": 1, "G": 1}
            for nt in nt_counts.keys():
                nt_counts_all[nt] += nt_counts[nt] * center_dict[center]
        counts_df_group.append(pd.DataFrame.from_records([nt_counts_all]))
    counts_df = pd.concat(counts_df_group).reset_index(drop=True)
    return counts_df

bg_seqs = df["seq"].tolist()
kernel_data = []
kernel_data_dict = []
for i in range(layer_output.shape[1]):
# for i in range(0,1):
    curr_kernel_output = layer_output[:, i, :]
    
    max_activ_idx = np.array(
        [
            np.argmax(curr_kernel_output[j, :]) + window_idx
            for j in range(curr_kernel_output.shape[0])
        ]
    )
    max_activation_vals = np.max(curr_kernel_output, axis=1)
    activ_thresh = max_activation_vals[
        np.argsort(max_activation_vals)[-len(max_activation_vals) // 10]
    ]
    activ_indices = np.where(max_activation_vals > activ_thresh)[0]

    # Abort if too few sequences were activated.
    if len(activ_indices) < 5:
        continue

    seqs = seq

    motifs = []
    centers = []
    for j, center in enumerate(max_activ_idx):
        if (
            center >= filter_size // 2
            and center <= len(seqs[0]) - filter_size // 2 - 1
        ):
            motif = seqs[j][
                center - filter_size // 2 : center + filter_size // 2 + 1
            ]
        elif center < filter_size // 2:
            motif = seqs[j][0 : center + filter_size // 2 + 1]
            motif = "-" * (filter_size - len(motif)) + motif
        elif center > len(seqs[0]) - filter_size // 2 - 1:
            motif = seqs[j][center - filter_size // 2 :]
            motif += "-" * (filter_size - len(motif))
        motifs.append(motif)
        centers.append(center)

    bg_counts_df = generate_positional_counts_mat(
        bg_seqs, centers, filter_size
    )
    counts_df = seqlogo_from_msa(motifs, bg_counts_mat=bg_counts_df)  
    # counts_df = lm.alignment_to_matrix(motifs, to_type='information',background=bg_counts_df)
    kernel_scores = np.array(
        [
            pearsonr(layer_output[:, i, loc], Y.detach().numpy())[0]
            for loc in range(layer_output.shape[2])
        ]
    )

    kernel_data.append(
                (counts_df, i, mode(centers)[0], len(motifs), kernel_scores)
            )

    order_counts_df = bg_counts_df[['A', 'C', 'G', 'U']]
    kernel_data_dict.append(
        (order_counts_df, i, mode(centers)[0], len(motifs), kernel_scores)
    )

kernel_data.sort(key=lambda elem: elem[2])
kernel_data_dict.sort(key=lambda elem: elem[2])

n = len(kernel_data)
# Set figure
col = floor(sqrt(n))
row = ceil(n / col)

fig = plt.figure(figsize=(col * 10, row * 5))
gs = plt.GridSpec(row * 2, col, figure=fig, height_ratios=[10, 1] * row)

for i, (counts_df, idx, kernel_loc, n_motifs, kernel_scores) in enumerate(
    kernel_data[:n]
):
    _, ax = plt.subplots(1, 1, figsize=(filter_size/2,3))
    lm.Logo(counts_df, ax=ax, color_scheme="classic",font_name='No')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.set_xticks([])
    ax.tick_params(width=2) 
    #ax.set_yticks([0,0.1,0.15])
    # ax.text(
    #     len(counts_df) // 2,
    #     3.96,
    #     "Kernel %d: center=%d, n=%d" % (idx, kernel_loc, n_motifs),
    #     ha="center",
    #     va="top",
    #     fontsize=12,
    # )
    # ax.set_yticks([])
    # ax2 = fig.add_subplot(gs[i // col * 2 + 1, i % col])
    # sns.heatmap(
    #     kernel_scores.reshape((1, -1)),
    #     ax=ax2,
    #     cmap="seismic",
    #     vmin=-1,
    #     vmax=1,
    #     cbar=False,
    # )
    # ax2.axis("off")
# save_fig(
#     f"{virus_name}_{date}_{model_name}_conv1_{window}_withE"
# )
plt.rcParams['svg.fonttype'] = 'none'
plt.savefig("./results/SFV_a.eps", format="eps")