In [2]:
from typing import List, Dict, Tuple
import keras
import keras.layers
import keras.utils.all_utils
import keras.callbacks
import seaborn as sns
import numpy as np
from sklearn.model_selection import train_test_split

In [3]:
import tensorflow as tf

tf.config.list_physical_devices()

2021-12-19 12:55:01.572199: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-12-19 12:55:01.577433: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-12-19 12:55:01.577554: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [4]:
def parse_fasta_dataset(file_path: str = "../data/LTP_09_2021_compressed.fasta", max_seqs:int = False) -> List[Dict[str, str]]:
    dataset = []
    current_idx = 0
    current_meta = {}
    with open(file_path, "r") as f:
        for line in f.readlines():
            if line[0] == ">":
                if current_meta != {}:
                    current_meta["sequence"] = current_meta["sequence"].strip()
                    dataset.append(current_meta)
                line_list = line.replace("\n", "").replace(">", "").split("\t")
                if len(line_list) < 2:
                    current_meta  = {key: line_list[idx] for idx, key in enumerate(["name"])}
                elif len(line_list) < 3:
                    current_meta  = {key: line_list[idx] for idx, key in enumerate(["id", "name"])}
                else:
                    current_meta  = {key: line_list[idx] for idx, key in enumerate(["id", "name", "tags"])}
                    current_meta["tags"] = current_meta["tags"].split(";")
                current_meta["sequence"] = ""
                current_idx += 1
                if current_idx > max_seqs and max_seqs>=1:
                    break
            else:
                current_meta["sequence"] += line.replace("\n", " ")
    return dataset

dataset = parse_fasta_dataset(max_seqs=-1)
print(len(dataset))

17959


In [5]:
dataset[0]

{'id': 'AB681979',
 'name': 'Trabulsiella guamensis',
 'tags': ['Bacteria',
  'Proteobacteria',
  'Gammaproteobacteria',
  'Enterobacterales',
  'Enterobacteriaceae',
  'Trabulsiella'],
 'sequence': 'AUUGAACGCU GGCGGCAGGC CUAACACAUG CAAGUCGAGC GGCAGCGGGG GAAAGCUUGC UUUCCCGCCG GCGAGCGGCG GACGGGUGAG UAAUGUCUGG GAAACUGCCU GAUGGAGGGG GAUAACUACU GGAAACGGUA GCUAAUACCG CAUAACGUCU UCGGACCAAA GUGGGGGACC UUCGGGCCUC AUGCCAUCAG AUGUGCCCAG AUGGGAUUAG CUAGUAGGUG GGGUAACGGC UCACCUAGGC GACGAUCCCU AGCUGGUCUG AGAGGAUGAC CAGCCACACU GGAACUGAGA CACGGUCCAG ACUCCUACGG GAGGCAGCAG UGGGGAAUAU UGCACAAUGG GCGCAAGCCU GAUGCAGCCA UGCCGCGUGU AUGAAGAAGG CCUUCGGGUU GUAAAGUACU UUCAGCGGGG AGGAAGGUGU UGUGGUUAAU AACCAGAGCA AUUGACGUUA CCCGCAGAAG AAGCACCGGC UAACUCCGUG CCAGCAGCCG CGGUAAUACG GAGGGUGCAA GCGUUAAUCG GAAUUACUGG GCGUAAAGCG CACGCAGGCG GUCUGUCAAG UCGGAUGUGA AAUCCCCGGG CUCAACCUGG GAACUGCAUC CGAAACUGGC AGGCUUGAGU CUUGUAGAGG GGGGUAGAAU UCCAGGUGUA GCGGUGAAAU GCGUAGAGAU CUGGAGGAAU ACCGGUGGCG AAGGCGGCCC CCUGGACAAA GACUGACG

In [6]:
def chunk_seq(seq: str, chunk_len: int = 10):
    ret_list = []
    chunk = ""
    for char in seq:
        chunk += char
        if len(chunk) % chunk_len ==0:
            ret_list.append(chunk)
            chunk=""
    if chunk != '':
        ret_list.append("{}{}".format(chunk, "_"*(4-len(chunk))))
    return ret_list

In [7]:
temp = []
for kmer_seq in [chunk_seq(x["sequence"].replace(" ", "")) for x in dataset]:
    temp.extend(kmer_seq)
temp = set(temp)
encode_dict = {value: idx+1 for idx, value in enumerate(temp)}
# print(encode_dict)
X = np.array([np.pad(np.array([encode_dict[z] for z in chunk_seq(x["sequence"].replace(" ", ""))]), (0, 3000))[:1800] for x in dataset])

In [8]:
X.shape

(17959, 1800)

In [9]:
temp = {value for x in dataset for value in x["tags"][:3]}
label_encode_dict = {value: idx+1 for idx, value in enumerate(temp)}
Y_category = keras.utils.all_utils.to_categorical(np.array([[label_encode_dict[z] for z in x["tags"][:3]] for x in dataset]))
Y_category = np.array([np.sum(x, 0) for x in Y_category])

In [30]:
Y_next = np.array([np.pad(np.array([0.0, 0.0, *[encode_dict[z] for z in chunk_seq(x["sequence"].replace(" ", ""))]]), (0, 3000))[:1500] for x in dataset])

In [31]:
def chunk_kmer(seq: List, chunk_len: int = 5):
    ret_list = []
    chunk = []
    for char in seq:
        chunk.append(char)
        if len(chunk) % chunk_len ==0:
            ret_list.append(chunk)
            chunk=[]
    return ret_list

In [34]:
Y_sg = np.array([[z for z in chunk_kmer(x)] for x in Y_next])
Y_sg.shape
Y_sg = np.array([[[*z[:2], *z[3:]] for z in x] for x in Y_sg])

In [35]:
print(X[0][0])
print(Y_sg[0][0])

61830
[     0.      0. 243212. 137227.]


In [52]:
def conv_bn_block(
    x, 
    filters=16, 
    kernel=3, 
    stride=1, 
    ratio=2, 
    act="elu", 
    padding="same"
):
    # dimensionality reduction
    x = keras.layers.Conv1D(
        filters // ratio,
        kernel_size=1,
        strides=stride,
        padding=padding,
    )(x)

    # convolution
    x = keras.layers.Conv1D(
        filters,
        kernel_size=kernel,
        strides=stride,
        padding=padding,
    )(x)

    # batch norm
    x = keras.layers.BatchNormalization()(x)
    # activation
    x = keras.layers.Activation(act)(x)

    return x

In [53]:
def attention_layer(
    input_layer, 
    width: int = 128, 
    dropout: float = 0.3, 
    attention_filter: int = 8, 
    attention_dilation:int = 1, 
    attention_activation: str = "softmax"
):
    # RNN layer
    rnn_layer = keras.layers.CuDNNGRU(
        width, 
    )(input_layer)
    
    attention_layer = conv_bn_block(input_layer, filters=width, kernel=64)
    
    multiply_layer = keras.layers.Multiply()([rnn_layer, attention_layer])
    
    output_layer = keras.layers.Dense(width, activation="linear")(multiply_layer)
    
    # dropout
    if dropout:
        output_layer = keras.layers.Dropout(dropout)(output_layer)
    
    return keras.layers.Activation(activation=attention_activation)(output_layer)

In [54]:
def build_model(
    input_shape: Tuple[int], 
    output_shape: Tuple[int], 
    embed_size: int, 
    vocab_size: int, 
    rnn_size: int = 32
):
    # model input
    model_input = keras.layers.Input(shape=input_shape)

    # embedding layer
    embedding = keras.layers.Embedding(vocab_size, embed_size)(model_input)
    
    rnn_layer = attention_layer(embedding)
    
    # pooling layer
    pooling_layer = keras.layers.GlobalAveragePooling1D()(rnn_layer)
    
    # model output
    model_output = keras.layers.Dense(
        output_shape[0] * output_shape[1], 
        activation="relu", 
        kernel_regularizer="l2"
    )(pooling_layer)
    
    model_output = keras.layers.Reshape([*output_shape])(model_output)

    return keras.Model(inputs=[model_input], outputs=[model_output]), keras.Model(inputs=[model_input], outputs=[embedding])

In [None]:
model, embedder = build_model(
    input_shape=X[0].shape, 
    output_shape=Y_sg[0].shape, 
    embed_size=8, 
    vocab_size=len(encode_dict)+1, 
    rnn_size=64
)

model.compile(optimizer="rmsprop", loss="binary_crossentropy")
print(model.summary())

train_hx = model.fit(
    X, 
    Y_sg, 
    validation_split=0.2, 
    epochs=1000, 
    callbacks=[keras.callbacks.TensorBoard()], 
    batch_size=128
)

Model: "model_18"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_12 (InputLayer)          [(None, 1800)]       0           []                               
                                                                                                  
 embedding_11 (Embedding)       (None, 1800, 8)      1990072     ['input_12[0][0]']               
                                                                                                  
 conv1d_12 (Conv1D)             (None, 1800, 64)     576         ['embedding_11[0][0]']           
                                                                                                  
 conv1d_13 (Conv1D)             (None, 1800, 128)    524416      ['conv1d_12[0][0]']              
                                                                                           