In [21]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [22]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import os

import sys
sys.path.append('..')

from bert.dataset import encode
from bert.model import create_model
from bert.go import TreeNorm
from bert.go import Ontology
from bert.go.layers import LogitSplitFmax

In [23]:
ont = Ontology(threshold=1)
swissprot_dir = '/gpfs/alpine/bie108/proj-shared/swissprot/'

def parse_example(example):
    parsed = tf.io.parse_single_example(example, features={
        'sequence': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'annotation': tf.io.FixedLenFeature([], tf.string, default_value=''),
    })
   
    sequence = encode(parsed['sequence'], 1024)
    annotation = tf.io.parse_tensor(parsed['annotation'], out_type=tf.int64)
    
    return sequence, annotation

train_dataset = tf.data.TFRecordDataset(
    os.path.join(swissprot_dir, 'tfrecords_1', 'go_train.tfrecord.gz'),
    compression_type='GZIP', num_parallel_reads=tf.data.experimental.AUTOTUNE)\
    .map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .padded_batch(batch_size=16,
                  padded_shapes=(([1024], [ont.total_nodes])))\
    .prefetch(tf.data.experimental.AUTOTUNE)

valid_dataset = tf.data.TFRecordDataset(
    os.path.join(swissprot_dir, 'tfrecords_1', 'go_valid.tfrecord.gz'),
    compression_type='GZIP', num_parallel_reads=tf.data.experimental.AUTOTUNE)\
    .map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .padded_batch(batch_size=16,
                  padded_shapes=(([1024], [ont.total_nodes])))\
    .prefetch(tf.data.experimental.AUTOTUNE)

In [24]:
from tensorflow.keras import layers

In [25]:
inputs = layers.Input(shape=(None,), dtype=tf.int32, batch_size=None)

initial_bias = np.load(os.path.join(swissprot_dir, 'tfrecords_1', 'bias.npy'))

# Amino-acid level embeddings
embeddings = layers.Embedding(
    24, ont.total_nodes, embeddings_initializer=tf.keras.initializers.Constant(np.tile(initial_bias, (24, 1))),
    mask_zero=True)(inputs)

protein_predictions = tf.keras.layers.GlobalMaxPooling1D()(embeddings)

segments, ids = zip(*ont.iter_ancestor_array())
treenorm = TreeNorm(segments, ids)
normed = treenorm(protein_predictions)

go_model = tf.keras.Model(inputs, normed)

go_model.summary()

optimizer = tf.keras.optimizers.Adam(3E-5)

metrics = [
    LogitSplitFmax(ont, 0),
    LogitSplitFmax(ont, 1),
    LogitSplitFmax(ont, 2),
]

go_model.compile(
   loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
   metrics=metrics,
   optimizer=optimizer)

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding_3 (Embedding)      (None, None, 32372)       776928    
_________________________________________________________________
global_max_pooling1d_2 (Glob (None, 32372)             0         
_________________________________________________________________
tree_norm_1 (TreeNorm)       (None, None)              0         
Total params: 776,928
Trainable params: 776,928
Non-trainable params: 0
_________________________________________________________________


In [26]:
go_model.fit(train_dataset, steps_per_epoch=1000, epochs=1)

Train for 1000 steps


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


  47/1000 [>.............................] - ETA: 44:33 - loss: 0.0063 - logit_split_fmax_biological_process: 0.3552 - logit_split_fmax_molecular_function: 0.3800 - logit_split_fmax_cellular_component: 0.4423

KeyboardInterrupt: 

In [27]:
go_model.evaluate(valid_dataset)

      8/Unknown - 12s 1s/step - loss: 0.0065 - logit_split_fmax_biological_process: 0.3649 - logit_split_fmax_molecular_function: 0.3924 - logit_split_fmax_cellular_component: 0.4622

KeyboardInterrupt: 