In [1]:
import pandas as pd
import numpy as np
import torch
import os
import random
import math
import re
from gensim.models import doc2vec
import os.path as osp

import torch.nn.functional as F
from torch import nn,optim
from sklearn.metrics import roc_auc_score
import warnings

import os
import pickle

import gensim
from gensim.models.doc2vec import Doc2Vec
import sys
sys.path.append("..")
from embeddings_reproduction import embedding_tools


warnings.filterwarnings('ignore')
file_dir='data/'
task='C17ORF85_Baltz2012'

In [2]:
def seq_to_kmers(seq, k=3, overlap=False, **kwargs):
    N = len(seq)
    if overlap:
        return [[seq[i:i+k] for i in range(N - k + 1)]]
    else:
        return [[seq[i:i+k] for i in range(j, N - k + 1, k)]
                for j in range(k)]


In [3]:
def get_embeddings(doc2vec_file,seq, k=5, overlap=False, norm=True, steps=5):
    """ Infer embeddings in one pass using a gensim doc2vec model.

    Parameters:
        doc2vec_file (str): file pointing to saved doc2vec model
        seqs (iterable): sequences to infer
        k (int) default 3
        overlap (Boolean) default False
        norm (Boolean) default True
        steps (int): number of steps during inference. Default 5.

    Returns:
        numpy ndarray where each row is the embedding for one sequence.
    """
    
    as_kmer = []+seq_to_kmers(seq, k=k, overlap=overlap)
    as_kmer=as_kmer[0]
    model = doc2vec.Doc2Vec.load(doc2vec_file)
    #print(as_kmer[0])
    vector_ret=model.infer_vector(as_kmer, steps=steps)
    return vector_ret


In [None]:

bases = ['A', 'C', 'G', 'U']
base_dict = {'A': 0, 'C': 1, 'G': 2, 'U': 3}
bases_len = len(bases)
num_feature = 271 #total number of features
def convert_to_index(str,word_len):
    output_index = 0
    for i in range(word_len):
        output_index = output_index * bases_len + base_dict[str[i]]
    return output_index

def extract_features(line):
    line2=line
    for i in 'agctu\n':
        line2 = line2.replace(i, '')
    line = line.upper().rstrip()
    line = line.replace('T', 'U')
    line2 = line2.replace('T','U')
   #line = line.replace('N','')
   #line2 = line2.replace('N','')
    final_output=[]#get_embeddings('outputs/docvec_models/2_virus_5_6.pkl',line).tolist()
    final_output.extend(get_embeddings('outputs/docvec_models/0_virus_5_6.pkl',line2).tolist())
    for word_len in [1,2,3]:
        output_count_list = [0 for i in range(bases_len ** word_len)]
        for i in range(len(line) - word_len + 1):
            output_count_list[convert_to_index(line[i: i + word_len],word_len)] += 1
        output_count_list2 = [0 for i in range(bases_len ** word_len)]
        for i in range(len(line2)-word_len+1):
            output_count_list2[convert_to_index(line2[i:i+word_len],word_len)] +=1
        final_output.extend(output_count_list)
        final_output.extend(output_count_list2)
    for word_len in [4,5,6]:
        output_count_list = [0 for i in range(bases_len ** 2)]
        for i in range(len(line) - word_len):
            output_count_list[convert_to_index(line[i]+line[i + word_len],2)] += 1
        output_count_list2 = [0 for i in range(bases_len ** 2)]
        for i in range(len(line2)-word_len):
              output_count_list2[convert_to_index(line2[i]+line2[i+word_len],2)] +=1
        final_output.extend(output_count_list)
        final_output.extend(output_count_list2)
    final_output.append(len(line2))
    final_output.append(math.log(len(line2)))
    final_output.append(int(len(line2) % 3 == 0))
    stop_codons = ['UAG', 'UAA', 'UGA']
    stop_codon_features = [0,0,0,0]
    for stop_codon_num in range(len(stop_codons)):
        tmp_arr = [m.start() for m in re.finditer(stop_codons[stop_codon_num], line2)]
        tmp_arr_div3 = [i for i in tmp_arr if i % 3 == 0]
        stop_codon_features[stop_codon_num]=int(len(tmp_arr_div3) > 0)
        stop_codon_features[3]|=stop_codon_features[stop_codon_num]
    final_output.extend(stop_codon_features)
    return final_output


def load_dataset(task,is_load):
    x_train=[]
    x_valid=[]
    x_test=[]
    y_train=[]
    y_valid=[]
    y_test=[]
    N=0
    filename=file_dir+task+'.train.positives.fa'
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                N+=1
    select_list = list(range(N))
    valid_set=random.sample(select_list,(int)(N/10))
    filename=file_dir+task+'.train.positives.fa'
    num=0
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                num+=1
                if(num in valid_set):
                    x_valid.append(extract_features(line.strip('\n').strip('\r')))
                    y_valid.append(1.0)
                else:
                    x_train.append(extract_features(line.strip('\n').strip('\r')))
                    y_train.append(1.0)
                    
                
    filename=file_dir+task+'.train.negatives.fa'
    num=0
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                num+=1
                if(num in valid_set):
                    x_valid.append(extract_features(line.strip('\n').strip('\r')))
                    y_valid.append(0)
                else:
                    x_train.append(extract_features(line.strip('\n').strip('\r')))
                    y_train.append(0)
                    
    filename=file_dir+task+'.ls.positives.fa'
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                x_test.append(extract_features(line.strip('\n').strip('\r')))
                y_test.append(1.0)
    filename=file_dir+task+'.ls.negatives.fa'
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                x_test.append(extract_features(line.strip('\n').strip('\r')))
                y_test.append(0)
    return np.array(x_train),np.array(y_train),np.array(x_valid),np.array(y_valid),np.array(x_test),np.array(y_test)

In [None]:
x_train,y_train,x_valid,y_valid,x_test,y_test=load_dataset(task,is_load=False)


print(x_train.shape)
print(y_train.shape)


print(x_valid.shape)
print(y_valid.shape)

print(x_test.shape)
print(y_test.shape)
x_train=torch.from_numpy(x_train).float()
y_train=torch.from_numpy(y_train).float()

x_valid=torch.from_numpy(x_valid).float()
y_valid=torch.from_numpy(y_valid).float()

x_test=torch.from_numpy(x_test).float()
y_test=torch.from_numpy(y_test).float()

y_train = y_train.unsqueeze(1)
y_test = y_test.unsqueeze(1)
y_valid = y_valid.unsqueeze(1)

In [None]:
learning_rate=1e-4
feature_num=335
class nnNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, out_dim):
        super(nnNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1))
        self.layer2=nn.Sequential(nn.Linear(n_hidden_1, n_hidden_1))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_1, out_dim))
 
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer3(x)
        return torch.sigmoid(x)
model=nnNet(feature_num,64,1)
criterion =nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=10 ** (-5.0))

num_epoch=5000


In [None]:
best=0
wh=0
for epoch in range(num_epoch):
    out=model(x_train)
    loss=criterion(out,y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print_loss=loss.data.item()
    train_auc=roc_auc_score(y_train,out.detach().numpy())
    out_valid=model(x_valid)
    valid_auc=roc_auc_score(y_valid,out_valid.detach().numpy())
    
    if(valid_auc>best):
        best=valid_auc
        wh=epoch
        torch.save({'model_state_dict': model.state_dict()},task+'local_best_model.pth')
    if(epoch%10==0):
        print(print_loss,train_auc,valid_auc)
model.load_state_dict(torch.load(task+'local_best_model.pth')['model_state_dict'])
out_test=model(x_test)
test_auc=roc_auc_score(y_test,out_test.detach().numpy())
print(best,test_auc)    