# Shared Bottom in Tensorflow

## Setup

In [None]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [None]:
import numpy as np
import tensorflow as tf

## Multi-task

In [None]:
def task_network(inputs,
                 hidden_units,
                 hidden_activation=tf.nn.relu,
                 output_activation=tf.nn.sigmoid,
                 hidden_dropout=None,
                 initializer=None):

    x = inputs
    for units in hidden_units:
        x = tf.layers.dense(x,
                            units,
                            activation=hidden_activation,
                            kernel_initializer=initializer)

        if hidden_dropout is not None:
            x = tf.layers.dropout(x, rate=hidden_dropout)

    outputs = tf.layers.dense(x, 1, kernel_initializer=initializer)

    if output_activation is not None:
        outputs = output_activation(outputs)
    return outputs

In [None]:
def multi_task(inputs,
               num_tasks,
               task_hidden_units,
               task_output_activations,
               **kwargs):

    outputs = []

    for i in range(num_tasks):

        task_inputs = inputs[i] if isinstance(inputs, list) else inputs

        output = task_network(task_inputs,
                              task_hidden_units,
                              output_activation=task_output_activations[i],
                              **kwargs)
        outputs.append(output)

    return outputs

## Shared bottom strategy

In [None]:
def _synthetic_data(num_examples, example_dim=100, c=0.3, p=0.8, m=5):

    mu1 = np.random.normal(size=example_dim)
    mu1 = (mu1 - np.mean(mu1)) / (np.std(mu1) * np.sqrt(example_dim))

    mu2 = np.random.normal(size=example_dim)
    mu2 -= mu2.dot(mu1) * mu1
    mu2 /= np.linalg.norm(mu2)

    w1 = c * mu1
    w2 = c * (p * mu1 + np.sqrt(1. - p ** 2) * mu2)

    alpha = np.random.normal(size=m)
    beta = np.random.normal(size=m)

    examples = np.random.normal(size=(num_examples, example_dim))

    w1x = np.matmul(examples, w1)
    w2x = np.matmul(examples, w2)

    sin1, sin2 = 0., 0.
    for i in range(m):
        sin1 += np.sin(alpha[i] * w1x + beta[i])
        sin2 += np.sin(alpha[i] * w2x + beta[i])

    y1 = w1x + sin1 + np.random.normal(size=num_examples, scale=0.01)
    y2 = w2x + sin2 + np.random.normal(size=num_examples, scale=0.01)

    return examples.astype(np.float32), (y1.astype(np.float32), y2.astype(np.float32))

In [None]:
def synthetic_data_input_fn(num_examples, epochs=1, batch_size=256, buffer_size=256, **kwargs):

    synthetic_data = _synthetic_data(num_examples, **kwargs)

    dataset = tf.data.Dataset.from_tensor_slices(synthetic_data)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(epochs)
    dataset = dataset.prefetch(buffer_size)

    return dataset

In [None]:
def _dense(x, units, activation=None, dropout=None, name=None):
    weights = tf.get_variable("w{}".format(name),
                              shape=(x.shape[-1], units),
                              dtype=tf.float32)
    bias = tf.get_variable("b{}".format(name),
                           shape=(units,),
                           dtype=tf.float32,
                           initializer=tf.zeros_initializer())
    x = tf.nn.xw_plus_b(x, weights, bias)

    if dropout is not None:
        x = tf.nn.dropout(x, rate=dropout)

    if activation is not None:
        x = activation(x)

    return x

In [None]:
def shared_bottom(x: tf.Tensor,
                  num_tasks: int,
                  bottom_units: list,
                  task_units: list,
                  task_output_activation: list,
                  bottom_initializer: tf.Tensor = tf.truncated_normal_initializer(),
                  bottom_activation=tf.nn.relu,
                  bottom_dropout: float = None,
                  task_initializer: tf.Tensor = tf.truncated_normal_initializer(),
                  task_dropout: float = None,
                  task_activation=tf.nn.relu):

    with tf.variable_scope("bottom", initializer=bottom_initializer):
        for i, units in enumerate(bottom_units[:-1]):
            x = _dense(x, units, activation=bottom_activation, dropout=bottom_dropout, name=i)

        bottom_out = _dense(x, bottom_units[-1], name="out")

    outputs = []

    for task_idx in range(num_tasks):
        x = bottom_out
        with tf.variable_scope("task{}".format(task_idx), initializer=task_initializer):
            for i, units in enumerate(task_units):
                x = _dense(x, units, activation=task_activation, dropout=task_dropout, name=i)

            task_out = _dense(x, 1, name="out")

            output_activation = task_output_activation[task_idx]
            if output_activation == "sigmoid":
                task_out = tf.nn.sigmoid(task_out)

            outputs.append(task_out)

    return outputs

In [None]:
def shared_bottom_v2(x: tf.Tensor,
                     num_tasks: int,
                     bottom_units: list,
                     task_hidden_units: list,
                     task_output_activations: list,
                     bottom_initializer: tf.Tensor = None,
                     bottom_activation=tf.nn.relu,
                     bottom_dropout: float = None,
                     task_initializer: tf.Tensor = None,
                     task_dropout: float = None,
                     task_activation=tf.nn.relu):

    for i, units in enumerate(bottom_units[:-1]):
        x = tf.layers.dense(x,
                            units,
                            activation=bottom_activation,
                            kernel_initializer=bottom_initializer,
                            name="bottom_dense{}".format(i))

        if bottom_dropout is not None:
            x = tf.layers.dropout(x, rate=bottom_dropout, name="bottom_dropout{}".format(i))

    bottom_out = tf.layers.dense(x, bottom_units[-1], kernel_initializer=bottom_initializer, name="bottom_out")

    outputs = multi_task(bottom_out,
                         num_tasks,
                         task_hidden_units,
                         task_output_activations,
                         hidden_activation=task_activation,
                         hidden_dropout=task_dropout,
                         initializer=task_initializer)

    return outputs

In [None]:
def model_fn(features, labels, mode, params):

    outputs = shared_bottom_v2(features["inputs"],
                               num_tasks=params.get("num_tasks"),
                               bottom_units=params.get("bottom_units"),
                               task_hidden_units=params.get("task_units"),
                               task_output_activations=params.get("task_output_activations"))
    predictions = {
        "y{}".format(i): y
        for i, y in enumerate(outputs)
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    task_losses = params.get("task_losses")

    total_loss = tf.Variable(0., name="total_loss")
    losses = []
    metrics = {}

    for i, (t, y) in enumerate(predictions.items()):

        y = tf.squeeze(y)

        if task_losses[i] == 'log_loss':
            loss = tf.losses.log_loss(labels=labels[i], predictions=y)
            auc_op = tf.metrics.auc(labels=labels, predictions=y, name='auc_op')
            tf.summary.scalar("auc", auc_op[-1])
            metrics["auc"] = auc_op

        elif task_losses[i] == 'mse':
            loss = tf.losses.mean_squared_error(labels=labels[i], predictions=y)

        else:
            loss = tf.losses.mean_squared_error(labels=labels[i], predictions=y)

        losses.append(loss)
        total_loss = total_loss + loss
        metrics["loss_{}".format(t)] = loss
        tf.summary.scalar("loss_{}".format(t), loss)

    tf.summary.scalar("total_loss", total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=total_loss, eval_metric_ops=metrics)

    optimizer = tf.train.AdamOptimizer(learning_rate=params['lr'])

    train_op = tf.group(*[
        optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        for loss in losses
    ])

    return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)

In [None]:
def shared_bottom_estimator(model_dir, inter_op, intra_op, params):

    config_proto = tf.ConfigProto(device_count={'GPU': 0},
                                  inter_op_parallelism_threads=inter_op,
                                  intra_op_parallelism_threads=intra_op)

    run_config = tf.estimator.RunConfig().replace(
        tf_random_seed=42,
        keep_checkpoint_max=10,
        save_checkpoints_steps=200,
        log_step_count_steps=10,
        session_config=config_proto)

    return tf.estimator.Estimator(model_fn=model_fn,
                                  model_dir=model_dir,
                                  params=params,
                                  config=run_config)

## Testing

In [None]:
from absl.testing import parameterized
import sys
import tempfile

In [None]:
tf.disable_eager_execution()
sys.dont_write_bytecode = True
sys.argv = sys.argv[:1]
old_sysexit = sys.exit
tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
class TestSharedBottom(tf.test.TestCase, parameterized.TestCase):

    def test_shared_bottom(self):
        num_examples = 100
        example_dim = 10
        x = tf.random_normal(shape=(num_examples, example_dim))

        with self.session() as sess:
            y = shared_bottom(x, 2, [32, 16], [10, 5], [None, None])
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            sess.run(y)

    @parameterized.parameters(
        {
            "num_tasks": 2,
            "bottom_units": [32, 16],
            "task_units": [10, 5],
            "task_output_activations": [None, None],
            "task_losses": ["mse", "mse"],
            "lr": 0.001
        }
    )
    def test_shared_bottom_estimator(self, **params):

        def _map_fn(x, y):
            return {"inputs": x}, y

        with tempfile.TemporaryDirectory() as temp_dir:
            estimator = shared_bottom_estimator(temp_dir, 8, 8, params)
            estimator.train(input_fn=lambda: synthetic_data_input_fn(1000).map(map_func=_map_fn))

            features = {"inputs": tf.placeholder(tf.float32, (None, 100), name="inputs")}

            serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features)

            estimator.export_saved_model(temp_dir + "/saved_model", serving_input_receiver_fn)

In [None]:
try:
    sys.exit = lambda *args: None
    tf.test.main()
finally:
    sys.exit = old_sysexit

Running tests under Python 3.7.11: /usr/bin/python3
[ RUN      ] TestSharedBottom.test_session
[  SKIPPED ] TestSharedBottom.test_session
[ RUN      ] TestSharedBottom.test_shared_bottom
[       OK ] TestSharedBottom.test_shared_bottom
[ RUN      ] TestSharedBottom.test_shared_bottom_estimator0 (num_tasks=2, bottom_units=[32, 16], task_units=[10, 5], task_output_activations=[None, None], task_losses=['mse', 'mse'], lr=0.001)


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpb88j9chi', '_tf_random_seed': 42, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': device_count {
  key: "GPU"
  value: 0
}
intra_op_parallelism_threads: 8
inter_op_parallelism_threads: 8
, '_keep_checkpoint_max': 10, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 10, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f04ea81cbd0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


I0822 16:12:04.097022 139661090527104 estimator.py:212] Using config: {'_model_dir': '/tmp/tmpb88j9chi', '_tf_random_seed': 42, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': device_count {
  key: "GPU"
  value: 0
}
intra_op_parallelism_threads: 8
inter_op_parallelism_threads: 8
, '_keep_checkpoint_max': 10, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 10, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f04ea81cbd0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
I0822 16:12:05.842091 139661090527104 utils.py:157] Num



W0822 16:12:07.076486 139661090527104 ag_logging.py:146] Entity <function TestSharedBottom.test_shared_bottom_estimator.<locals>._map_fn at 0x7f04ea8b4e60> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4


INFO:tensorflow:Calling model_fn.


I0822 16:12:07.108192 139661090527104 estimator.py:1148] Calling model_fn.


Instructions for updating:
Use keras.layers.Dense instead.


W0822 16:12:07.112539 139661090527104 deprecation.py:323] From <ipython-input-7-b6a55fae1727>:18: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.Dense instead.


Instructions for updating:
Please use `layer.__call__` method instead.


W0822 16:12:07.118262 139661090527104 deprecation.py:323] From /tensorflow-1.15.2/python3.7/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.__call__` method instead.


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


W0822 16:12:07.359186 139661090527104 deprecation.py:323] From /tensorflow-1.15.2/python3.7/tensorflow_core/python/ops/losses/losses_impl.py:121: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


INFO:tensorflow:Done calling model_fn.


I0822 16:12:08.019086 139661090527104 estimator.py:1150] Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


I0822 16:12:08.029508 139661090527104 basic_session_run_hooks.py:541] Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


I0822 16:12:08.579919 139661090527104 monitored_session.py:240] Graph was finalized.


INFO:tensorflow:Running local_init_op.


I0822 16:12:08.812000 139661090527104 session_manager.py:500] Running local_init_op.


INFO:tensorflow:Done running local_init_op.


I0822 16:12:08.855753 139661090527104 session_manager.py:502] Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpb88j9chi/model.ckpt.


I0822 16:12:09.549391 139661090527104 basic_session_run_hooks.py:606] Saving checkpoints for 0 into /tmp/tmpb88j9chi/model.ckpt.


INFO:tensorflow:loss = 11.14214, step = 2


I0822 16:12:10.147409 139661090527104 basic_session_run_hooks.py:262] loss = 11.14214, step = 2


INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmpb88j9chi/model.ckpt.


I0822 16:12:10.400245 139661090527104 basic_session_run_hooks.py:606] Saving checkpoints for 8 into /tmp/tmpb88j9chi/model.ckpt.


INFO:tensorflow:Loss for final step: 8.56104.


I0822 16:12:10.520564 139661090527104 estimator.py:371] Loss for final step: 8.56104.


INFO:tensorflow:Calling model_fn.


I0822 16:12:10.554844 139661090527104 estimator.py:1148] Calling model_fn.


INFO:tensorflow:Done calling model_fn.


I0822 16:12:10.703628 139661090527104 estimator.py:1150] Done calling model_fn.


Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.


W0822 16:12:10.706550 139661090527104 deprecation.py:323] From /tensorflow-1.15.2/python3.7/tensorflow_core/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.


INFO:tensorflow:Signatures INCLUDED in export for Classify: None


I0822 16:12:10.710428 139661090527104 export_utils.py:170] Signatures INCLUDED in export for Classify: None


INFO:tensorflow:Signatures INCLUDED in export for Regress: None


I0822 16:12:10.714503 139661090527104 export_utils.py:170] Signatures INCLUDED in export for Regress: None


INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']


I0822 16:12:10.722679 139661090527104 export_utils.py:170] Signatures INCLUDED in export for Predict: ['serving_default']


INFO:tensorflow:Signatures INCLUDED in export for Train: None


I0822 16:12:10.732854 139661090527104 export_utils.py:170] Signatures INCLUDED in export for Train: None


INFO:tensorflow:Signatures INCLUDED in export for Eval: None


I0822 16:12:10.739479 139661090527104 export_utils.py:170] Signatures INCLUDED in export for Eval: None


INFO:tensorflow:Restoring parameters from /tmp/tmpb88j9chi/model.ckpt-8


I0822 16:12:10.839352 139661090527104 saver.py:1284] Restoring parameters from /tmp/tmpb88j9chi/model.ckpt-8


INFO:tensorflow:Assets added to graph.


I0822 16:12:10.889425 139661090527104 builder_impl.py:665] Assets added to graph.


INFO:tensorflow:No assets to write.


I0822 16:12:10.897397 139661090527104 builder_impl.py:460] No assets to write.


INFO:tensorflow:SavedModel written to: /tmp/tmpb88j9chi/saved_model/temp-b'1629648730'/saved_model.pb


I0822 16:12:10.969254 139661090527104 builder_impl.py:425] SavedModel written to: /tmp/tmpb88j9chi/saved_model/temp-b'1629648730'/saved_model.pb
[       OK ] TestSharedBottom.test_shared_bottom_estimator0 (num_tasks=2, bottom_units=[32, 16], task_units=[10, 5], task_output_activations=[None, None], task_losses=['mse', 'mse'], lr=0.001)
----------------------------------------------------------------------
Ran 3 tests in 7.092s

OK (skipped=1)
