In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [3]:
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
gpu_index = 3

print("Num GPUs Available: ", len(gpu_devices))
tf.config.set_visible_devices(gpu_devices[gpu_index], 'GPU')
tf.config.experimental.set_memory_growth(gpu_devices[gpu_index], True)

Num GPUs Available:  4


## Load the data and sequences

In [4]:
import os
swissprot_loc = '/ccs/home/pstjohn/project_work/swissprot/'

data = pd.read_parquet(os.path.join(swissprot_loc, 'parsed_swissprot.parquet')).set_index('accession')
train = pd.read_csv(os.path.join(swissprot_loc, 'subcellular/train.csv.gz')).sample(frac=1.)
valid = pd.read_csv(os.path.join(swissprot_loc, 'subcellular/valid.csv.gz')).sample(frac=1.)

num_targets = train.shape[1] - 1

## Convert to a tensorflow dataset

In [5]:
from bert.dataset import encode
from functools import partial

max_seq_len=512
fix_seq_len=True
batch_size=24

def create_dataset(sequences,
                   targets,
                   buffer_size=1000):
    
    encoded = tf.data.Dataset.from_tensor_slices(sequences.values)\
        .map(partial(encode, max_sequence_length=max_seq_len))
    target_ds = tf.data.Dataset.from_tensor_slices(
        targets.values.astype(np.int32))
    zipped = tf.data.Dataset.zip((encoded, target_ds))\
        .shuffle(buffer_size)
    return zipped


def get_train_subsampled():
    for location in train.columns[1:]:
        train_subset = train[train[location] == 1]
        train_seq_subset = data.reindex(train_subset.accession).sequence
        yield create_dataset(train_seq_subset, train_subset.set_index('accession'))

train_ds = tf.data.experimental.sample_from_datasets(list(get_train_subsampled()))\
    .padded_batch(batch_size=batch_size, padded_shapes=(
    [-1 if not fix_seq_len else max_seq_len], [num_targets]))

valid_sequences = data.reindex(valid.accession).sequence
valid_ds = create_dataset(valid_sequences, valid.set_index('accession'))\
    .padded_batch(batch_size=batch_size, padded_shapes=(
    [-1 if not fix_seq_len else max_seq_len], [num_targets]))

## Load the best performing model

In [6]:
checkpoint_dir = '/ccs/home/pstjohn/member_work/uniparc_checkpoints/12_layer_bs24_adamw_20200527'
tf.train.latest_checkpoint(checkpoint_dir)

from bert.model import create_albert_model

dimension = 768

model = create_albert_model(model_dimension=dimension,
                            transformer_dimension=dimension * 4,
                            num_attention_heads=dimension // 64,
                            num_transformer_layers=12,
                            dropout_rate=0.,
                            max_relative_position=64,
                            final_layernorm=False)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding (Embedding)        (None, None, 768)         18432     
_________________________________________________________________
transformer (Transformer)    (None, None, 768)         7096128   
_________________________________________________________________
transformer_1 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_2 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_3 (Transformer)  (None, None, 768)         7096128   
_________________________________________________________________
transformer_4 (Transformer)  (None, None, 768)         709612

In [7]:
model.trainable = False
final_embedding = model.layers[-2].input
residue_predictions = tf.keras.layers.Dense(num_targets, activation='sigmoid')(final_embedding)
protein_predictions = tf.keras.layers.GlobalMaxPooling1D()(residue_predictions)

localization_model = tf.keras.Model(model.inputs, protein_predictions)

In [8]:
optimizer = tf.keras.optimizers.Adam(1E-4)

In [9]:
import tensorflow_addons as tfa

In [10]:
metrics = [
    tf.keras.metrics.BinaryAccuracy(name='accuracy'),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tf.keras.metrics.AUC(name='pr_auc', curve='PR'),
#    tfa.metrics.F1Score(num_classes=num_targets, name='f1')
]

In [11]:
localization_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=metrics,
    optimizer=optimizer)

In [12]:
localization_model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=10,
    verbose=1,
    steps_per_epoch=250,
    validation_steps=10)

Train for 250 steps, validate for 10 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10

KeyboardInterrupt: 