In [1]:
import sys

sys.path.append("/Users/PRVATE/Documents/tf_transformers/src/")

In [2]:
import tensorflow as tf
import tensorflow_hub as hub

from transformers import TFAlbertModel
from tf_transformers.models import AlbertEncoder

from tf_transformers.core import LegacyModule
from transformers import AlbertTokenizer

from tf_transformers.utils import convert_albert_hf_to_tf_transformers
import json
import time

In [3]:
# Load HF model

# Always do this
tf.keras.backend.clear_session()

local_dir = "/Users/PRVATE/HUggingFace_Models/"
hf_model_name = "albert-base-v2"
if local_dir:
    hf_model_location = local_dir + hf_model_name

model_hf = TFAlbertModel.from_pretrained(hf_model_location)

All model checkpoint layers were used when initializing TFAlbertModel.

All the layers of TFAlbertModel were initialized from the model checkpoint at /Users/PRVATE/HUggingFace_Models/albert-base-v2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFAlbertModel for predictions without further training.


In [4]:
# Load tf_transformers model
# Most config we will be providing

# Default configs for the model
config_location = (
    "../../configs/model_configs/" + "albert_base_v2/" + "albert_config.json"
)
config = json.load(open(config_location))

# Always do this
tf.keras.backend.clear_session()

# tf_transformers Layer (an extension of Keras Layer)
# This is not Keras model, but extension of keras Layer

# Save as saved_model
# If you want to use the model for Auto Regressive tasks ( text-generation ),
# you have to enable pipeline_mode='auto-regressive'.
# Because TF needs extra cache inputs in the saved_model format for doing efficient caching

model_layer = AlbertEncoder(
    config=config,
    name="albert",
    mask_mode=config["mask_mode"],
    is_training=False,
    use_dropout=False,
)

# Convert to tf.keras.Model
model_tf_transformers = model_layer.get_and_load_model(model_dir=None)
convert_albert_hf_to_tf_transformers(model_hf, model_tf_transformers, config)

INFO:absl:We are overwriding `is_training` is False to `is_training` to                     True with `use_dropout` is False, no effects on your inference pipeline
INFO:absl:Inputs -->
INFO:absl:input_ids ---> Tensor("input_ids:0", shape=(None, None), dtype=int32)
INFO:absl:input_mask ---> Tensor("input_mask:0", shape=(None, None), dtype=int32)
INFO:absl:input_type_ids ---> Tensor("input_type_ids:0", shape=(None, None), dtype=int32)
INFO:absl:Initialized Variables
INFO:absl:Inputs -->
INFO:absl:input_ids ---> Tensor("input_ids_1:0", shape=(None, None), dtype=int32)
INFO:absl:input_mask ---> Tensor("input_mask_1:0", shape=(None, None), dtype=int32)
INFO:absl:input_type_ids ---> Tensor("input_type_ids_1:0", shape=(None, None), dtype=int32)
INFO:absl:Deleteing huggingface model for saving memory
INFO:absl:Done assigning variables weights. Total 25


In [5]:
# If you want to save the model as checkpoints

checkpoint = tf.train.Checkpoint(model=model_tf_transformers)
manager = tf.train.CheckpointManager(
    checkpoint, directory="model_albert_ckpt", max_to_keep=1
)
manager.save()
print("Saved at {}".format(manager.latest_checkpoint))

Saved at model_albert_ckpt/ckpt-1


In [5]:
model_tf_transformers

<tf_transformers.core.legacy_model.LegacyModel at 0x14f3ca5b0>

In [6]:
# Please have a look at tf_transformers/extra/*.py for reference values

input_ids = tf.constant([[1, 9, 10, 11, 23], [1, 22, 234, 432, 2349]])
input_mask = tf.ones_like(input_ids)
input_type_ids = tf.ones_like(input_ids)

inputs = {
    "input_ids": input_ids,
    "input_mask": input_mask,
    "input_type_ids": input_type_ids,
}

results_tf_transformers = model_tf_transformers(inputs)
for k, r in results_tf_transformers.items():
    if isinstance(r, list):
        continue
    print(k, "-->", tf.reduce_sum(r), "-->", r.shape)

cls_output --> tf.Tensor(12.337963, shape=(), dtype=float32) --> (2, 768)
token_embeddings --> tf.Tensor(-193.53201, shape=(), dtype=float32) --> (2, 5, 768)
token_logits --> tf.Tensor(-18578.355, shape=(), dtype=float32) --> (2, 5, 30000)
last_token_logits --> tf.Tensor(-3803.8923, shape=(), dtype=float32) --> (2, 30000)


In [7]:
model_layer = AlbertEncoder(
    config=config,
    name="albert",
    mask_mode=config["mask_mode"],
    is_training=False,
    pipeline_mode="auto-regressive",
)

# Convert to tf.keras.Model
model_tf_transformers = model_layer.get_and_load_model(model_dir=None)

# And now load the checkpints from previously saved model

checkpoint = tf.train.Checkpoint(model=model_tf_transformers)
manager = tf.train.CheckpointManager(
    checkpoint, directory="model_albert_ckpt", max_to_keep=1
)
status = checkpoint.restore(manager.latest_checkpoint)

# Important
status.assert_existing_objects_matched()

INFO:absl:Inputs -->
INFO:absl:input_ids ---> Tensor("input_ids_2:0", shape=(None, None), dtype=int32)
INFO:absl:input_mask ---> Tensor("input_mask_2:0", shape=(None, None), dtype=int32)
INFO:absl:input_type_ids ---> Tensor("input_type_ids_2:0", shape=(None, None), dtype=int32)
INFO:absl:all_cache_key ---> Tensor("all_cache_key:0", shape=(None, None, 12, None, 64), dtype=float32)
INFO:absl:all_cache_value ---> Tensor("all_cache_value:0", shape=(None, None, 12, None, 64), dtype=float32)
INFO:absl:past_length ---> Tensor("past_length:0", shape=(1, None), dtype=int32)
INFO:absl:Initialized Variables
INFO:absl:Inputs -->
INFO:absl:input_ids ---> Tensor("input_ids_3:0", shape=(None, None), dtype=int32)
INFO:absl:input_mask ---> Tensor("input_mask_3:0", shape=(None, None), dtype=int32)
INFO:absl:input_type_ids ---> Tensor("input_type_ids_3:0", shape=(None, None), dtype=int32)
INFO:absl:all_cache_key ---> Tensor("all_cache_key_1:0", shape=(None, None, 12, None, 64), dtype=float32)
INFO:absl:a


Two checkpoint references resolved to different objects (<tf_transformers.models.albert.AlbertEncoder object at 0x14f6c8b80> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x14f7b32e0>).



Two checkpoint references resolved to different objects (<tf_transformers.models.albert.AlbertEncoder object at 0x14f6c8b80> and <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x14f7b32e0>).


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1502523a0>

In [8]:
# albert_module = LegacyModule(model_tf_transformers)
# albert_module.save("model_albert_pb")

In [9]:
# This tokenizer is not necessary
# We can use amazing HuggingFace tokenizer library also

tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")

In [10]:
def tokenizer_fn(text_list):
    """Tokenizer fn should return a dict (no padding is required).
    Make sure, you pass all primary keys required to the model

    text_list: a list of text

    {'input_ids': tf.constant([[1, 2]]),
     'input_mask': tf.constant([[1, 1]]),
     'input_type_ids': tf.constant([[1, 0]])}


    """
    input_ids = []
    for text in text_list:
        input_ids.append(tokenizer.encode(text))
    input_ids = tf.ragged.constant(input_ids)
    inputs = {"input_ids": input_ids}
    inputs["input_mask"] = tf.ones_like(input_ids).numpy().tolist()
    inputs["input_type_ids"] = tf.zeros_like(input_ids).numpy().tolist()
    inputs["input_ids"] = input_ids.numpy().tolist()
    return inputs

In [11]:
from tf_transformers.text import TextDecoder
from tf_transformers.text import TextDecoderSerializable

In [12]:
decoder_layer_beam = TextDecoder(
    tokenizer_fn=tokenizer_fn,
    model=model_tf_transformers,
    num_attention_heads=12,
    num_layers=12,
    attention_state=64,
    input_mask_ids=1,
    input_type_ids=0,
)

In [13]:
text_list = ["Sachin Tendulkar is one of the finest", 
            "I like to walk with my dog"]

In [14]:
# Beam Search

start_time = time.time()
result_beam = decoder_layer_beam.decode(
    text_list, max_iterations=25, beam_size=2, mode="beam", do_sample=False, eos_id=None
)
# for i in range(len(result["input_ids"])):
#     for beam_predicted_ids in result["predicted_ids"][i]:
#         print(
#             tokenizer.decode(
#                 tf.concat([result_beam["input_ids"][i], beam_predicted_ids], axis=0).numpy()
#             )
#         )
#         print("--------------")
end_time = time.time()
print("Time taken {} seconds".format(end_time - start_time))
print('_______________________________________________________')




Time taken 3.2244322299957275 seconds
_______________________________________________________


In [15]:
# Greedy Search

start_time = time.time()
result_greedy = decoder_layer_beam.decode(
    text_list, max_iterations=25, mode="greedy", do_sample=False, eos_id=None
)
# for i in range(len(result["input_ids"])):
#     for beam_predicted_ids in result["predicted_ids"][i]:
#         print(
#             tokenizer.decode(
#                 tf.concat([result_beam["input_ids"][i], beam_predicted_ids], axis=0).numpy()
#             )
#         )
#         print("--------------")
end_time = time.time()
print("Time taken {} seconds".format(end_time - start_time))
print('_______________________________________________________')


Time taken 2.4923179149627686 seconds
_______________________________________________________


In [16]:
# Top K top P Search

start_time = time.time()
result_top_k_top_p = decoder_layer_beam.decode(
    text_list, max_iterations=25, mode="top_k_top_p", top_k=50, top_p=0.7, do_sample=False, eos_id=None, 
    num_return_sequences=2
)
# for i in range(len(result["input_ids"])):
#     for beam_predicted_ids in result["predicted_ids"][i]:
#         print(
#             tokenizer.decode(
#                 tf.concat([result_beam["input_ids"][i], beam_predicted_ids], axis=0).numpy()
#             )
#         )
#         print("--------------")
end_time = time.time()
print("Time taken {} seconds".format(end_time - start_time))
print('_______________________________________________________')


Time taken 2.9417638778686523 seconds
_______________________________________________________


In [18]:
decoder_layer_serializable = TextDecoderSerializable(
    model_tf_transformers,
    input_name_list=["input_ids", "input_mask", "input_type_ids"],
    max_iterations=25,
    num_attention_heads=12,
    num_layers=12,
    attention_state=64,
    mode="greedy",
    do_sample=False,
    eos_id=-100,
    input_mask_ids=1,
    input_type_ids=0,
)

inputs_for_serializable = tokenizer_fn(text_list)

main_inputs = {}
for k, v in inputs_for_serializable.items():
    main_inputs[k] = tf.ragged.constant(v)
main_inputs['input_ids'] = main_inputs['input_ids'].to_tensor(-1)
main_inputs['input_mask'] = main_inputs['input_mask'].to_tensor(0)
main_inputs['input_type_ids'] = main_inputs['input_type_ids'].to_tensor(0)

start_time = time.time()
results_serializable_greedy = decoder_layer_serializable(main_inputs)
end_time = time.time()
print("Time taken {} seconds".format(end_time - start_time))

Time taken 2.503310203552246 seconds


In [19]:
decoder_model  = decoder_layer_serializable.get_model()
decoder_module = LegacyModule(decoder_model)
decoder_module.save("model_temp_pb")

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


INFO:tensorflow:Assets written to: model_temp_pb/assets


INFO:tensorflow:Assets written to: model_temp_pb/assets


In [63]:
import tensorflow as tf
from tf_transformers.text import (assign_zeros_to_K_V,
                                        _log_prob_from_logits,
                                        _gather_beams,
                                        top_k_logits,
                                        top_p_logits)

import numpy as np

tf.keras.backend.clear_session()
class TextDecoderSerializable(tf.keras.layers.Layer):
    """TextDecoderSerializable - This class is responsible for saving the model along with decoding
    operation as a saved_model, which makes deployment in production easier.
    """
    def __init__(
        self,
        model,
        max_iterations,
        num_attention_heads,
        num_layers,
        attention_state,
        mode,
        input_name_list=None,
        beam_size = 1,
        eos_id=-100,
        do_sample=False,
        top_k = 0,
        top_p = 0,
        num_return_sequences = 1,
        input_mask_ids = None,
        input_type_ids = None,
    ):
        """[Init]

        Args:
            model ([tf.keras.Model / tf.keras.Layer]): [The model with which decoding
            has to be performed]
            max_iterations ([int]): [Maximum iterations for decoding]
            num_attention_heads ([int]): [Attention heads of model]
            num_layers ([int]): [Number of model layers]
            attention_state ([int]): [embedding_size//num_attention_heads]
            mode ([str]): ['greedy' , 'beam', 'top_k_top_p']
            input_name_list ([List of int]): [Names of model inputs like input_ids, input_mask, etc]
            beam_size (int, optional): [Number of beam size]. Defaults to 1.
            eos_id (int, optional): [end of sentence token id]. Defaults to -100.
            do_sample (bool, optional): [Multinomial sampling]. Defaults to False.
            top_k (int, optional): [top k]. Defaults to 0.
            top_p (int, optional): [top p Nucleus]. Defaults to 0.
            input_mask_ids (int, optional): [if your model has this, provide it]. Defaults to None.
            input_type_ids (int, optional): [if your model has this, provide it]. Defaults to None.
            num_return_sequences: (int): [No of return sequences for topk top beam] Defaults to 1.
        """

        super(TextDecoderSerializable, self).__init__()

        self.num_attention_heads = num_attention_heads
        self.num_layers = num_layers
        self.attention_state = attention_state

        self.model = model
        # self.input_name_list = input_name_list
        self.input_name_list, self.model_inputs = self.get_inputs()
        self.input_name_map = {i:k for i, k in enumerate(self.input_name_list)}


        self.eos_id = eos_id
        self.max_iterations = max_iterations
        self.mode = mode

        # Mask and type ids
        self.input_mask_ids = input_mask_ids
        self.input_type_ids = input_type_ids

        self.beam_size = beam_size
        self.top_k = top_k
        self.top_p = top_p
        self.do_sample = do_sample
        self.num_return_sequences = num_return_sequences


        if self.mode == 'greedy':
            self.decoder_fn = self.greedy()
        elif self.mode == 'beam':
            self.decoder_fn = self.beam()
        elif self.mode == 'top_k_top_p':
            self.decoder_fn = self.top_k_top()




    def get_model(self):
        # Call the model in init itself
        layer_outputs = self(self.model_inputs)
        decoder_model = tf.keras.Model(
                        inputs=self.model_inputs,
                        outputs=layer_outputs,
                        name='decoder_model')
        return decoder_model


    def get_inputs(self):

        input_ids = tf.keras.layers.Input(
                        shape=(None,), batch_size = None, ragged=False, dtype=tf.int32, name='input_ids')
        input_mask = tf.keras.layers.Input(
                        shape=(None,), batch_size = None, ragged=False, dtype=tf.int32, name='input_mask')
        input_type_ids = tf.keras.layers.Input(
                        shape=(None,), batch_size = None, ragged=False, dtype=tf.int32, name='input_type_ids')
        self.input_name_list = []
        if 'input_ids' in self.model.input:
            self.input_name_list.append('input_ids')
        if 'input_mask' in self.model.input:
            self.input_name_list.append('input_mask')
        if 'input_type_ids' in self.model.input:
            self.input_name_list.append('input_type_ids')

        inputs = {}
        for name in self.input_name_list:
            if name == 'input_ids':
                inputs['input_ids'] = input_ids
                continue
            if name == 'input_mask':
                inputs['input_mask'] = input_mask
            if name == 'input_type_ids':
                inputs['input_type_ids'] = input_type_ids

        return self.input_name_list, inputs


    def reorder_past_batches(self, all_cache_key, all_cache_value, coordinates, beam_size):
        """[Reorder the input batch based on beam predictions
        Future beams changes the best path order]

        Args:
            all_cache_key ([tf.tensor]): [K from Transformers]
            all_cache_value ([tf.tensor]): [V from Transformers]
            coordinates ([tf.tensor (bach_size x beam_size)]): [The order ]
            beam_size ([int/tf.tensor]): [Number of beams]

        Returns:
            [type]: [description]

        """
        coordinates_reshaped = coordinates[:, :beam_size, -1] + tf.expand_dims(tf.range(tf.shape(coordinates)[0]) * beam_size, 1)
        # Old Approach
        # coordinates_reshaped = tf.reshape(coordinates_reshaped, -1)
        # all_cache_key   = tf.gather(all_cache_key, coordinates_reshaped , axis=1)
        # all_cache_value = tf.gather(all_cache_value, coordinates_reshaped, axis=1)

        coordinates_reshaped = tf.reshape(coordinates_reshaped, (1,-1))
        all_cache_key   = tf.squeeze(tf.gather(all_cache_key, coordinates_reshaped , axis=1), axis=1)
        all_cache_value = tf.squeeze(tf.gather(all_cache_value, coordinates_reshaped, axis=1), axis=1)
        return all_cache_key, all_cache_value


    def greedy(self):
        """
        This function will perform greedy decoding.
        """

        # EOS check function
        def cond(i, input_ids, all_cache_key, all_cache_value, past_length, initial_id):
            eos_check = tf.greater(
                tf.reduce_prod(
                    tf.reduce_sum(
                        tf.cast(tf.equal(initial_id, self.eos_id), tf.int32), axis=[1]
                    )
                ),
                0,
            )
            return tf.not_equal(eos_check, True)

        def body(i, inputs_tuple, all_cache_key, all_cache_value, past_length, initial_id):

            """[This is the body of the beam decoder]

            Args:
                i ([tf.tensor]): [iterator (an int)]
                inputs ([List of model inputs]): [description]
                all_cache_key ([K]): [description]
                all_cache_value ([V]): [description]
                past_length ([tf.tensor (1 x batch_size)]): [description]
                This is our main output or decoded ids]
                alive_log_probs ([tf.tensor]): [To keep track of active ids]
                alive_seq ([tf.tensor]): [description]

            Returns:
                [List of tensors]: [Outputs]
            """
            inputs = {}
            for k in range(len(self.input_name_list)):
                inputs[self.input_name_list[k]] = inputs_tuple[k]

            inputs['all_cache_key'] = all_cache_key
            inputs['all_cache_value'] = all_cache_value
            inputs['past_length'] = past_length

            model_outputs = self.model(inputs)
            model_logits = model_outputs['last_token_logits']

            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            past_length = model_outputs["past_length"]

            prediction_ids = tf.argmax(model_logits, axis=1)
            input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32)

            inputs_tuple = [None] * len(self.input_name_list)

            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
            # Convert to tuple
            inputs_tuple = tuple(inputs_tuple)

            return [
                i + 1,
                inputs_tuple,
                model_outputs["all_cache_key"],
                model_outputs["all_cache_value"],
                model_outputs["past_length"],
                tf.concat([initial_id, input_ids], axis=1),
            ]

        # @tf.function(experimental_relax_shapes=True)
        def call_greedy(inputs):
            input_ids_orig = inputs["input_ids"]
            # Original batch size and sequence length
            batch_size = tf.shape(inputs['input_ids'])[0]
            max_sequence_length = tf.shape(inputs['input_ids'])[1]
            # Repeat for beam search (We nedd batch_size x beam_size)
            model_inputs = {}
            for input_key, input_value in inputs.items():
                model_inputs[input_key] = input_value

            # Pre-initialize addtional inputs
            zero_entry  = tf.zeros((self.num_layers,
                                    batch_size,
                                    self.num_attention_heads,
                                    max_sequence_length, self.attention_state))
            all_cache_key   = zero_entry
            all_cache_value = zero_entry
            # past_length for keeping track of positional ids
            past_length = tf.expand_dims(tf.zeros(batch_size, dtype=tf.int32), 0)
            # Iterator to keep track of the loop
            i = tf.constant([[0]])
            initial_id = tf.ones(shape=(batch_size, 1), dtype=tf.int32)

            # Add remaining model inputs
            model_inputs['all_cache_key'] = all_cache_key
            model_inputs['all_cache_value'] = all_cache_value
            model_inputs['past_length'] = past_length

            model_outputs = self.model(model_inputs)
            model_logits = model_outputs["last_token_logits"]
            prediction_ids = tf.argmax(model_logits, axis=1)
            input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32)

            # Update iter
            i = i + 1
            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            initial_id = tf.concat([initial_id, input_ids], axis=1)

            masks = tf.cast(tf.not_equal(inputs["input_ids"], -1), tf.float32)
            masks = tf.reshape(masks, (1, batch_size, 1, max_sequence_length, 1))
            all_cache_key = all_cache_key * masks
            all_cache_value = all_cache_value * masks

            ## END
            inputs_tuple = [None] * len(self.input_name_list)
            input_shapes_tuple = [tf.TensorShape([None, None])] * len(self.input_name_list)
            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue

            inputs_tuple = tuple(inputs_tuple)
            input_shapes_tuple = tuple(input_shapes_tuple)

            results = tf.while_loop(
                cond,
                body,
                maximum_iterations=self.max_iterations - 1,
                loop_vars=[
                    i,
                    inputs_tuple,
                    all_cache_key,
                    all_cache_value,
                    model_outputs['past_length'],
                    initial_id,
                ],
                shape_invariants=[
                    i.get_shape(),
                    input_shapes_tuple,
                    tf.TensorShape(
                        [
                            self.num_layers,
                            None,
                            self.num_attention_heads,
                            None,
                            self.attention_state,
                        ]
                    ),
                    tf.TensorShape(
                        [
                            self.num_layers,
                            None,
                            self.num_attention_heads,
                            None,
                            self.attention_state,
                        ]
                    ),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                ],
            )

            results_dict = {}
            results_dict["iterations"] = results[0]
            results_dict["input_ids"] = input_ids_orig
            # Skip -1 initial ids
            results_dict["predicted_ids"] = results[-1][:, 1:]
            # Add matched positions here
            matched_positions = tf.argmax(tf.cast(tf.equal(self.eos_id,
                                                       results_dict["predicted_ids"]), tf.int64), axis=1)
            # no eos matched positions will be 0, replace with -1
            eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int64) * -1
            matched_positions = tf.cast(matched_positions, tf.int64) + eos_pos_mask
            results_dict['matched_eos_pos'] = matched_positions
            results_dict["predicted_ids"] = tf.expand_dims(results[-1][:, 1:], 1)

            return results_dict

        return call_greedy


    def beam(self):
        """
        This function will perform beam decoding.
        """

        # EOS check function
        def cond(i, input_ids, all_cache_key, all_cache_value, past_length,  alive_log_probs, alive_seq):
            eos_check = tf.greater(tf.reduce_prod(tf.reduce_sum(tf.cast(tf.equal(alive_seq, self.eos_id),
                                                                tf.int32), axis=[2])), 0)
            return tf.not_equal(eos_check, True)

        def body(i,
                 inputs_tuple,
                 all_cache_key,
                 all_cache_value, past_length, alive_log_probs, alive_seq):
            """[This is the body of the beam decoder]

            Args:
                i ([tf.tensor]): [iterator (an int)]
                inputs ([List of model inputs]): [description]
                all_cache_key ([K]): [description]
                all_cache_value ([V]): [description]
                past_length ([tf.tensor (1 x batch_size)]): [description]
                This is our main output or decoded ids]
                alive_log_probs ([tf.tensor]): [To keep track of active ids]
                alive_seq ([tf.tensor]): [description]

            Returns:
                [List of tensors]: [Outputs]
            """
            inputs = {}
            for k in range(len(self.input_name_list)):
                inputs[self.input_name_list[k]] = inputs_tuple[k]
            inputs['all_cache_key'] = all_cache_key
            inputs['all_cache_value'] = all_cache_value
            inputs['past_length'] = past_length

            beams_to_keep = 2 * self.beam_size
            model_outputs = self.model(inputs)

            model_logits = model_outputs['last_token_logits']

            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            past_length = model_outputs["past_length"]

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            vocab_size = tf.shape(model_logits)[1]
            batch_size = tf.shape(inputs['input_ids'])[0]//self.beam_size
            logits = tf.reshape(model_logits, (batch_size, self.beam_size, -1))
            # # Convert logits to normalized log probs
            candidate_log_probs = _log_prob_from_logits(logits)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, 2)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)

            # Each batch item has beam_size * vocab_size candidate sequences. For each
            # batch item, get the k candidates with the highest log probabilities.
            flat_log_probs = tf.reshape(log_probs,
                                        [-1, self.beam_size * vocab_size])


            if self.do_sample:
                next_tokens = tf.random.categorical(
                    flat_log_probs, dtype=tf.int32, num_samples=beams_to_keep
                )  # (batch_size, 2 * num_beams)

                # # Compute next scores
                next_scores = tf.gather(flat_log_probs, next_tokens, batch_dims=1)  # (batch_size, 2 * num_beams)

                # # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
                next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)
                next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)

                topk_log_probs = next_scores
                topk_indices   = next_tokens
            else:
                topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep)

            topk_beam_indices = topk_indices // vocab_size
            topk_seq, coordinates = _gather_beams(
                alive_seq, topk_beam_indices, batch_size,
                beams_to_keep)
            topk_seq = tf.cast(topk_seq, tf.int32)
            topk_ids = topk_indices % vocab_size
            topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_alive_seq  = topk_seq[:, :self.beam_size, :]
            alive_log_probs = topk_log_probs[:, :self.beam_size]
            input_ids = tf.reshape(topk_ids[:, :self.beam_size], [-1, 1])
            alive_seq = topk_alive_seq

            inputs_tuple = [None] * len(self.input_name_list)

            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
            # Convert to tuple
            inputs_tuple = tuple(inputs_tuple)

            all_cache_key , all_cache_value = self.reorder_past_batches(all_cache_key, all_cache_value, coordinates, self.beam_size)
            model_outputs["all_cache_key"] = all_cache_key
            model_outputs["all_cache_value"] = all_cache_value

            return [
                i + 1,
                inputs_tuple,
                model_outputs["all_cache_key"],
                model_outputs["all_cache_value"],
                model_outputs["past_length"],
                alive_log_probs,
                alive_seq
            ]

        # @tf.function(experimental_relax_shapes=True)
        def call_beam(inputs):
            """The main function to perform beam search
            Args:
                inputs ([dict]): [dict of tf.tensors (model inputs)]
            """
            input_ids_orig = inputs["input_ids"]
            # We take 2x beams
            beams_to_keep = 2 * self.beam_size
            # Original batch size and sequence length
            batch_size = tf.shape(inputs['input_ids'])[0]
            max_sequence_length = tf.shape(inputs['input_ids'])[1]
            # Repeat for beam search (We nedd batch_size x beam_size)
            model_inputs = {}
            for input_key, input_value in inputs.items():
                model_inputs[input_key] = tf.repeat(input_value, [self.beam_size], axis=0)
            # New batch size
            batch_size_updated = tf.shape(model_inputs['input_ids'])[0]


            # Pre-initialize addtional inputs
            zero_entry  = tf.zeros((self.num_layers,
                                    batch_size_updated,
                                    self.num_attention_heads,
                                    max_sequence_length, self.attention_state))
            all_cache_key   = zero_entry
            all_cache_value = zero_entry
            # past_length for keeping track of positional ids
            past_length = tf.expand_dims(tf.zeros(batch_size_updated, dtype=tf.int32), 0)
            # Iterator to keep track of the loop
            i = tf.constant([[0]])

            # Add remaining model inputs
            model_inputs['all_cache_key'] = all_cache_key
            model_inputs['all_cache_value'] = all_cache_value
            model_inputs['past_length'] = past_length

            # We need this to re-ordering and keep track of best -log(prob))
            alive_log_probs = -np.inf * tf.ones((batch_size, self.beam_size-1))
            alive_log_probs = tf.concat([tf.zeros([batch_size, 1]), alive_log_probs], axis=1)
            alive_seq = tf.zeros((batch_size, self.beam_size, 1))

            # First pass to the model
            model_outputs = self.model(model_inputs)
            model_logits = model_outputs['last_token_logits']
            # Update iter
            i = i + 1
            all_cache_key = model_outputs["all_cache_key"]
            all_cache_value = model_outputs["all_cache_value"]
            past_length = model_outputs["past_length"]

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            # vocab size
            vocab_size = tf.shape(model_logits)[1]
            logits = tf.reshape(model_logits, (batch_size, self.beam_size, -1))
            # # Convert logits to normalized log probs
            candidate_log_probs = _log_prob_from_logits(logits)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, 2)

            # Calculate new log probabilities if each of the alive sequences were
            # extended # by the the candidate IDs.
            # Shape [batch_size, beam_size, vocab_size]
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)

            # Each batch item has beam_size * vocab_size candidate sequences. For each
            # batch item, get the k candidates with the highest log probabilities.
            flat_log_probs = tf.reshape(log_probs,
                                        [-1, self.beam_size * vocab_size])


            if self.do_sample:
                next_tokens = tf.random.categorical(
                    flat_log_probs, dtype=tf.int32, num_samples=beams_to_keep
                )  # (batch_size, 2 * num_beams)

                # # Compute next scores
                next_scores = tf.gather(flat_log_probs, next_tokens, batch_dims=1)  # (batch_size, 2 * num_beams)

                # # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
                next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)
                next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)

                topk_log_probs = next_scores
                topk_indices   = next_tokens
            else:
                topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep)

            topk_beam_indices = topk_indices // vocab_size
            topk_seq, coordinates = _gather_beams(
                alive_seq, topk_beam_indices, batch_size,
                beams_to_keep)
            topk_seq = tf.cast(topk_seq, tf.int32)
            topk_ids = topk_indices % vocab_size
            topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_alive_seq  = topk_seq[:, :self.beam_size, :]
            alive_log_probs = topk_log_probs[:, :self.beam_size]
            input_ids = tf.reshape(topk_ids[:, :self.beam_size], [-1, 1])
            alive_seq = topk_alive_seq

            inputs_tuple = [None] * len(self.input_name_list)
            input_shapes_tuple = [tf.TensorShape([None, None])] * len(self.input_name_list)
            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue

            inputs_tuple = tuple(inputs_tuple)
            input_shapes_tuple = tuple(input_shapes_tuple)

            # on step 0

            masks = tf.cast(tf.not_equal(model_inputs['input_ids'], -1), tf.float32)
            masks = tf.reshape(masks, (1, batch_size_updated, 1, tf.shape(model_inputs['input_ids'])[1], 1))
            all_cache_key = all_cache_key * masks
            all_cache_value = all_cache_value * masks

            all_cache_key , all_cache_value = self.reorder_past_batches(all_cache_key, all_cache_value, coordinates, self.beam_size)

            ## END
            results = tf.while_loop(
                cond,
                body,
                maximum_iterations=self.max_iterations - 1,
                loop_vars=[
                    i,
                    inputs_tuple,
                    all_cache_key,
                    all_cache_value,
                    past_length,
                    alive_log_probs,
                    alive_seq
                ],
                shape_invariants=[
                    i.get_shape(),
                    input_shapes_tuple,
                    tf.TensorShape(
                        [
                            self.num_layers,
                            None,
                            self.num_attention_heads,
                            None,
                            self.attention_state,
                        ]
                    ),
                    tf.TensorShape(
                        [
                            self.num_layers,
                            None,
                            self.num_attention_heads,
                            None,
                            self.attention_state,
                        ]
                    ),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None,None])
                ],
            )

            results_dict = {}
            results_dict["iterations"] = results[0]
            results_dict["input_ids"] = input_ids_orig
            # Skip -1 initial ids
            results_dict["predicted_ids"] = results[-1][:,:,1:] # to remove initial 0

            matched_positions = tf.squeeze(tf.reshape(tf.argmax(tf.cast(tf.equal(self.eos_id,
                                results_dict["predicted_ids"]), tf.int32), axis=2),
                                (-1, batch_size * self.beam_size)), [0]) -1
            # no eos matched positions will be 0, replace with -1
            eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1
            matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask
            results_dict["matched_eos_pos"] = matched_positions


            return results_dict

        return call_beam

    def top_k_top(self):

        # EOS check function
        def cond(i, input_ids, all_cache_key, all_cache_value, past_length, initial_id):
            eos_check = tf.greater(
                tf.reduce_prod(
                    tf.reduce_sum(
                        tf.cast(tf.equal(initial_id, self.eos_id), tf.int32), axis=[1]
                    )
                ),
                0,
            )
            return tf.not_equal(eos_check, True)

        def body(i, inputs_tuple, all_cache_key, all_cache_value, past_length, initial_id):
            """[This is the body of the top k top p decoder]

            Args:
                i ([tf.tensor]): [iterator (an int)]
                inputs ([List of model inputs]): [description]
                all_cache_key ([K]): [description]
                all_cache_value ([V]): [description]
                past_length ([tf.tensor (1 x batch_size)]): [description]
                This is our main output or decoded ids]
                initial_id ([tf.tensor]): [To keep track of concatanted ids generated in each iteration]

            Returns:
                [List of tensors]: [Outputs]
            """
            inputs = {}
            for k in range(len(self.input_name_list)):
                inputs[self.input_name_list[k]] = inputs_tuple[k]
            inputs['all_cache_key'] = all_cache_key
            inputs['all_cache_value'] = all_cache_value
            inputs['past_length'] = past_length

            model_outputs = self.model(inputs)
            model_logits = model_outputs['last_token_logits']

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            if self.do_sample:
                prediction_ids = tf.random.categorical(model_logits, num_samples=1)
                input_ids = tf.cast(prediction_ids, tf.int32)
            else:
                prediction_ids = tf.argmax(model_logits, axis=1)
                input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32)

            inputs_tuple = [None] * len(self.input_name_list)

            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
            # Convert to tuple
            inputs_tuple = tuple(inputs_tuple)
            return [
                i + 1,
                inputs_tuple,
                model_outputs["all_cache_key"],
                model_outputs["all_cache_value"],
                model_outputs["past_length"],
                tf.concat([initial_id, input_ids], axis=1),
            ]

        def call_top_k_top_p(inputs):
            """The main function to perform Top K top P (Nucleus) decoding
            Args:
                inputs ([dict]): [dict of tf.tensors (model inputs)]
            """
            input_ids_orig = inputs["input_ids"]
            batch_size = tf.shape(inputs['input_ids'])[0]
            max_sequence_length = tf.shape(inputs['input_ids'])[1]
            model_inputs = {}
            for input_key, input_value in inputs.items():
                model_inputs[input_key] = tf.repeat(input_value, [self.num_return_sequences], axis=0)
            # Updated batch size
            batch_size_updated = tf.shape(model_inputs['input_ids'])[0]

            # Pre-initialize addtional inputs
            zero_entry  = tf.zeros((self.num_layers,
                                    batch_size_updated,
                                    self.num_attention_heads,
                                    max_sequence_length, self.attention_state))
            all_cache_key   = zero_entry
            all_cache_value = zero_entry
            # past_length for keeping track of positional ids
            past_length = tf.expand_dims(tf.zeros(batch_size_updated, dtype=tf.int32), 0)
            # Iterator to keep track of the loop
            i = tf.constant([[0]])
            initial_id = tf.ones(shape=(batch_size_updated, 1), dtype=tf.int32)

            # Add remaining model inputs
            model_inputs['all_cache_key'] = all_cache_key
            model_inputs['all_cache_value'] = all_cache_value
            model_inputs['past_length'] = past_length

            # First pass to the model
            model_outputs = self.model(model_inputs)
            model_logits = model_outputs['last_token_logits']

            if self.top_k > 0:
                model_logits = top_k_logits(model_logits, k=self.top_k)
            if self.top_p > 0:
                model_logits = top_p_logits(model_logits, p=self.top_p)

            if self.do_sample:
                prediction_ids = tf.random.categorical(model_logits, num_samples=1)
                input_ids = tf.cast(prediction_ids, tf.int32)
            else:
                prediction_ids = tf.argmax(model_logits, axis=1)
                input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32)
            inputs_tuple = [None] * len(self.input_name_list)
            input_shapes_tuple = [tf.TensorShape([None, None])] * len(self.input_name_list)
            for index, name in self.input_name_map.items():
                if name == "input_ids":
                    inputs_tuple[index] = input_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_type_ids":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_type_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue
                if name == "input_mask":
                    inputs_tuple[index] = tf.ones_like(input_ids) * self.input_mask_ids
                    # input_shapes_tuple.append(tf.TensorShape([None, None]))
                    continue

            inputs_tuple = tuple(inputs_tuple)
            input_shapes_tuple = tuple(input_shapes_tuple)

            # Concatanate
            initial_id = tf.concat([initial_id, input_ids], axis=1)

            # on step 0

            masks = tf.cast(tf.not_equal(model_inputs['input_ids'], -1), tf.float32)
            masks = tf.reshape(masks, (1, batch_size_updated, 1, tf.shape(model_inputs['input_ids'])[1], 1))

            all_cache_key = model_outputs['all_cache_key']
            all_cache_value = model_outputs['all_cache_value']
            all_cache_key = all_cache_key * masks
            all_cache_value = all_cache_value * masks
            ## END

            results = tf.while_loop(
                cond,
                body,
                maximum_iterations=self.max_iterations - 1,
                loop_vars=[
                    i,
                    inputs_tuple,
                    all_cache_key,
                    all_cache_value,
                    model_outputs['past_length'],
                    initial_id,
                ],
                shape_invariants=[
                    i.get_shape(),
                    input_shapes_tuple,
                    tf.TensorShape(
                        [
                            self.num_layers,
                            None,
                            self.num_attention_heads,
                            None,
                            self.attention_state,
                        ]
                    ),
                    tf.TensorShape(
                        [
                            self.num_layers,
                            None,
                            self.num_attention_heads,
                            None,
                            self.attention_state,
                        ]
                    ),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                ],
            )

            results_dict = {}
            results_dict["iterations"] = results[0]
            results_dict["input_ids"] = input_ids_orig
            # Skip -1 initial ids
            results_dict["predicted_ids"] = results[-1][:, 1:]
            results_dict["predicted_ids"] = tf.reshape(results_dict["predicted_ids"],(batch_size, self.num_return_sequences, -1))

            matched_positions = tf.squeeze(tf.reshape(tf.argmax(tf.cast(tf.equal(self.eos_id,
                                results_dict["predicted_ids"]), tf.int32), axis=2),
                                (-1, batch_size * self.num_return_sequences)), [0]) -1
            # no eos matched positions will be 0, replace with -1
            eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1
            matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask
            results_dict["matched_eos_pos"] = matched_positions

            return results_dict
        return call_top_k_top_p

    def call(self, inputs):
        results_dict = self.decoder_fn(inputs)
        return results_dict



In [49]:
decoder_layer_serializable = TextDecoderSerializable(
    model_tf_transformers,
    input_name_list=["input_ids", "input_mask", "input_type_ids"],
    max_iterations=25,
    num_attention_heads=12,
    num_layers=12,
    attention_state=64,
    mode="beam",
    do_sample=False,
    beam_size=2,
    eos_id=-100,
    input_mask_ids=1,
    input_type_ids=0,
)

inputs_for_serializable = tokenizer_fn(text_list)

main_inputs = {}
for k, v in inputs_for_serializable.items():
    main_inputs[k] = tf.ragged.constant(v)
main_inputs['input_ids'] = main_inputs['input_ids'].to_tensor(-1)
main_inputs['input_mask'] = main_inputs['input_mask'].to_tensor(0)
main_inputs['input_type_ids'] = main_inputs['input_type_ids'].to_tensor(0)

start_time = time.time()
results_serializable_beam = decoder_layer_serializable(main_inputs)
end_time = time.time()
print("Time taken {} seconds".format(end_time - start_time))

decoder_model  = decoder_layer_serializable.get_model()
decoder_module = LegacyModule(decoder_model)
decoder_module.save("model_temp_pb")

Time taken 3.0038630962371826 seconds
INFO:tensorflow:Assets written to: model_temp_pb/assets


INFO:tensorflow:Assets written to: model_temp_pb/assets


In [64]:
decoder_layer_serializable = TextDecoderSerializable(
    model_tf_transformers,
    input_name_list=["input_ids", "input_mask", "input_type_ids"],
    max_iterations=25,
    num_attention_heads=12,
    num_layers=12,
    attention_state=64,
    mode="top_k_top_p",
    do_sample=False,
    num_return_sequences=2,
    top_k = 50,
    top_p = 0.7,
    eos_id=-100,
    input_mask_ids=1,
    input_type_ids=0,
)

inputs_for_serializable = tokenizer_fn(text_list)

main_inputs = {}
for k, v in inputs_for_serializable.items():
    main_inputs[k] = tf.ragged.constant(v)
main_inputs['input_ids'] = main_inputs['input_ids'].to_tensor(-1)
main_inputs['input_mask'] = main_inputs['input_mask'].to_tensor(0)
main_inputs['input_type_ids'] = main_inputs['input_type_ids'].to_tensor(0)

start_time = time.time()
results_serializable_top_k_top_p = decoder_layer_serializable(main_inputs)
end_time = time.time()
print("Time taken {} seconds".format(end_time - start_time))

decoder_model  = decoder_layer_serializable.get_model()
decoder_module = LegacyModule(decoder_model)
decoder_module.save("model_temp_pb")

Time taken 2.9446468353271484 seconds
INFO:tensorflow:Assets written to: model_temp_pb/assets


INFO:tensorflow:Assets written to: model_temp_pb/assets


In [65]:
tf.assert_equal(result_greedy['predicted_ids'], results_serializable_greedy['predicted_ids'])
tf.assert_equal(result_beam['predicted_ids'], results_serializable_beam['predicted_ids'])
tf.assert_equal(results_serializable_top_k_top_p['predicted_ids'], result_top_k_top_p['predicted_ids'])