In [1]:
import sys
sys.path.append('..')

In [60]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [61]:
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
gpu_index = 2

print("Num GPUs Available: ", len(gpu_devices))
tf.config.set_visible_devices(gpu_devices[gpu_index], 'GPU')
tf.config.experimental.set_memory_growth(gpu_devices[gpu_index], True)

Num GPUs Available:  4


In [62]:
import os
swissprot_loc = '/ccs/home/pstjohn/project_work/swissprot/'

data = pd.read_parquet(os.path.join(swissprot_loc, 'parsed_swissprot.parquet'))
train = pd.read_csv(os.path.join(swissprot_loc, 'subcellular/train.csv.gz')).sample(frac=1.)
valid = pd.read_csv(os.path.join(swissprot_loc, 'subcellular/valid.csv.gz')).sample(frac=1.)
test  = pd.read_csv(os.path.join(swissprot_loc, 'subcellular/test.csv.gz')).sample(frac=1.)

In [6]:
checkpoint_dir = '/ccs/home/pstjohn/member_work/uniparc_checkpoints/12_layer_localization_model.134306'
tf.train.latest_checkpoint(checkpoint_dir)

from bert.model import create_albert_model

dimension = 768

model = create_albert_model(model_dimension=dimension,
                            transformer_dimension=dimension * 4,
                            num_attention_heads=dimension // 64,
                            num_transformer_layers=12,
                            dropout_rate=0.,
                            max_relative_position=64,
                            final_layernorm=False)

final_embedding = model.layers[-2].input
residue_predictions = tf.keras.layers.Dense(num_targets)(final_embedding)
protein_predictions = tf.keras.layers.GlobalMaxPooling1D()(residue_predictions)

localization_model = tf.keras.Model(model.inputs, protein_predictions)

localization_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
localization_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding_1 (Embedding)      (None, None, 768)         18432     
_________________________________________________________________
transformer_12 (Transformer) (None, None, 768)         7096128   
_________________________________________________________________
transformer_13 (Transformer) (None, None, 768)         7096128   
_________________________________________________________________
transformer_14 (Transformer) (None, None, 768)         7096128   
_________________________________________________________________
transformer_15 (Transformer) (None, None, 768)         7096128   
_________________________________________________________________
transformer_16 (Transformer) (None, None, 768)         709612

In [63]:
from bert.dataset import encode
from functools import partial

valid_sequences = data.set_index('accession').reindex(valid.accession).sequence
valid_labels = valid.set_index('accession')

max_seq_len = 512
batch_size = 32

valid_data = tf.data.Dataset.from_tensor_slices(valid_sequences.values).map(
    partial(encode, max_sequence_length=max_seq_len)).padded_batch(
    batch_size=batch_size, padded_shapes=[-1])

In [23]:
valid_predictions = localization_model.predict(valid_data, verbose=1)
valid_probs = tf.nn.softmax(valid_predictions)



In [24]:
valid_predictions.shape

(10000, 18)

In [37]:
valid_bool = valid_probs > 0.5

In [81]:
from sklearn.metrics import f1_score, accuracy_score, precision_score

In [70]:
f1_score(valid_labels.iloc[:, 0], valid_probs[:, 0] > 0.5)

0.07799074686054197

In [89]:
pd.Series({col: f1_score(valid_labels.iloc[:, i], valid_probs[:, i] > 0.5) 
           for i, col in enumerate(valid_labels.columns)}).sort_values(ascending=False)

Cytoplasm                         0.444007
Secreted                          0.102381
Nucleus                           0.090998
Cell membrane                     0.081096
Cell inner membrane               0.077991
Membrane                          0.045351
Plastid                           0.032017
Virion                            0.031746
Mitochondrion inner membrane      0.021798
Mitochondrion                     0.014815
Periplasm                         0.014493
Host nucleus                      0.012422
Endoplasmic reticulum membrane    0.011111
Golgi apparatus membrane          0.000000
Chromosome                        0.000000
Cell projection                   0.000000
Cell junction                     0.000000
Host cytoplasm                    0.000000
dtype: float64

In [50]:
 == valid_bool.numpy()

Unnamed: 0_level_0,Cell inner membrane,Cell junction,Cell membrane,Cell projection,Chromosome,Cytoplasm,Endoplasmic reticulum membrane,Golgi apparatus membrane,Host cytoplasm,Host nucleus,Membrane,Mitochondrion,Mitochondrion inner membrane,Nucleus,Periplasm,Plastid,Secreted,Virion
accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
A3PEA4,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
Q04364,True,True,True,True,True,False,True,True,True,True,True,True,True,True,True,True,True,True
A9R3T6,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
Q617M0,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
Q4R5C6,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
Q1R4M9,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
Q080P6,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
B5BDC9,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
B0TIP2,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True
A0A2A5JY22,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True


In [42]:
valid_bool.shape

TensorShape([10000, 18])

In [25]:
valid_predictions[0]

array([-13.4831295, -13.291877 , -13.238673 , -12.2335205, -12.167766 ,
         9.265969 , -14.079013 , -15.818308 , -11.057702 , -11.632896 ,
       -11.869153 , -10.446587 , -13.612272 , -10.603114 , -14.839796 ,
       -13.735782 , -12.178269 , -13.885028 ], dtype=float32)