In [1]:
import pandas as pd
import numpy as np
import torch
import os
import random
import math
import re
from gensim.models import doc2vec
import os.path as osp

import torch.nn.functional as F
from torch import nn,optim
from sklearn.metrics import roc_auc_score
import warnings

import os
import pickle

import gensim
from gensim.models.doc2vec import Doc2Vec
import sys
sys.path.append("..")
from embeddings_reproduction import embedding_tools


warnings.filterwarnings('ignore')
file_dir='data/'
task='C17ORF85_Baltz2012'

In [2]:
def seq_to_kmers(seq, k=3, overlap=False, **kwargs):
    N = len(seq)
    if overlap:
        return [[seq[i:i+k] for i in range(N - k + 1)]]
    else:
        return [[seq[i:i+k] for i in range(j, N - k + 1, k)]
                for j in range(k)]


In [3]:
def get_embeddings(doc2vec_file,seq, k=5, overlap=False, norm=True, steps=5):
    """ Infer embeddings in one pass using a gensim doc2vec model.

    Parameters:
        doc2vec_file (str): file pointing to saved doc2vec model
        seqs (iterable): sequences to infer
        k (int) default 3
        overlap (Boolean) default False
        norm (Boolean) default True
        steps (int): number of steps during inference. Default 5.

    Returns:
        numpy ndarray where each row is the embedding for one sequence.
    """
    
    as_kmer = []+seq_to_kmers(seq, k=k, overlap=overlap)
    as_kmer=as_kmer[0]
    model = doc2vec.Doc2Vec.load(doc2vec_file)
    #print(as_kmer[0])
    vector_ret=model.infer_vector(as_kmer, steps=steps)
    return vector_ret


In [9]:

bases = ['A', 'C', 'G', 'U']
base_dict = {'A': 0, 'C': 1, 'G': 2, 'U': 3}
bases_len = len(bases)
num_feature = 271 #total number of features
def convert_to_index(str,word_len):
    output_index = 0
    for i in range(word_len):
        output_index = output_index * bases_len + base_dict[str[i]]
    return output_index

def extract_features(line):
    line2=line
    for i in 'agctu\n':
        line2 = line2.replace(i, '')
    line = line.upper().rstrip()
    line = line.replace('T', 'U')
    line2 = line2.replace('T','U')
   #line = line.replace('N','')
   #line2 = line2.replace('N','')
    final_output=[]#get_embeddings('outputs/docvec_models/2_virus_5_6.pkl',line).tolist()
    final_output.extend(get_embeddings('outputs/docvec_models/0_virus_5_6.pkl',line2).tolist())
    for word_len in [1,2,3]:
        output_count_list = [0 for i in range(bases_len ** word_len)]
        for i in range(len(line) - word_len + 1):
            output_count_list[convert_to_index(line[i: i + word_len],word_len)] += 1
        output_count_list2 = [0 for i in range(bases_len ** word_len)]
        for i in range(len(line2)-word_len+1):
            output_count_list2[convert_to_index(line2[i:i+word_len],word_len)] +=1
        final_output.extend(output_count_list)
        final_output.extend(output_count_list2)
    for word_len in [4,5,6]:
        output_count_list = [0 for i in range(bases_len ** 2)]
        for i in range(len(line) - word_len):
            output_count_list[convert_to_index(line[i]+line[i + word_len],2)] += 1
        output_count_list2 = [0 for i in range(bases_len ** 2)]
        for i in range(len(line2)-word_len):
              output_count_list2[convert_to_index(line2[i]+line2[i+word_len],2)] +=1
        final_output.extend(output_count_list)
        final_output.extend(output_count_list2)
    final_output.append(len(line2))
    final_output.append(math.log(len(line2)))
    final_output.append(int(len(line2) % 3 == 0))
    stop_codons = ['UAG', 'UAA', 'UGA']
    stop_codon_features = [0,0,0,0]
    for stop_codon_num in range(len(stop_codons)):
        tmp_arr = [m.start() for m in re.finditer(stop_codons[stop_codon_num], line2)]
        tmp_arr_div3 = [i for i in tmp_arr if i % 3 == 0]
        stop_codon_features[stop_codon_num]=int(len(tmp_arr_div3) > 0)
        stop_codon_features[3]|=stop_codon_features[stop_codon_num]
    final_output.extend(stop_codon_features)
    return final_output


def load_dataset(task,is_load):
    x_train=[]
    x_valid=[]
    x_test=[]
    y_train=[]
    y_valid=[]
    y_test=[]
    N=0
    filename=file_dir+task+'.train.positives.fa'
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                N+=1
    select_list = list(range(N))
    valid_set=random.sample(select_list,(int)(N/10))
    filename=file_dir+task+'.train.positives.fa'
    num=0
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                num+=1
                if(num in valid_set):
                    x_valid.append(extract_features(line.strip('\n').strip('\r')))
                    y_valid.append(1.0)
                else:
                    x_train.append(extract_features(line.strip('\n').strip('\r')))
                    y_train.append(1.0)
                    
                
    filename=file_dir+task+'.train.negatives.fa'
    num=0
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                num+=1
                if(num in valid_set):
                    x_valid.append(extract_features(line.strip('\n').strip('\r')))
                    y_valid.append(0)
                else:
                    x_train.append(extract_features(line.strip('\n').strip('\r')))
                    y_train.append(0)
                    
    filename=file_dir+task+'.ls.positives.fa'
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                x_test.append(extract_features(line.strip('\n').strip('\r')))
                y_test.append(1.0)
    filename=file_dir+task+'.ls.negatives.fa'
    for line in open(filename, "r"):
        if line[0] == '>':
            continue
        else:
            if ('n' in line or 'N' in line):
                continue
            else:
                x_test.append(extract_features(line.strip('\n').strip('\r')))
                y_test.append(0)
    return np.array(x_train),np.array(y_train),np.array(x_valid),np.array(y_valid),np.array(x_test),np.array(y_test)

In [24]:
x_train,y_train,x_valid,y_valid,x_test,y_test=load_dataset(task,is_load=False)


print(x_train.shape)
print(y_train.shape)


print(x_valid.shape)
print(y_valid.shape)

print(x_test.shape)
print(y_test.shape)
x_train=torch.from_numpy(x_train).float()
y_train=torch.from_numpy(y_train).float()

x_valid=torch.from_numpy(x_valid).float()
y_valid=torch.from_numpy(y_valid).float()

x_test=torch.from_numpy(x_test).float()
y_test=torch.from_numpy(y_test).float()

y_train = y_train.unsqueeze(1)
y_test = y_test.unsqueeze(1)
y_valid = y_valid.unsqueeze(1)

(2170, 335)
(2170,)
(240, 335)
(240,)
(266, 335)
(266,)


In [48]:
learning_rate=1e-4
feature_num=335
class nnNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, out_dim):
        super(nnNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1))
        self.layer2=nn.Sequential(nn.Linear(n_hidden_1, n_hidden_1))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_1, out_dim))
 
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer3(x)
        return torch.sigmoid(x)
model=nnNet(feature_num,64,1)
criterion =nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=10 ** (-5.0))

num_epoch=5000


In [49]:
best=0
wh=0
for epoch in range(num_epoch):
    out=model(x_train)
    loss=criterion(out,y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print_loss=loss.data.item()
    train_auc=roc_auc_score(y_train,out.detach().numpy())
    out_valid=model(x_valid)
    valid_auc=roc_auc_score(y_valid,out_valid.detach().numpy())
    
    if(valid_auc>best):
        best=valid_auc
        wh=epoch
        torch.save({'model_state_dict': model.state_dict()},task+'local_best_model.pth')
    if(epoch%10==0):
        print(print_loss,train_auc,valid_auc)
model.load_state_dict(torch.load(task+'local_best_model.pth')['model_state_dict'])
out_test=model(x_test)
test_auc=roc_auc_score(y_test,out_test.detach().numpy())
print(best,test_auc)    

0.7044891119003296 0.5571435367353735 0.49149246475449676
0.6878293752670288 0.5657127736209369 0.4990624348913119
0.6812427639961243 0.5704091826540807 0.5039933328703382
0.6788443326950073 0.5743627121178141 0.5067018542954372
0.6772264838218689 0.5776723276723277 0.5080213903743315
0.6760063171386719 0.5803499221866569 0.5117022015417737
0.6749605536460876 0.5829960855471059 0.5128828390860476
0.6739919781684875 0.585484243647509 0.5156608097784569
0.6730471849441528 0.5881805269560372 0.5178831863323842
0.6720945835113525 0.5908963485494099 0.5206611570247934
0.6711150407791138 0.5939035454341577 0.5237169247864435
0.6701003313064575 0.5971757834002731 0.5271199388846448
0.6690506935119629 0.6006366082896695 0.5306618515174665
0.6679648756980896 0.6047260562566685 0.5357316480311133
0.6668334603309631 0.609072050398581 0.5412181401486215
0.6656408905982971 0.6141129278884381 0.5466351829988194
0.6643772125244141 0.6199412832065894 0.5521216751163276
0.6630275845527649 0.62672191753

0.6031242609024048 0.7927055937260019 0.7379331898048476
0.603003203868866 0.7926677913922812 0.7379679144385027
0.602870762348175 0.7926818079879305 0.738349885408709
0.6027323007583618 0.7926660924109903 0.7381068129731232
0.6025967001914978 0.7926359354930784 0.7381068129731232
0.6024681925773621 0.7926274405866244 0.7382109868740885
0.6023394465446472 0.7925943104514533 0.7378984651711925
0.602212131023407 0.7926121497550068 0.7384193346760192
0.6020850539207458 0.792606203320489 0.7385929578442948
0.6019587516784668 0.7925437657580514 0.7388360302798807
0.601832926273346 0.7925063881696535 0.7392180012500869
0.6017078161239624 0.7924860003941637 0.7392180012500869
0.6015832424163818 0.7924821776862594 0.7394263490520175
0.601471483707428 0.792471134307869 0.7396694214876034
0.6014233827590942 0.7924779302330323 0.7397388707549135
0.601273238658905 0.7924588166935106 0.7397041461212586
0.601105272769928 0.7924086967454314 0.739843044655879
0.6009917259216309 0.7923640984865475 0.73

0.5912009477615356 0.7925539596457963 0.7405375373289812
0.5911443829536438 0.7925297491624022 0.7406764358636017
0.5910918116569519 0.792482602431582 0.7403639141607056
0.591414213180542 0.7925437657580514 0.7393568997847072
0.5910386443138123 0.7924855756488409 0.7404333634280158
0.5909735560417175 0.7925284749264342 0.7401555663587749
0.5908786654472351 0.7925059634243308 0.7391485519827766
0.5908300280570984 0.792538668814179 0.7390791027154664
0.5907774567604065 0.7925569328630553 0.7392180012500869
0.590726912021637 0.7926074775564571 0.7393568997847072
0.5906770825386047 0.7926142734816204 0.7393568997847072
0.5906282663345337 0.7925484379566012 0.7393568997847073
0.590579628944397 0.7925785948745132 0.7392874505173971
0.590531051158905 0.7926129992456523 0.7394263490520175
0.5904831290245056 0.7926342365117875 0.7392874505173971
0.5904352068901062 0.792616397208234 0.7390791027154664
0.5903875231742859 0.7926274405866243 0.7388707549135356
0.5903400182723999 0.792672463590831 0

0.5860607624053955 0.793290468035366 0.7408153343982221
0.5860379338264465 0.7933592767776441 0.7406069865962915
0.586015522480011 0.7933690459200664 0.7406069865962915
0.5859932899475098 0.7933673469387755 0.7405722619626363
0.5859710574150085 0.7933970791113649 0.7406069865962915
0.5859490633010864 0.7934004770739465 0.7406069865962914
0.5859271287918091 0.7934229885760498 0.740745885130912
0.585905134677887 0.7934357309357309 0.740745885130912
0.5858833193778992 0.7934459248234758 0.7407111604972568
0.5858615636825562 0.7934590919284796 0.740780609764567
0.5858398079872131 0.7934586671831569 0.7405722619626363
0.5858181118965149 0.79347013530687 0.7405375373289812
0.585796594619751 0.7934773559773559 0.7392527258837419
0.5857749581336975 0.7934735332694516 0.7393221751510521
0.5857535600662231 0.7933677716840982 0.7392527258837419
0.5857320427894592 0.7933724438826479 0.7391138273491216
0.5857111215591431 0.7933788150624885 0.7392180012500869
0.5857435464859009 0.7935134593297859 0.