## Imports

In [1]:
import os
from pathlib import Path

import numpy as np

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf

tf.get_logger().setLevel("INFO")


### Example training procedure

In [None]:
# Start generator with training labels, pointing to data directory with embeddings


import models.bert_learned_pooler as bert_learned_pooler

bert_model = bert_learned_pooler.create_learned_pooler(1)


In [3]:
bert_model.load_weights("./models/learned_pooler_epochs_1/training_checkpoints/ckpt_0001.ckpt")


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f8dfe86ed10>

In [22]:
bert_model.summary()


Model: "learned_pooler_epochs_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_masks (InputLayer)       [(None, 386)]        0           []                               
                                                                                                  
 input_ids (InputLayer)         [(None, 386)]        0           []                               
                                                                                                  
 token_type_ids (InputLayer)    [(None, 386)]        0           []                               
                                                                                                  
 tf_bert_model (TFBertModel)    TFBaseModelOutputWi  335141888   ['input_masks[0][0]',            
                                thPoolingAndCrossAt               'input_ids

### Extract the weights for all 25 hidden layers

In [18]:
w = bert_model.weights[-4].numpy()
t = bert_model.weights[-3].numpy()
layer_to_weight = dict(zip(range(len(w)), w))
layer_to_weight


{0: -0.7286164,
 1: -0.68705434,
 2: -0.7754198,
 3: -0.7786671,
 4: -0.6581254,
 5: -0.84351796,
 6: -0.7038068,
 7: -0.78308666,
 8: -0.69364554,
 9: -0.7104143,
 10: -0.74511576,
 11: -0.67180556,
 12: -0.5480674,
 13: -0.5075096,
 14: -0.6255021,
 15: -0.65797615,
 16: -0.6923896,
 17: -0.8698429,
 18: -0.75600356,
 19: -0.25995982,
 20: 0.2591863,
 21: 0.38232687,
 22: 0.6862729,
 23: 0.9188513,
 24: 1.3424544}

Get the top 12 absolute weights from layer_to_weight

In [26]:
top_12_absolute = abs(w).argsort()[-12:][::-1]
top_12_absolute


array([24, 23, 17,  5,  7,  3,  2, 18, 10,  0,  9,  6])

### Model compilation

All models are compiled as a bi-headed model, the first representing span start position and the second representing span end position. No activation is applied as the heads come directly from splitting a tensor.

In [None]:
def compile_model(model):
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    # monitor accuracy during the training process
    model.compile(loss=[loss, loss], optimizer="adam", metrics=["accuracy"])


# Gather a list of models to fit
# Since fitting is typically faster than data load, it is beneficial to many models at once
model_list = [average_pooler_model, learned_pooling_model]

for model_current in model_list:
    compile_model(model_current)


In [None]:
import models.bert_embedding_parser as bert_embedding_parser

gen = bert_embedding_parser.load_bert_embeddings(bert_model, batch_size=4)

i = 0
max_batches = 8248  # Can be any number; this is pre-calculated based on the amount of training data used; 8248 goes through entire dataset at batch size of 16


for batch in gen:
    # Read in the batch of data from generator
    X = batch[0]
    Y = batch[1]

    for model_current in model_list:
        # Fit the generated dataset once
        model_current.fit(X, Y, epochs=1)

    # increment counter
    i += 1
    del batch  # delete the batch to free up memory

    # When the number of
    if i == max_batches:  # 4 batches; can save each quarter
        break


In [None]:
# Save weights
weights_dir = "weights"
n = 0
for m in model_list:
    n += 1
    m.save_weights(weights_dir + "/%s.h5" % m.name)


# TODO: print out the 25 weights for each model (before and after fine-tuning)
