In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

In [2]:
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# print("Num GPUs Available: ", len(gpu_devices))
# for device in gpu_devices:
#     tf.config.experimental.set_memory_growth(device, True)

In [3]:
from bert.model import create_albert_model
model = create_albert_model(model_dimension=1024,
                            transformer_dimension=1024 * 4,
                            num_attention_heads=1024 // 64,
                            num_transformer_layers=24,
                            vocab_size=22,
                            dropout_rate=0.)

model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 1024)   22528       input_1[0][0]                    
__________________________________________________________________________________________________
transformer (Transformer)       (None, None, 1024)   12597568    embedding[0][0]                  
                                                                 transformer[0][0]                
                                                                 transformer[1][0]                
                                                                 transformer[2][0]            

In [4]:
from bert.optimizers import (ECE, masked_sparse_categorical_crossentropy,
                             BertLinearSchedule)

# opt = tfa.optimizers.AdamW(learning_rate=1E-4,
#                            beta_2=0.98,
#                            epsilon=1E-6,
#                            weight_decay=0.0)

opt = tf.optimizers.Adam(learning_rate=1E-4,
                           beta_2=0.98,
                           epsilon=1E-6)

opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)

true_labels = tf.keras.layers.Input(
    shape=(None,), dtype=tf.int32, batch_size=None)

model.compile(
    target_tensors=true_labels,    
    loss=masked_sparse_categorical_crossentropy,
    metrics=[ECE],
    optimizer=opt,
    experimental_run_tf_function=True)

In [5]:
callbacks = [    
    BertLinearSchedule(1E-4, 10000, int(1E7)),
    tf.keras.callbacks.ModelCheckpoint(filepath='jupyter_test_checkpoints/ckpt_{epoch}_{val_ECE:.2f}.h5'),
]

In [12]:
from bert.dataset import create_masked_input_dataset

training_data = create_masked_input_dataset(
    sequence_path='../uniparc_data/sequences_train.txt',
    max_sequence_length=128,
    batch_size=32)

training_data = training_data.repeat().prefetch(tf.data.experimental.AUTOTUNE)

valid_data = create_masked_input_dataset(
    sequence_path='../uniparc_data/sequences_valid.txt',
    max_sequence_length=128,
    batch_size=32)

valid_data = valid_data.repeat().prefetch(tf.data.experimental.AUTOTUNE)

In [13]:
model.fit(training_data, steps_per_epoch=10000, epochs=10,
          verbose=1, validation_data=valid_data, validation_steps=10,
          callbacks=callbacks)

Epoch 1/10
  193/10000 [..............................] - ETA: 1:35:52 - loss: 2.8200 - ECE: 16.7936

KeyboardInterrupt: 