In [1]:
import onmt
import numpy as np
import torch
import torch.nn as nn 
import onmt.inputters
import onmt.decoders
import onmt.utils
from Bio import SeqIO
from sklearn.model_selection import train_test_split


In [2]:
torch.cuda.is_available()

True

In [116]:
def fasta_to_list(fa):
    genetic_data = []

    for record in SeqIO.parse(fa, "fasta"):
        genetic_data.append(list(record.seq))

    genetic_sorted = []

    for i in genetic_data:
        if set(i) == {'G', 'T', 'C', 'A'}:
            genetic_sorted.append(i)

    return genetic_sorted

def split_data(gen_data):
    x = []
    y = []
    split_polymerase = int(len(gen_data)/2)
    x = gen_data[:split_polymerase]
    y = gen_data[len(x):]
    # len(x), len(y)

    # x_train, x_test, y_train, y_test = train_test_split(x, y)

    return x, y[:len(x)]

# Triplet codon hypothesis to ensure a better reading by OpenNMT
def triplet_codon(A):
    new_list = []
    for fem in A:
        triplet = [ x+y+z for x,y,z in zip(fem[0::3], fem[1::3], fem[2::3]) ]
        new_list.append(triplet)

    return new_list  

def list_to_txt(A, filename):
    with open(filename, 'w') as f:
        for exp in A:
            for item in exp:
                f.write(str(item))
            f.write('\n')

def list_txt(A, filename):
    with open(filename, 'w') as f:
        for exp in A:
            f.write(str(exp) + '\n')
            

def length_sort(unsorted_list):
    return sorted(list(unsorted_list), key=len)

def txt_to_list(filename):
    txt_file = open(filename, 'r')
    return txt_file.readlines()
        

In [33]:
#Load the data 
bundi_ebola_x, bundi_ebola_y = split_data(fasta_to_list('Bundibugyo ebola Human.fa'))
tai_ebola_x, tai_ebola_y = split_data(fasta_to_list('Tai forest ebola Human.fa'))
zaire_ebola_x, zaire_ebola_y = split_data(fasta_to_list('Zaire ebola human.fa'))
ebola_rna_x, ebola_rna_y  = split_data(fasta_to_list('Ebola RNA-dependent RNA polymerase Human.fa'))
sudan_ebola_x, sudan_ebola_y = split_data(fasta_to_list('Sudan ebola Human.fa'))
west_nile_1_x, west_nile_1_y = split_data(fasta_to_list('West nile Human.fa'))
west_nile_2_x, west_nile_2_y = split_data(fasta_to_list('West nile bat.fa'))
dengue_human_x, dengue_human_y = split_data(fasta_to_list('dengue human.fa'))

In [34]:
x_ = bundi_ebola_x + tai_ebola_x + zaire_ebola_x + ebola_rna_x + sudan_ebola_x + west_nile_1_x + west_nile_2_x + dengue_human_x

y_ = bundi_ebola_y + tai_ebola_y + zaire_ebola_y + ebola_rna_y + sudan_ebola_y + west_nile_1_y + west_nile_2_y + dengue_human_y

x_ = length_sort(x_)
y_ = length_sort(y_)

In [35]:
#Test, train and validation split
x_train, x_test_val, y_train, y_test_val = train_test_split(x_, y_, test_size = 0.3)
x_test, x_val, y_test, y_val = train_test_split(x_test_val, y_test_val, test_size = 0.5, random_state=2000)

In [158]:
# x_train = length_sort(x_train)
# y_train = length_sort(y_train)
# x_test = length_sort(x_test)
# y_test = length_sort(y_test)
# x_val = length_sort(x_val)
# y_val = length_sort(y_val)
len(x_train), len(y_val)


(14018, 3005)

In [30]:
list_to_txt(x_train, 'x_train.txt')
list_to_txt(y_train, 'y_train.txt')
list_to_txt(x_test, 'x_test.txt')
list_to_txt(y_test, 'y_test.txt')
list_to_txt(x_val, 'x_val.txt')
list_to_txt(y_val, 'y_val.txt')

In [91]:
pred_data = txt_to_list('y_pred_5.txt')
y_test = [''.join(x) for x in y_test]


In [92]:
import difflib

accuracy = []
# accuracy = 0
size_diff = 0
for i, m in enumerate(y_test):
    # print(pred_data[i][:-1])
    similarity = difflib.SequenceMatcher(None, m, pred_data[i][:-1]).ratio()
    accuracy.append(similarity)
    # size_diff += np.sqrt((len(m) - len(y_train[i]))**2)
# print(accuracy/len(x_train))
# print(size_diff/len(x_train))
# print(len(x_train[0]))
# print(len(y_train[0]))
# mean_error = sum(accuracy)/len(accuracy)
# mean_error


In [105]:
accuracy_filtered = list(filter(lambda a: a > 0.004, accuracy))
opennmt_results = [np.mean(accuracy_filtered), len(accuracy_filtered)]

In [115]:
list_txt(list(opennmt_results), 'OpenNMT metrics.txt')

In [14]:
#Already ran the preprocessing step through command line.
vocab_fields = torch.load("genes_processed.vocab.pt")

src_text_field = vocab_fields["src"].base_field
src_vocab = src_text_field.vocab
src_padding = src_vocab.stoi[src_text_field.pad_token]

tgt_text_field = vocab_fields['tgt'].base_field
tgt_vocab = tgt_text_field.vocab
tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]

In [26]:
emb_size = 500
rnn_size = 2000
# Specify the core model.

encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),
                                             word_padding_idx=src_padding)

encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,
                                   rnn_type="LSTM", bidirectional=True,
                                   embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),
                                             word_padding_idx=tgt_padding)
decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
    hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, 
    rnn_type="LSTM", embeddings=decoder_embeddings)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = onmt.models.model.NMTModel(encoder, decoder)
model.to(device)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(
    nn.Linear(rnn_size, len(tgt_vocab)),
    nn.LogSoftmax(dim=-1)).to(device)

loss = onmt.utils.loss.NMTLossCompute(
    criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),
    generator=model.generator)

In [27]:
lr = 0.5
torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optim = onmt.utils.optimizers.Optimizer(
    torch_optimizer, learning_rate=lr, max_grad_norm=2)

In [28]:
# Load some data
from itertools import chain
train_data_file = "genes_processed.train.0.pt"
valid_data_file = "genes_processed.valid.0.pt"
train_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[train_data_file],
                                                     fields=vocab_fields,
                                                     batch_size=50,
                                                     batch_size_multiple=1,
                                                     batch_size_fn=None,
                                                     pool_factor=1,
                                                     device=device,
                                                     is_train=True,
                                                     repeat=True)

valid_iter = onmt.inputters.inputter.DatasetLazyIter(dataset_paths=[valid_data_file],
                                                     fields=vocab_fields,
                                                     batch_size=10,
                                                     batch_size_multiple=1,
                                                     pool_factor=1,
                                                     batch_size_fn=None,
                                                     device=device,
                                                     is_train=False,
                                                     repeat=False)

In [32]:
report_manager = onmt.utils.ReportMgr(report_every=2000, start_time=None, tensorboard_writer=None)
trainer = onmt.Trainer(model=model,
                        n_gpu=1,
                        gpu_rank=0,
                       train_loss=loss,
                       valid_loss=loss,
                       optim=optim,
                       report_manager=report_manager)
trainer.train(train_iter=train_iter,
              train_steps=2000,
              valid_iter=valid_iter,
              valid_steps=100)

KeyboardInterrupt: 

In [None]:
import onmt.translate

src_reader = onmt.inputters.str2reader["text"]
tgt_reader = onmt.inputters.str2reader["text"]
scorer = onmt.translate.GNMTGlobalScorer(alpha=0.7, 
                                         beta=0., 
                                         length_penalty="avg", 
                                         coverage_penalty="none")
gpu = 0 if torch.cuda.is_available() else -1
translator = onmt.translate.Translator(model=model, 
                                       fields=vocab_fields, 
                                       src_reader=src_reader, 
                                       tgt_reader=tgt_reader, 
                                       global_scorer=scorer,
                                       gpu=gpu)
builder = onmt.translate.TranslationBuilder(data=torch.load(valid_data_file), 
                                            fields=vocab_fields)

for batch in valid_iter:
    trans_batch = translator.translate_batch(
        batch=batch, src_vocabs=[src_vocab],
        attn_debug=False)
    translations = builder.from_batch(trans_batch)
    for trans in translations:
        print(trans.log(0))