# Entity Prediction
---

#### _Given a user query and context, predict the entity that the user is looking for_


## First... Install ml4ir

In [1]:
!pip install ml4ir

Looking in indexes: https://pypi.python.org/simple






Collecting pyarrow<0.16.0,>=0.15.1; python_version >= "3.0" or platform_system != "Windows"
  Using cached pyarrow-0.15.1-cp37-cp37m-macosx_10_6_intel.whl (36.6 MB)


[31mERROR: tfx-bsl 0.15.3 has requirement absl-py<0.9,>=0.7, but you'll have absl-py 0.9.0 which is incompatible.[0m
[31mERROR: tfx-bsl 0.15.3 has requirement apache-beam[gcp]<2.17,>=2.16, but you'll have apache-beam 2.19.0 which is incompatible.[0m
[31mERROR: tfx-bsl 0.15.3 has requirement pyarrow<0.15.0,>=0.14.0, but you'll have pyarrow 0.15.1 which is incompatible.[0m
[31mERROR: tensorflow-transform 0.15.0 has requirement absl-py<0.9,>=0.7, but you'll have absl-py 0.9.0 which is incompatible.[0m
[31mERROR: apache-beam 2.19.0 has requirement dill<0.3.2,>=0.3.1.1, but you'll have dill 0.3.0 which is incompatible.[0m
[31mERROR: apache-beam 2.19.0 has requirement httplib2<=0.12.0,>=0.8, but you'll have httplib2 0.17.0 which is incompatible.[0m
Installing collected packages: pyarrow
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 0.14.1
    Uninstalling pyarrow-0.14.1:
      Successfully uninstalled pyarrow-0.14.1
Successfully installed pyarrow-0.15.1

## Step 0: Looking at the Data

In [2]:
from ml4ir.base.io.file_io import FileIO
from ml4ir.base.io.local_io import LocalIO
import glob
import logging
import pandas as pd
import os
import tensorflow as tf

# Pandas options
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

# Setup logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
tf.get_logger().setLevel('INFO')
logging.debug("Logger is initialized...")

# Define FileIO
file_io: FileIO = LocalIO(logger)

# Load data
CSV_DATA_DIR = '../ml4ir/applications/classification/tests/data/csv'

df = file_io.read_df(os.path.join(CSV_DATA_DIR, "train", "file_0.csv"))

logger.info(df.shape)

df.head(10)

DEBUG:root:Logger is initialized...
INFO:root:Loading dataframe from path : ../ml4ir/applications/classification/tests/data/csv/train/file_0.csv
INFO:root:(700, 5)


Unnamed: 0,query_key,query_text,domain_id,user_context,entity_id
0,query_id_0,yourself,Y,"EEE,BBB,AAA,GGG,FFF,FFF,AAA,CCC,CCC,FFF,FFF,DD...",AAA
1,query_id_1,struck entire the come thanks,B,"CCC,CCC,AAA",DDD
2,query_id_2,sick unfold am prince you,Q,"DDD,FFF,AAA,GGG,GGG,HHH,GGG,FFF,AAA,CCC,BBB,HH...",AAA
3,query_id_3,bitter twelve if upon of him,U,"AAA,FFF,DDD,GGG,AAA,EEE,HHH,DDD,HHH,CCC,CCC,HHH",DDD
4,query_id_4,tragedy,O,"AAA,EEE,FFF,EEE,GGG,GGG,AAA",EEE
5,query_id_5,quiet,W,"BBB,GGG,AAA",AAA
6,query_id_6,friends bid thee hamlet most,P,"FFF,GGG",AAA
7,query_id_7,bid his for entire dane hear,L,"BBB,AAA,GGG,EEE,CCC,EEE,GGG,FFF",EEE
8,query_id_8,guard platform watch,Z,"GGG,GGG,AAA,CCC,HHH,AAA,FFF,HHH,HHH,EEE,GGG,AA...",AAA
9,query_id_9,had stand whos,I,"AAA,BBB,DDD,DDD,AAA,DDD,EEE,DDD,BBB,FFF,BBB,FF...",DDD


### Define the FeatureConfig

In [3]:
# Set up the feature configurations
from ml4ir.base.features.feature_config import FeatureConfig, ExampleFeatureConfig
from ml4ir.base.config.keys import TFRecordTypeKey
import json
import yaml

feature_config_yaml = '''
query_key: 
  name: query_key
  node_name: query_key
  trainable: false
  dtype: string
  log_at_inference: true
  feature_layer_info:
    type: numeric
    shape: null
  serving_info:
    name: query_key
    required: false
  default_value: ""
label:
  name: entity_id
  node_name: entity_id
  trainable: false
  dtype: string
  shape: 
    - 1
    - null
  log_at_inference: true
  preprocessing_info:
    - fn: one_hot_vectorize_label
      args:
        vocabulary_file: {0}
        num_oov_buckets: 1
  feature_layer_info:
    type: numeric
    fn: categorical_indicator_with_vocabulary_file
    args:
      vocabulary_file: {0}
      num_oov_buckets: 1
  serving_info:
    name: entity_id
    required: false
  default_value: ""
features:
  - name: query_text
    node_name: query_text
    trainable: true
    dtype: string
    log_at_inference: true
    feature_layer_info:
      type: numeric
      shape: null
      fn: bytes_sequence_to_encoding_bilstm
      args:
        encoding_type: bilstm
        encoding_size: 128
        embedding_size: 128
        max_length: 20
    preprocessing_info:
      - fn: preprocess_text
        args:
          remove_punctuation: true
          to_lower: true
    serving_info:
      name: query_text
      required: true
    default_value: ""
  - name: domain_id
    node_name: domain_id
    trainable: true
    dtype: string
    log_at_inference: true
    is_group_metric_key: true
    feature_layer_info:
      type: numeric
      shape: null
      # fn: categorical_embedding_with_hash_buckets
      # args:
      #   num_hash_buckets: 4
      #   hash_bucket_size: 64
      #   embedding_size: 32
      #   merge_mode: concat
      fn: categorical_embedding_with_vocabulary_file
      args:
        vocabulary_file: {1}
        embedding_size: 64
        default_value: -1
        num_oov_buckets: 1
    serving_info:
      name: domain_id
      required: true
    default_value: ""
  - name: user_context
    node_name: user_context
    trainable: true
    dtype: string
    shape:
      - 1
      - {2}
    log_at_inference: true
    is_group_metric_key: true
    preprocessing_info:
      - fn: split_string
        args:
          split_char: ","
          max_length: {2}
    feature_layer_info:
      type: numeric
      shape: null
      # fn: categorical_sequence_bilstm_embedding
      # args:
      #   num_hash_buckets: 4
      #   hash_bucket_size: 64
      #   embedding_size: 32
      #   merge_mode: concat
      fn: smart_scope_embedding_bilstm_encoding
      args:
        vocabulary_file: {0}
        embedding_size: 64
        encoding_size: 64
        num_oov_buckets: 1
        max_length: {2}
    serving_info:
      name: user_context
      required: true
    default_value: ""
'''.format(
    os.path.join(CSV_DATA_DIR, '../configs/vocabulary', 'entity_id.csv'),
    os.path.join(CSV_DATA_DIR, '../configs/vocabulary', 'domain_id.csv'),
    20
)
feature_config: ExampleFeatureConfig = FeatureConfig.get_instance(
    tfrecord_type=TFRecordTypeKey.EXAMPLE,
    feature_config_dict=yaml.safe_load(feature_config_yaml),
    logger=logger)

DEBUG:root:{
    "query_key": {
        "name": "query_key",
        "node_name": "query_key",
        "trainable": false,
        "dtype": "string",
        "log_at_inference": true,
        "feature_layer_info": {
            "type": "numeric",
            "shape": null
        },
        "serving_info": {
            "name": "query_key",
            "required": false
        },
        "default_value": ""
    },
    "label": {
        "name": "entity_id",
        "node_name": "entity_id",
        "trainable": false,
        "dtype": "string",
        "shape": [
            1,
            null
        ],
        "log_at_inference": true,
        "preprocessing_info": [
            {
                "fn": "one_hot_vectorize_label",
                "args": {
                    "vocabulary_file": "../ml4ir/applications/classification/tests/data/csv/../configs/entity_id_vocab.csv",
                    "num_oov_buckets": 1
                }
            }
        ],
        "feature_layer

### Step 2: Load the RelevanceDataset

In [4]:
from ml4ir.base.data.relevance_dataset import RelevanceDataset
from ml4ir.base.config.keys import DataFormatKey
from ml4ir.base.features.feature_fns.categorical import categorical_indicator_with_vocabulary_file
from tensorflow import image
from tensorflow import print as tfprint
import tensorflow as tf


def get_one_hot_vectorizer(feature_info, file_io: FileIO):
    label_str = tf.keras.Input(shape=(1,), dtype=tf.string)
    label_one_hot = categorical_indicator_with_vocabulary_file(label_str, feature_info, file_io)
    one_hot_vectorizer = tf.keras.Model(inputs=label_str, outputs=label_one_hot)
    
    @tf.function
    def one_hot_vectorize(feature_tensor, **kwargs):
        return tf.squeeze(one_hot_vectorizer(feature_tensor), axis=[0])
    
    return one_hot_vectorize

@tf.function
def split_string(feature_tensor, split_char=",", max_length=20, **kwargs):
    tokens = tf.strings.split(feature_tensor, sep=split_char).to_tensor()
    padded_tokens = image.pad_to_bounding_box(tf.expand_dims(tokens[:, :max_length], axis=-1),
                                              offset_height=0,
                                              offset_width=0,
                                              target_height=1,
                                              target_width=max_length)
    padded_tokens = tf.squeeze(padded_tokens, axis=-1)
    return padded_tokens

preprocessing_keys_to_fns = {
    "one_hot_vectorize_label": get_one_hot_vectorizer(feature_config.get_label(), file_io),
    "split_string": split_string
}
    
relevance_dataset = RelevanceDataset(
        data_dir=CSV_DATA_DIR,
        data_format=DataFormatKey.CSV,
        feature_config=feature_config,
        tfrecord_type=TFRecordTypeKey.EXAMPLE,
        batch_size=128,
        preprocessing_keys_to_fns=preprocessing_keys_to_fns,
        file_io=file_io,
        logger=logger
    )

tfprint(relevance_dataset.train)
tfprint(relevance_dataset.validation)
tfprint(relevance_dataset.test)

INFO:root:Loading dataframe from path : ../ml4ir/applications/classification/tests/data/csv/../configs/entity_id_vocab.csv


Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.


Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.


Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.


Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.
INFO:root:1 files found under ../ml4ir/applications/classification/tests/data/csv/train
INFO:root:Reading 1 files from [../ml4ir/applications/classification/tests/data/csv/train/file_0.csv, ..
INFO:root:Loading dataframe from path : ../ml4ir/applications/classification/tests/data/csv/train/file_0.csv
INFO:root:Writing SequenceExample protobufs to : ../ml4ir/applications/classification/tests/data/csv/tfrecord/train/file_0.tfrecord
INFO:root:1 files found under ../ml4ir/applications/classification/tests/data/csv/tfrecord/train


{'query_key': FixedLenFeature(shape=[], dtype='string', default_value=''), 'entity_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'query_text': FixedLenFeature(shape=[], dtype='string', default_value=''), 'domain_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'user_context': FixedLenFeature(shape=[], dtype='string', default_value='')}


INFO:root:Created TFRecordDataset from SequenceExample protobufs from 1 files : ['../ml4ir/applications/classification/tests/data/
INFO:root:1 files found under ../ml4ir/applications/classification/tests/data/csv/validation
INFO:root:Reading 1 files from [../ml4ir/applications/classification/tests/data/csv/validation/file_0.csv, ..
INFO:root:Loading dataframe from path : ../ml4ir/applications/classification/tests/data/csv/validation/file_0.csv
INFO:root:Writing SequenceExample protobufs to : ../ml4ir/applications/classification/tests/data/csv/tfrecord/validation/file_0.tfrecord
INFO:root:1 files found under ../ml4ir/applications/classification/tests/data/csv/tfrecord/validation
INFO:root:Created TFRecordDataset from SequenceExample protobufs from 1 files : ['../ml4ir/applications/classification/tests/data/
INFO:root:1 files found under ../ml4ir/applications/classification/tests/data/csv/test
INFO:root:Reading 1 files from [../ml4ir/applications/classification/tests/data/csv/test/file_0

{'query_key': FixedLenFeature(shape=[], dtype='string', default_value=''), 'entity_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'query_text': FixedLenFeature(shape=[], dtype='string', default_value=''), 'domain_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'user_context': FixedLenFeature(shape=[], dtype='string', default_value='')}
{'query_key': FixedLenFeature(shape=[], dtype='string', default_value=''), 'entity_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'query_text': FixedLenFeature(shape=[], dtype='string', default_value=''), 'domain_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'user_context': FixedLenFeature(shape=[], dtype='string', default_value='')}
<BatchDataset shapes: ({query_key: (128, 1), query_text: (128, 1), domain_id: (128, 1), user_context: (128, 1, 20)}, (128, 1, 10)), types: ({query_key: tf.string, query_text: tf.string, domain_id: tf.string, user_context: tf.string}, tf.float32)>
<Ba

In [5]:
batch = next(iter(relevance_dataset.train))
for col in batch[0]:
    print("\n~~ {} ~~".format(col))
    print(batch[0][col][:5])
print("\n~~ {} ~~".format("entity"))
print(batch[1][:5])


~~ query_key ~~
tf.Tensor(
[[b'query_id_0']
 [b'query_id_1']
 [b'query_id_2']
 [b'query_id_3']
 [b'query_id_4']], shape=(5, 1), dtype=string)

~~ query_text ~~
tf.Tensor(
[[b'yourself']
 [b'struck entire the come thanks']
 [b'sick unfold am prince you']
 [b'bitter twelve if upon of him']
 [b'tragedy']], shape=(5, 1), dtype=string)

~~ domain_id ~~
tf.Tensor(
[[b'Y']
 [b'B']
 [b'Q']
 [b'U']
 [b'O']], shape=(5, 1), dtype=string)

~~ user_context ~~
tf.Tensor(
[[[b'EEE' b'BBB' b'AAA' b'GGG' b'FFF' b'FFF' b'AAA' b'CCC' b'CCC' b'FFF'
   b'FFF' b'DDD' b'CCC' b'AAA' b'' b'' b'' b'' b'' b'']]

 [[b'CCC' b'CCC' b'AAA' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
   b'' b'' b'' b'' b'']]

 [[b'DDD' b'FFF' b'AAA' b'GGG' b'GGG' b'HHH' b'GGG' b'FFF' b'AAA' b'CCC'
   b'BBB' b'HHH' b'EEE' b'CCC' b'FFF' b'FFF' b'' b'' b'' b'']]

 [[b'AAA' b'FFF' b'DDD' b'GGG' b'AAA' b'EEE' b'HHH' b'DDD' b'HHH' b'CCC'
   b'CCC' b'HHH' b'' b'' b'' b'' b'' b'' b'' b'']]

 [[b'AAA' b'EEE' b'FFF' b'EEE' b'GGG' b'GGG' b

### Step 3: Define the InteractionModel

In [6]:
from ml4ir.base.model.scoring.interaction_model import InteractionModel, UnivariateInteractionModel
from ml4ir.base.config.keys import TFRecordTypeKey
from ml4ir.base.features.feature_fns.categorical import VocabLookup

from tensorflow import feature_column
from tensorflow.keras import layers


# Define custom feature layer ops
def smart_scope_embedding_bilstm_encoding(feature_tensor, feature_info, file_io: FileIO):
    args = feature_info.get("feature_layer_info")["args"]
    
    vocabulary_df = file_io.read_df(args["vocabulary_file"])
    vocabulary_keys = vocabulary_df["key"].fillna(feature_info["default_value"]).values
    vocabulary_ids = (
        vocabulary_df["id"].values if "id" in vocabulary_df else list(range(len(vocabulary_keys)))
    )

    num_oov_buckets = args.get("num_oov_buckets", 1)
    vocabulary_size = len(set(vocabulary_ids))
    lookup_table = VocabLookup(
        vocabulary_keys=vocabulary_keys,
        vocabulary_ids=vocabulary_ids,
        num_oov_buckets=num_oov_buckets,
        feature_name=feature_info.get("node_name", feature_info["name"]),
    )
    categorical_indices = lookup_table(feature_tensor)
    categorical_embeddings = layers.Embedding(
                                    input_dim=vocabulary_size + num_oov_buckets,
                                    output_dim=args["embedding_size"],
                                    mask_zero=True,
                                    input_length=args.get("max_length")
                                )(categorical_indices)

    encoding = layers.Bidirectional(
                    layers.LSTM(
                        units=int(args["encoding_size"] / 2), return_sequences=False
                    ),
                    merge_mode="concat",
                )(tf.squeeze(categorical_embeddings, axis=1))
    encoding = tf.expand_dims(encoding, name="smart_scope_encoding", axis=1)
    
    return encoding

feature_layer_fns = {
    "smart_scope_embedding_bilstm_encoding": smart_scope_embedding_bilstm_encoding,
}

interaction_model: InteractionModel = UnivariateInteractionModel(
                                            feature_config=feature_config,
                                            feature_layer_keys_to_fns=feature_layer_fns,
                                            tfrecord_type=TFRecordTypeKey.EXAMPLE,
                                            file_io=file_io)

### Step 4: Define the Loss and Scoring Functions

In [7]:
from ml4ir.base.model.scoring.scoring_model import ScorerBase, RelevanceScorer
from ml4ir.base.model.losses.loss_base import RelevanceLossBase
from tensorflow.keras import layers
from tensorflow.keras import losses
from ml4ir.base.features.feature_fns.categorical import categorical_indicator_with_vocabulary_file

class CustomCategoricalCrossEntropy(RelevanceLossBase):
        
    def get_loss_fn(self, **kwargs):
        """
        Define a softmax cross entropy loss

        """
        cce = losses.CategoricalCrossentropy(reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)

        def _loss_fn(y_true, y_pred):
            # NOTE: Can use any of the metadata features to qualify your loss here
            return cce(y_true, y_pred)

        return _loss_fn

    def get_final_activation_op(self, output_name):
        return lambda logits, mask: layers.Activation("softmax", name=output_name)(logits)

scorer: ScorerBase = RelevanceScorer.from_model_config_file(
    model_config_file=os.path.join(CSV_DATA_DIR, '../configs/model_config.yaml'),
    interaction_model=interaction_model,
    loss=CustomCategoricalCrossEntropy(),
    output_name="relevance_score",
    file_io=file_io)
    
logger.info(json.dumps(scorer.model_config, indent=4))

INFO:root:Reading YAML file from : ../ml4ir/applications/classification/tests/data/csv/../configs/model_config.yaml
INFO:root:{
    "architecture_key": "dnn",
    "layers": [
        {
            "type": "dense",
            "name": "first_dense",
            "units": 256,
            "activation": "relu"
        },
        {
            "type": "dropout",
            "name": "first_dropout",
            "rate": 0.3
        },
        {
            "type": "dense",
            "name": "second_dense",
            "units": 64,
            "activation": "relu"
        },
        {
            "type": "dropout",
            "name": "second_dropout",
            "rate": 0.0
        },
        {
            "type": "dense",
            "name": "final_dense",
            "units": 10,
            "activation": null
        }
    ]
}


### Step 5: Define Metrics and Optimizer

In [8]:
from tensorflow.keras import metrics as kmetrics
from ml4ir.applications.ranking.model.metrics.metric_factory import get_metric


# metrics = ['categorical_accuracy', kmetrics.Precision, get_metric("MRR"), get_metric("ACR")]
metrics = ['categorical_accuracy', kmetrics.Precision]

In [9]:
from tensorflow.keras.optimizers import Optimizer
from ml4ir.base.model.optimizer import get_optimizer
from ml4ir.base.config.keys import OptimizerKey

optimizer: Optimizer = get_optimizer(
                optimizer_key=OptimizerKey.ADAM,
                learning_rate=0.01,
                learning_rate_decay=0.94,
                learning_rate_decay_steps=1000,
                gradient_clip_value=50,
            )

### Step 6: Putting it all together!

In [10]:
from ml4ir.base.model.relevance_model import RelevanceModel
from ml4ir.base.config.keys import OptimizerKey

relevance_model = RelevanceModel(
        feature_config=feature_config,
        scorer=scorer,
        metrics=metrics,
        optimizer=optimizer,
        tfrecord_type=TFRecordTypeKey.EXAMPLE,
        output_name="entity_prediction_score",
        file_io=file_io,
        logger=logger
    )

INFO:root:Loading dataframe from path : ../ml4ir/applications/classification/tests/data/csv/../configs/domain_id_vocab.csv
INFO:root:Loading dataframe from path : ../ml4ir/applications/classification/tests/data/csv/../configs/entity_id_vocab.csv
INFO:root:Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
query_text (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
user_context (InputLayer)       [(None, 1, 20)]      0                                            
__________________________________________________________________________________________________
tf_op_layer_DecodePaddedRaw (Te [(None, 1, 20)]      0           query_text[0][0]                 
__________________________________

### Step 7: Train the Model

In [11]:
if not os.path.exists('../models'):
    os.makedirs('../models')
if not os.path.exists('../logs'):
    os.makedirs('../logs')

relevance_model.fit(relevance_dataset, 
                    num_epochs=5, 
                    models_dir='../models',
                    logs_dir='../logs',
                    monitor_metric='val_categorical_accuracy',
                    monitor_mode='max')

INFO:root:Training Model
INFO:root:Starting Epoch : 1
INFO:root:{}


Epoch 1/5


INFO:root:[epoch: 1 | batch: 0] {'batch': 0, 'size': 128, 'loss': 2.2914057, 'categorical_accuracy': 0.140625, 'precision': 0.0}






      5/Unknown - 9s 2s/step - loss: 1.9248 - categorical_accuracy: 0.2094 - precision: 0.0000e+00

INFO:root:Evaluating Model
INFO:root:Completed evaluating model
INFO:root:None



Epoch 00001: val_categorical_accuracy improved from -inf to 0.18750, saving model to ../models/checkpoint.tf
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Assets written to: ../models/checkpoint.tf/assets


INFO:tensorflow:Assets written to: ../models/checkpoint.tf/assets
INFO:root:End of Epoch 1
INFO:root:{'loss': 1.9248095273971557, 'categorical_accuracy': 0.209375, 'precision': 0.0, 'val_loss': 1.7251148223876953, 'val_categorical_accuracy': 0.1875, 'val_precision': 0.0}




INFO:root:Starting Epoch : 2
INFO:root:{}


Epoch 2/5


INFO:root:[epoch: 2 | batch: 0] {'batch': 0, 'size': 128, 'loss': 1.6742113, 'categorical_accuracy': 0.28125, 'precision': 0.0}




INFO:root:Evaluating Model
INFO:root:Completed evaluating model
INFO:root:None



Epoch 00002: val_categorical_accuracy improved from 0.18750 to 0.19531, saving model to ../models/checkpoint.tf
INFO:tensorflow:Assets written to: ../models/checkpoint.tf/assets


INFO:tensorflow:Assets written to: ../models/checkpoint.tf/assets
INFO:root:End of Epoch 2
INFO:root:{'loss': 1.6542035341262817, 'categorical_accuracy': 0.259375, 'precision': 0.0, 'val_loss': 1.633194923400879, 'val_categorical_accuracy': 0.1953125, 'val_precision': 0.0}




INFO:root:Starting Epoch : 3
INFO:root:{}


Epoch 3/5


INFO:root:[epoch: 3 | batch: 0] {'batch': 0, 'size': 128, 'loss': 1.5382508, 'categorical_accuracy': 0.2890625, 'precision': 0.0}




INFO:root:Evaluating Model
INFO:root:Completed evaluating model
INFO:root:None



Epoch 00003: val_categorical_accuracy improved from 0.19531 to 0.21875, saving model to ../models/checkpoint.tf
INFO:tensorflow:Assets written to: ../models/checkpoint.tf/assets


INFO:tensorflow:Assets written to: ../models/checkpoint.tf/assets
INFO:root:End of Epoch 3
INFO:root:{'loss': 1.5646453619003295, 'categorical_accuracy': 0.2859375, 'precision': 0.4, 'val_loss': 1.6941783428192139, 'val_categorical_accuracy': 0.21875, 'val_precision': 0.0}




INFO:root:Starting Epoch : 4
INFO:root:{}


Epoch 4/5


INFO:root:[epoch: 4 | batch: 0] {'batch': 0, 'size': 128, 'loss': 1.4908521, 'categorical_accuracy': 0.3359375, 'precision': 0.0}




INFO:root:Evaluating Model
INFO:root:Completed evaluating model
INFO:root:None
INFO:root:End of Epoch 4
INFO:root:{'loss': 1.5353684186935426, 'categorical_accuracy': 0.3046875, 'precision': 0.30769232, 'val_loss': 1.6862261295318604, 'val_categorical_accuracy': 0.1640625, 'val_precision': 0.0}



Epoch 00004: val_categorical_accuracy did not improve from 0.21875


INFO:root:Starting Epoch : 5
INFO:root:{}


Epoch 5/5


INFO:root:[epoch: 5 | batch: 0] {'batch': 0, 'size': 128, 'loss': 1.4402865, 'categorical_accuracy': 0.4296875, 'precision': 0.0}




INFO:root:Evaluating Model
INFO:root:Completed evaluating model
INFO:root:None
INFO:root:End of Epoch 5
INFO:root:{'loss': 1.4775662422180176, 'categorical_accuracy': 0.36875, 'precision': 0.52380955, 'val_loss': 1.7673120498657227, 'val_categorical_accuracy': 0.15625, 'val_precision': 0.1875}



Epoch 00005: val_categorical_accuracy did not improve from 0.21875
Restoring model weights from the end of the best epoch.


INFO:root:Completed training model
INFO:root:None


Epoch 00005: early stopping


### Step 8: Save the Model to disk

In [12]:
MODEL_DIR = '../models/entity_prediction'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

preprocessing_keys_to_fns = {
    "split_string": split_string
}
relevance_model.save(
    models_dir=MODEL_DIR,
    preprocessing_keys_to_fns=preprocessing_keys_to_fns,
    required_fields_only=True)

INFO:tensorflow:Assets written to: ../models/entity_prediction/final/default/assets


INFO:tensorflow:Assets written to: ../models/entity_prediction/final/default/assets


{'query_text': FixedLenFeature(shape=[], dtype='string', default_value=''), 'domain_id': FixedLenFeature(shape=[], dtype='string', default_value=''), 'user_context': FixedLenFeature(shape=[], dtype='string', default_value='')}
INFO:tensorflow:Assets written to: ../models/entity_prediction/final/tfrecord/assets


INFO:tensorflow:Assets written to: ../models/entity_prediction/final/tfrecord/assets
INFO:root:Final model saved to : ../models/entity_prediction/final


In [13]:
!saved_model_cli show --dir ../models/entity_prediction/final/tfrecord/ --all


MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_tfrecord']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['protos'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: serving_tfrecord_protos:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['entity_prediction_score'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1, 10)
        name: StatefulPartitionedCall_2:0
  Method name is: tensorflow/serving/predict


### Step 9: Load the Model and make predictions!

In [14]:
from tensorflow.keras import models as kmodels
from tensorflow import data

model = kmodels.load_model(
    os.path.join(MODEL_DIR, 'final/tfrecord/'),
    compile=False)
infer_fn = model.signatures["serving_tfrecord"]

dataset = data.TFRecordDataset(
    glob.glob(os.path.join(CSV_DATA_DIR, "tfrecord", "test", "*.tfrecord")))
protos = next(iter(dataset.batch(5)))

print("Example proto: \n{}".format(protos[0]))
print("---------------------------------------")

print("\n\nLooking inside the proto:")
e = tf.train.Example()
e.ParseFromString(protos[0].numpy())
print(e)
print("---------------------------------------")

print("\n\n\nPredictions:")
print(infer_fn(protos=protos[:1]))
print("---------------------------------------")

Example proto: 
b'\n\xbe\x01\n\x14\n\tentity_id\x12\x07\n\x05\n\x03AAA\n \n\nquery_text\x12\x12\n\x10\n\x0ea nay act hour\n\x12\n\tdomain_id\x12\x05\n\x03\n\x01G\n\x1b\n\tquery_key\x12\x0e\n\x0c\n\nquery_id_0\nS\n\x0cuser_context\x12C\nA\n?BBB,FFF,HHH,HHH,CCC,HHH,DDD,FFF,EEE,CCC,BBB,CCC,AAA,HHH,BBB,FFF'
---------------------------------------


Looking inside the proto:
features {
  feature {
    key: "domain_id"
    value {
      bytes_list {
        value: "G"
      }
    }
  }
  feature {
    key: "entity_id"
    value {
      bytes_list {
        value: "AAA"
      }
    }
  }
  feature {
    key: "query_key"
    value {
      bytes_list {
        value: "query_id_0"
      }
    }
  }
  feature {
    key: "query_text"
    value {
      bytes_list {
        value: "a nay act hour"
      }
    }
  }
  feature {
    key: "user_context"
    value {
      bytes_list {
        value: "BBB,FFF,HHH,HHH,CCC,HHH,DDD,FFF,EEE,CCC,BBB,CCC,AAA,HHH,BBB,FFF"
      }
    }
  }
}

------------------

In [15]:
# Clean up directories
# NOTE: Run only if you don't want to make any more predictions
file_io.rm_dir(os.path.join(CSV_DATA_DIR, "tfrecord"))

INFO:root:Directory deleted : ../ml4ir/applications/classification/tests/data/csv/tfrecord


![thanks](images/thats_all_folks.gif)