In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [2]:
!nvidia-smi

Sat Aug  1 21:39:09 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.116.00   Driver Version: 418.116.00   CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-SXM2...  On   | 00000004:04:00.0 Off |                    0 |
| N/A   36C    P0    37W / 300W |      0MiB / 16130MiB |      0%   E. Process |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000004:05:00.0 Off |                    0 |
| N/A   43C    P0    54W / 300W |      0MiB / 16130MiB |      0%   E. Process |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000035:03:00.0 Off |                    0 |
| N/A   

In [3]:
import sys
sys.path.append('..')

In [4]:
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
gpu_index = 3

print("Num GPUs Available: ", len(gpu_devices))
tf.config.set_visible_devices(gpu_devices[gpu_index], 'GPU')
tf.config.experimental.set_memory_growth(gpu_devices[gpu_index], True)

from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

Num GPUs Available:  4


In [5]:
from bert.go import Ontology
ont = Ontology(threshold=1)
ont.total_nodes

32372

In [6]:
assert (pd.Series(ont.term_index) == pd.read_csv('term_index.csv', index_col=0, header=None)[1]).all()

In [7]:
import os
swissprot_dir = '/gpfs/alpine/bie108/proj-shared/swissprot/'

In [8]:
max_seq_len=512
batch_size=4

from bert.dataset import encode

def parse_example(example):
    parsed = tf.io.parse_single_example(example, features={
        'sequence': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'annotation': tf.io.FixedLenFeature([], tf.string, default_value=''),
    })
   
    sequence = encode(parsed['sequence'], max_seq_len)
    annotation = tf.io.parse_tensor(parsed['annotation'], out_type=tf.int64)
    
    return sequence, annotation


train_dataset = tf.data.TFRecordDataset(
    os.path.join(swissprot_dir, 'tfrecords_1', 'go_train.tfrecord.gz'),
    compression_type='GZIP', num_parallel_reads=tf.data.experimental.AUTOTUNE)\
    .map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .cache().shuffle(buffer_size=5000)\
    .padded_batch(batch_size=batch_size,
                  padded_shapes=(([max_seq_len], [ont.total_nodes])))\
    .prefetch(tf.data.experimental.AUTOTUNE)


valid_dataset = tf.data.TFRecordDataset(
    os.path.join(swissprot_dir, 'tfrecords_1', 'go_valid.tfrecord.gz'),
    compression_type='GZIP', num_parallel_reads=tf.data.experimental.AUTOTUNE)\
    .map(parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .cache().shuffle(buffer_size=5000)\
    .padded_batch(batch_size=batch_size,
                  padded_shapes=(([max_seq_len], [ont.total_nodes])))\
    .prefetch(tf.data.experimental.AUTOTUNE)

In [9]:
class_sample = np.concatenate([b.numpy() for a,b in train_dataset.take(1000)])
total = np.prod(class_sample.shape)
pos = class_sample.sum()
neg = total - pos
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

initial_bias = np.log(class_sample.sum(0) / class_sample.shape[0] + np.finfo(float).eps)

In [10]:
np.save(os.path.join(swissprot_dir, 'tfrecords_1', 'bias.npy'), initial_bias)

In [11]:
# train_sample = np.concatenate([arr.numpy() for _, arr in train_dataset.take(1000)])
# train_sample[:, ont.get_head_node_indices()].mean(0)

In [12]:
checkpoint_dir = '/ccs/home/pstjohn/member_work/uniparc_checkpoints/12_layer_relative_adam_20200625.186949'
tf.train.latest_checkpoint(checkpoint_dir)

from bert.model import create_model

dimension = 768

model = create_model(model_dimension=dimension,
                     transformer_dimension=dimension * 4,
                     num_attention_heads=dimension // 64,
                     num_transformer_layers=12,
                     vocab_size=24,
                     dropout_rate=0.0,
                     max_relative_position=64,
                     attention_type='relative')

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding (Embedding)        (None, None, 768)         18432     
_________________________________________________________________
transformer (Transformer)    (None, None, 768)         7096128   
_________________________________________________________________
transformer_1 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_2 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_3 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_4 (Transformer)  (None, None, 768)         709612

In [13]:
model.trainable = False
final_embedding = model.layers[-2].input
raw_residue_predictions = tf.keras.layers.Dense(
    ont.total_nodes, bias_initializer=tf.keras.initializers.Constant(initial_bias))(final_embedding)

In [14]:
class TreeNorm(tf.keras.layers.Layer):
    """ Multiply each GO score by the scores of its ancestor nodes to normalize
    tree-valued predictions to be monotonically decreasing with depth.
    
    For some reason, I have to use `unsorted_segment_prod` here:
    https://github.com/tensorflow/tensorflow/issues/41090
    
    """
    def __init__(self, segments, ids, **kwargs):
        super(TreeNorm, self).__init__(**kwargs)        
        self.segments = segments
        self.ids = ids
    
    def call(self, inputs):
        return tf.transpose(tf.math.segment_min(
            tf.gather(tf.transpose(inputs), self.ids), self.segments))
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [15]:
segments, ids = zip(*ont.iter_ancestor_array())
protein_predictions = tf.keras.layers.GlobalMaxPooling1D()(raw_residue_predictions)

treenorm = TreeNorm(segments, ids)
protein_predictions = treenorm(protein_predictions)

go_model = tf.keras.Model(model.inputs, protein_predictions)

In [16]:
go_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding (Embedding)        (None, None, 768)         18432     
_________________________________________________________________
transformer (Transformer)    (None, None, 768)         7096128   
_________________________________________________________________
transformer_1 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_2 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_3 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_4 (Transformer)  (None, None, 768)         709612

In [16]:
# def weighted_loss(y_true, y_pred):
#     weights = tf.gather(tf.constant([weight_for_0, weight_for_1], dtype=tf.float16), tf.cast(y_true, tf.int32))
#     losses = tf.nn.sigmoid_cross_entropy_with_logits(tf.cast(y_true, tf.float16), y_pred)
#     return losses * weights

In [18]:
class LogitPrecision(tf.keras.metrics.Precision):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.nn.sigmoid(y_pred)
        super(LogitPrecision, self).update_state(y_true, y_pred, sample_weight=None)
        
class LogitRecall(tf.keras.metrics.Recall):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.nn.sigmoid(y_pred)
        super(LogitRecall, self).update_state(y_true, y_pred, sample_weight=None)
        
class LogitAUC(tf.keras.metrics.AUC):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.nn.sigmoid(y_pred)
        super(LogitAUC, self).update_state(y_true, y_pred, sample_weight=None)

In [20]:
optimizer = tf.keras.optimizers.Adam(1E-4)

metrics = [
    tf.keras.metrics.BinaryAccuracy(name='accuracy', threshold=0.0),
    LogitPrecision(),
    LogitRecall(),
    LogitAUC(name='pr_auc', curve='PR'),
]

go_model.compile(
   loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
   metrics=metrics,
   optimizer=optimizer)

go_model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=10,
    verbose=1)

Epoch 1/10


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


    335/Unknown - 45s 134ms/step - loss: 0.0554 - accuracy: 0.9856 - logit_precision: 0.5485 - logit_recall: 0.2495 - pr_auc: 0.3529

KeyboardInterrupt: 

In [None]:
inputs = list(train_dataset.take(1))

In [None]:
out = tf.math.sigmoid(go_model(inputs[0][0]))

In [None]:
out

In [None]:
out.numpy()[:, np.array(ont.get_head_node_indices())]

In [None]:
len(ids)

In [None]:
tf.transpose(tf.math.unsorted_segment_prod(tf.gather(tf.transpose(out), ids), segments, ont.total_nodes))

In [None]:
df = pd.DataFrame.from_dict(ont.term_index, orient='index', columns=['GO'])
df['scores'] = out[0].numpy()

In [None]:
test_preds = out[0].numpy()

for i in range(ont.total_nodes):
    assert (test_preds[i] <= test_preds[np.asarray(ont.terms_to_indices(ont.get_ancestors((ont.term_index[i],))))]).all()
    assert (test_preds[i] >= test_preds[np.asarray(ont.terms_to_indices(ont.get_descendants((ont.term_index[i],))))]).all()

In [None]:
[ont.term_index[i] for i in ont.get_head_node_indices()]

In [None]:
ont.G.node['GO:0003674']

In [None]:
subgraph = ont.G.subgraph(pd.Series(ont.term_index)[test_preds > 0.5])
leafs = (node for node, out_degree in subgraph.out_degree if out_degree == 0)
pd.DataFrame((ont.G.node[leaf] for leaf in leafs)).sort_values('namespace')

In [None]:
subgraph = ont.G.subgraph(pd.Series(ont.term_index)[true_values])
leafs = (node for node, out_degree in subgraph.out_degree if out_degree == 0)
pd.DataFrame((ont.G.node[leaf] for leaf in leafs)).sort_values('namespace')