<h2>Train and evaluate data set</h2>

In [None]:
import tensorflow as tf
from tensorflow import data
import tensorflow_transform as tft 
import tensorflow_model_analysis as tfma
import tensorflow_transform.coders as tft_coders

from tensorflow_transform.beam import impl
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils

from tensorflow_transform.tf_metadata import metadata_io
from tensorflow_transform.tf_metadata import dataset_schema
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.saved import saved_transform_io
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
import argparse

import apache_beam as beam

import os
import params
import shutil
import dnn_estimator
import featurizer
import metadata
import input
from datetime import datetime

TRAIN_SIZE = metadata.TRAIN_SIZE
NUM_EPOCHS = metadata.NUM_EPOCHS
BATCH_SIZE = metadata.BATCH_SIZE
TOTAL_STEPS = (TRAIN_SIZE/BATCH_SIZE)*NUM_EPOCHS
EVAL_EVERY_SEC = metadata.EVAL_EVERY_SEC



<h2>Serving function for exporter in JSON format</h2>

In [None]:
def generate_json_serving_fn():

    # get the feature_spec of raw data
    raw_metadata = featurizer.create_raw_metadata()
    raw_placeholder_spec = raw_metadata.schema.as_batched_placeholders()
    raw_placeholder_spec.pop(metadata.TARGET_FEATURE_NAME)

    def _serving_fn():

        raw_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(raw_placeholder_spec)
        raw_features, recevier_tensors, _ = raw_input_fn()

        # apply tranform_fn on raw features
        _, transformed_features = (
            saved_transform_io.partially_apply_saved_transform(
                os.path.join(params.Params.TRANSFORM_ARTEFACTS_DIR, transform_fn_io.TRANSFORM_FN_DIR),
            raw_features)
        )

        # apply the process_features function to transformed features
        transformed_features = input.process_features(transformed_features)

        return tf.estimator.export.ServingInputReceiver(
            transformed_features, raw_features)

    return _serving_fn

<h2>Training and Evaluation spec</h2>
<h3>Includes the input function, Training/Eval mode, num_epocs, batch size, and max_steps</h3>
<h3>Set the path to the <b>transformed</b> input files</h3>

In [None]:
train_spec = tf.estimator.TrainSpec(
    input_fn = input.generate_tfrecords_input_fn(
        params.Params.TRANSFORMED_TRAIN_DATA_FILE_PREFIX+"*",
        mode = tf.estimator.ModeKeys.TRAIN,
        num_epochs=metadata.hparams.num_epochs,
        batch_size=metadata.hparams.batch_size
    ),  
    max_steps=metadata.hparams.max_steps,
    hooks=None
)

eval_spec = tf.estimator.EvalSpec(
    input_fn = input.generate_tfrecords_input_fn(
        params.Params.TRANSFORMED_EVAL_DATA_FILE_PREFIX+"*",
        mode=tf.estimator.ModeKeys.EVAL,
        num_epochs=1,
        batch_size=metadata.hparams.batch_size
    ),  
    exporters=[tf.estimator.LatestExporter(
        name="estimate", # the name of the folder in which the model will be exported to under export
        serving_input_receiver_fn=generate_json_serving_fn(),
        exports_to_keep=1,
        as_text=False)],
    steps=None,
    throttle_secs=EVAL_EVERY_SEC
)



<h2>Run the training job</h2>
<h3>Create the estimator and run train and evaluate</h3>

In [None]:
if params.Params.TRAIN:
    if not params.Params.RESUME_TRAINING:
        print("Removing previous training artefacts...")
        shutil.rmtree(model_dir, ignore_errors=True)
    else:
        print("Resuming training...")


    tf.logging.set_verbosity(tf.logging.INFO)

    time_start = datetime.utcnow()
    print("Experiment started at {}".format(time_start.strftime("%H:%M:%S")))
    print(".......................................")

    estimator = dnn_estimator.create_estimator(run_config, metadata.hparams)

    tf.estimator.train_and_evaluate(
        estimator=estimator,
        train_spec=train_spec,
        eval_spec=eval_spec
    )

    time_end = datetime.utcnow()
    print(".......................................")
    print("Experiment finished at {}".format(time_end.strftime("%H:%M:%S")))
    print("")
    time_elapsed = time_end - time_start
    print("Experiment elapsed time: {} seconds".format(time_elapsed.total_seconds()))
else:
    print "Training was skipped!" 