In [8]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf

import sys
sys.path.append('..')

from bert.dataset import encode
from bert.model import create_model
from bert.go import TreeNorm
from bert.go import Ontology

ont = Ontology()
swissprot_dir = '/gpfs/alpine/bie108/proj-shared/swissprot/'

swissprot = pd.read_parquet(os.path.join(swissprot_dir, 'parsed_swissprot_uniref_clusters.parquet'))
go_terms = pd.read_parquet(os.path.join(swissprot_dir, 'swissprot_quickgo.parquet'))
swissprot_annotated = swissprot[swissprot.accession.isin(go_terms['GENE PRODUCT ID'].unique())]
swissprot_annotated = swissprot_annotated[swissprot_annotated.length < 10000]

In [9]:
test = np.load('uniref50_split.npz', allow_pickle=True)['test']
swissprot_test = swissprot_annotated[swissprot_annotated['UniRef50 ID'].isin(test)]

In [16]:
go_terms_test = go_terms[go_terms['GENE PRODUCT ID'].isin(swissprot_test.accession)]

In [23]:
from functools import partial

In [30]:
test_data = tf.data.Dataset.from_tensor_slices(swissprot_test.sequence.values)\
    .map(partial(encode, max_sequence_length=1024), num_parallel_calls=tf.data.experimental.AUTOTUNE)\
    .padded_batch(batch_size=16,
                  padded_shapes=([1024]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

In [31]:
dimension = 768
model = create_model(model_dimension=dimension,
                     transformer_dimension=dimension * 4,
                     num_attention_heads=dimension // 64,
                     num_transformer_layers=12,
                     vocab_size=24,
                     dropout_rate=0.0,
                     max_relative_position=64,
                     attention_type='relative')

## Append the GO annotations
final_embedding = model.layers[-2].input
raw_residue_predictions = tf.keras.layers.Dense(ont.total_nodes)(final_embedding)

protein_predictions = tf.keras.layers.GlobalMaxPooling1D()(raw_residue_predictions)

segments, ids = zip(*ont.iter_ancestor_array())
treenorm = TreeNorm(segments, ids)
normed = treenorm(protein_predictions)

go_model_sigmoid = tf.keras.Model(model.inputs, tf.nn.sigmoid(normed))

In [32]:
checkpoint = tf.train.latest_checkpoint(
    '/ccs/home/pstjohn/member_work/uniparc_checkpoints/go_finetuning_new_split_1024_ont1.258061')
go_model_sigmoid.load_weights(checkpoint).expect_partial()

ValueError: Shapes (44232,) and (32372,) are incompatible