Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.estimator.train_and_evaluate not run evaluate when distribute strategy is CollectiveAllReduceStrategy #27857

Closed
wudixiaotie opened this issue Apr 15, 2019 · 11 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues stat:awaiting response Status - Awaiting response from author type:bug Bug

Comments

@wudixiaotie
Copy link

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    CentOS Linux release 7.3.1611
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):
    binary
  • TensorFlow version (use command below):
    1.13.1
  • Python version:
    2.7.5
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture script
You can also obtain the TensorFlow version with
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
('v1.13.1-0-g6612da8951', '1.13.1')

Describe the current behavior
When I try to run estimator in distribute with CollectiveAllReduceStrategy strategy, the train_and_evaluate api do not run evaluation after model save checkpoint.

Describe the expected behavior
train_and_evaluate should run evaluation after model save checkpoint

Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.

import tensorflow as tf
import random
import os
import json

class Generator:
    def __init__(self, mode, batch_size=100):
        self._i = 0
        self._mode = mode
        self._batch_size = batch_size

    def _get_random(self):
        return random.uniform(0, 100)

    def next_batch(self):
        self._i += 1
        if self._mode != tf.estimator.ModeKeys.TRAIN and self._i > 200:
            raise StopIteration
        features = {'a': [], 'b': []}
        labels = []
        for _ in xrange(self._batch_size):
            label = 0.0
            for key in features:
                r = self._get_random()
                features[key].append(r)
                label += r
            labels.append(label)
        return features, labels

    def output_types(self):
        return ({'a': tf.float32, 'b': tf.float32}, tf.float32)

    def output_shapes(self):
        return ({'a': [None], 'b': [None]}, [None])

def _dataset(mode):
    generator = Generator(mode)

    def generate_data():
        while True:
            yield generator.next_batch()

    return tf.data.Dataset.from_generator(
            generator=generate_data,
            output_types=generator.output_types(),
            output_shapes=generator.output_shapes(),
            args=[])

def _my_model_fn(features, labels, mode, params):
    learning_rate = params['learning_rate']
    keep_prob = params['keep_prob']
    feature_columns = [tf.feature_column.numeric_column('a'),
            tf.feature_column.numeric_column('b')]
    dense_tensor = tf.feature_column.input_layer(features, feature_columns)
    dense_tensor = tf.nn.dropout(dense_tensor, keep_prob=keep_prob)
    for units in [64, 32]:
        dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
    predictions = tf.layers.dense(dense_tensor, 1)
    predictions = tf.squeeze(predictions, [1])
    loss = tf.losses.absolute_difference(labels=labels, predictions=predictions)
    if mode == tf.estimator.ModeKeys.EVAL:
        accuracy_op = tf.metrics.accuracy(
                labels=labels,
                predictions=predictions,
                name='accuracy_op')
        eval_metric_ops = {'accuracy': accuracy_op}
        spec = tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.EVAL,
                loss=loss,
                eval_metric_ops=eval_metric_ops)
    else:
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
        spec = tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.TRAIN,
                loss=loss,
                train_op=train_op)
    return spec

def _cluster():
    return {'worker': ['localhost:2222', 'localhost:2223', 'localhost:2224']}

def _set_tf_config(index):
    tf_config = {
            'cluster': _cluster(),
            'task': {'type': 'worker', 'index': index}}
    os.environ['TF_CONFIG'] = json.dumps(tf_config)

def main(argv):
    distribution = tf.contrib.distribute.CollectiveAllReduceStrategy()
    config = tf.estimator.RunConfig(
            save_checkpoints_steps=2000,
            keep_checkpoint_max=1,
            train_distribute=distribution,
            eval_distribute=distribution)
    model_dir = './model'
    learning_rate = 1e-6
    keep_prob = 0.75
    estimator = tf.estimator.Estimator(
            model_fn=_my_model_fn,
            model_dir=model_dir,
            config=config,
            params={
                'learning_rate': learning_rate,
                'keep_prob': keep_prob
            })
    train_spec = tf.estimator.TrainSpec(
            input_fn=lambda : _dataset(tf.estimator.ModeKeys.TRAIN),
            max_steps=4000)
    eval_spec = tf.estimator.EvalSpec(
            input_fn=lambda : _dataset(tf.estimator.ModeKeys.EVAL))
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

if __name__ == '__main__':
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_integer(
        'index', 0, 'input task index')
    _set_tf_config(FLAGS.index)
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@achandraa achandraa self-assigned this Apr 16, 2019
@achandraa achandraa added comp:dist-strat Distribution Strategy related issues type:bug Bug labels Apr 16, 2019
@achandraa achandraa assigned ymodak and unassigned achandraa Apr 17, 2019
@devinkmoore
Copy link

I am also seeing this same problem with tf.contrib.distribute.MirroredStrategy(). Nearly the same exact setup as above, but reading from tfrecords instead of a generator.

@byronyi
Copy link
Contributor

byronyi commented Apr 30, 2019

Ping @yuefengz

@byronyi
Copy link
Contributor

byronyi commented Apr 30, 2019

Btw could you try tf-nightly and reproduce the error?

@devinkmoore
Copy link

@byronyi, thanks for the reply. I tried installing tf-nightly into a clean virtual env (as well as the conda env that I normally use). Unfortunately, I'm unable to train an estimator with a distribution strategy (no distribution seems to work normally). A few details below (want to avoid adding too much off an unrelated issue). Sorry this doesn't reproduce the error exactly or give you code that runs out of the box but I've included enough to hopefully give you an idea of what I'm doing

  1. MirroredStrategy fails, when being applied in the exact same way as before.
ValueError: In-graph multi-worker training with `MirroredStrategy` is not supported.
  1. Trying CollectiveAllReduce, it fails asking me to pass a tf.losses.reduction to the loss function
ValueError: Please use `tf.keras.losses.Reduction.SUM` or `tf.keras.losses.Reduction.NONE` for loss reduction when losses are used with `tf.distribute.Strategy` outside of the built-in training loops.

Trying that out, I get this:

AttributeError: module 'tensorflow._api.v1.losses' has no attribute 'reduction'

3)Trying MultiWorkerAllReduce, I get:

AttributeError: 'MultiWorkerAllReduce' object has no attribute 'configure'


def autoencoder_dataset(filename, batch_size, interleave_cycle_length, input_context):

    feature_description = {
        'vendor_variant_id': tf.FixedLenFeature([], tf.int64, default_value=0),
        'image': tf.FixedLenFeature([], tf.string, default_value=''),
        'taxonomy_id': tf.FixedLenFeature([], tf.int64, default_value=-1)
    }

    def _parse_function(example_proto):

        parsed = tf.parse_single_example(example_proto, feature_description)
        vendor_variant_id = parsed['vendor_variant_id']
        image = parsed['image']
        taxonomy_id = parsed['taxonomy_id']

        return vendor_variant_id, image, taxonomy_id

    def scale_image(image):

        image = tf.cast(image, tf.float32)
        image /= 255

        return image

    def preprocess_image(vendor_variant_id, image, taxonomy_id):

        decoded_image = tf.image.decode_image(image, channels=3)
        decoded_image = tf.reshape(decoded_image, (224,224,3))

        augmented_image = tf.image.random_flip_left_right(decoded_image, seed=919)
        augmented_image = tf.image.random_flip_up_down(augmented_image, seed=919)

        np.random.seed(919)
        pixel_shifts = np.random.randint(-30, 30, size=2).tolist()
        augmented_image = tf.contrib.image.translate(augmented_image, pixel_shifts)

        augmented_scaled_image = scale_image(decoded_image)
        scaled_image = scale_image(augmented_image)

        return augmented_scaled_image, scaled_image

    def autoencoder_dataset(tfrecord_path, batch_size, interleave_cycle_length, input_context):

        files = tf.data.Dataset.list_files(tfrecord_path)

        if input_context:
            files = files.apply(
                tf.data.experimental.filter_for_shard(input_context.num_input_pipelines,
                                                      input_context.input_pipeline_id)
                )

        if interleave_cycle_length>1:
            raw_dataset = files.apply(tf.contrib.data.parallel_interleave(
                tf.data.TFRecordDataset,
                cycle_length=interleave_cycle_length,
                sloppy=True))
        else:
            raw_dataset = tf.data.TFRecordDataset(files)

        shuffled_dataset = raw_dataset.shuffle(1000)
        parsed_dataset = shuffled_dataset.map(_parse_function,
                                              num_parallel_calls=tf.data.experimental.AUTOTUNE)

        mapped_dataset = parsed_dataset.apply(
            tf.data.experimental.map_and_batch(map_func=preprocess_image,
                                          batch_size=batch_size,
                                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
            )

        return mapped_dataset.repeat()

    return autoencoder_dataset(filename, batch_size, interleave_cycle_length, input_context)


def input_fn(params, mode, input_context=None):

    batch_size = params['batch_size']
    interleave_cycle_length = params['interleave_cycle_length']

    dataset_path = params['train_dataset_path']
    n_images = params['n_train_images']

    if mode == tf.estimator.ModeKeys.EVAL:

        dataset = autoencoder_dataset(dataset_path, batch_size, 1, input_context)

    elif mode == tf.estimator.ModeKeys.TRAIN:

        dataset = autoencoder_dataset(dataset_path, batch_size, interleave_cycle_length, input_context)

    return dataset.prefetch(tf.data.experimental.AUTOTUNE)


def activation_factory(params):

    if params['activation'] == 'leaky':

        return tf.keras.layers.LeakyReLU(0.01)

    else:

        return tf.keras.layers.LeakyReLU(0.0)


def model_fn(features, labels, params, mode):

    n_hidden_1 = 16  # 1st hidden layer
    n_hidden_2 = 12  # 2nd hidden layer
    n_hidden_3 = 10  # 3rd hidden layer
    n_hidden_4 = 8

    convkernel = (3, 3)
    poolkernel = (2, 2)

    activation = activation_factory(params)

    model = tf.keras.Sequential([

        ## Encoder
        tf.keras.layers.Conv2D(n_hidden_1, convkernel, activation='relu', input_shape=(224, 224, 3)),
        tf.keras.layers.MaxPooling2D(poolkernel, padding='same'),
        tf.keras.layers.Conv2D(n_hidden_2, convkernel, padding='same'),
        activation,
        tf.keras.layers.Conv2D(n_hidden_3, convkernel, padding='same'),
        activation,
        tf.keras.layers.MaxPooling2D(poolkernel, padding='same'),
        tf.keras.layers.Conv2D(n_hidden_4, convkernel, activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D(poolkernel, padding='same', name='bottleneck'),

        ## Decoder
        tf.keras.layers.Conv2D(n_hidden_4, convkernel, activation='relu', padding='same'),
        tf.keras.layers.UpSampling2D(poolkernel),
        tf.keras.layers.Conv2D(n_hidden_3, convkernel, padding='same'),
        activation,
        tf.keras.layers.Conv2D(n_hidden_2, convkernel, padding='same'),
        activation,
        tf.keras.layers.UpSampling2D(poolkernel),
        tf.keras.layers.Conv2D(32, (1,1)),
        activation,
        tf.keras.layers.UpSampling2D(poolkernel),
        tf.keras.layers.Conv2D(3, convkernel, activation='sigmoid', padding='same')

        ])

    global_step = tf.train.get_global_step()

    logits = model(features, training=False)
    predictions = {'logits': logits}

    if mode == tf.estimator.ModeKeys.PREDICT:

        return tf.estimator.EstimatorSpec(labels=labels,
                                          predictions=predictions)

    momentum = params['learning_rates_momentum']
    boundaries = [e*(params['n_train_images']//params['batch_size']) for e in params['learning_rates_epoch_schedule']]

    n_gpus = params['cluster_gpus']

    if 'learning_rates_warmup' in params.keys() and params['learning_rates_warmup']:
        scaled_rates = [r * n_gpus for r in params['learning_rates']]
        scaled_rates = [params['learning_rates'][0]] + scaled_rates
        learning_rate = manual_stepping(global_step, boundaries, scaled_rates, warmup=True)
    else:
        learning_rate = manual_stepping(global_step, boundaries, params['learning_rates'])

    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=momentum)

    loss = tf.keras.losses.MeanSquaredError()(labels, logits)

    tf.summary.scalar('learning_rate', optimizer._learning_rate)

    if mode == tf.estimator.ModeKeys.TRAIN:

        train_op = tf.contrib.training.create_train_op(loss,
                                                       optimizer,
                                                       summarize_gradients=True)

        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op
        )

    else:

        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss)


def set_tf_config(config, is_chief, worker_index=None):

    tf_config = {}
    tf_config['cluster'] = {}
    tf_config['cluster']['chief'] = ['{}:{}'.format(config['chief']['ip'], str(config['chief']['port']))]
    tf_config['cluster']['worker'] = []

    for i,w in enumerate(config['workers']):

        if is_chief:

            is_chief = True

            if not worker_index:

                worker_index = 0

        else:

            if worker_index is None:

                worker_index = i

        tf_config['cluster']['worker'].append('{}:{}'.format(w['ip'], str(w['port'])))

    if worker_index is None:
        raise Exception('Instance public DNS name not found')

    if is_chief:
        tf_config['task'] = {'type': 'chief', 'index': worker_index}
    else:
        tf_config['task'] = {'type': 'worker', 'index': worker_index}

    os.environ['TF_CONFIG'] = json.dumps(tf_config)

save_summary_steps = 1
save_checkpoints_steps = 1

is_chief = False
worker_index = None
cluster_config = {"workers": [{"ip": "localhost", "port": 1111},
                              {"ip": "localhost", "port": 1112}]}
# Will infer worker index from cluster config if not passed as an arg
set_tf_config(cluster_config, is_chief, worker_index)


strategy = tf.contrib.distribute.MirroredStrategy()

run_config = tf.estimator.RunConfig(
    save_summary_steps=save_summary_steps,
    save_checkpoints_steps=save_checkpoints_steps,
    train_distribute=strategy,
    eval_distribute=None
)

train_config = {
    # Various parameters passed to my model_fn
}

classifier = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir='...',
    params=train_config,
    config=run_config
)

tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn,
                                      max_steps=3),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn,
                                    steps=1,
                                    start_delay_secs=0,
                                    throttle_secs=0)
)

System information

Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04.6 LTS (GNU/Linux 4.4.0-1079-aws x86_64v) AND OSX Sierra 10.13.6
TensorFlow installed from (source or binary): pip tf-nightly
TensorFlow version (use command below): 1.13.1
Python version: 3.6.5
CUDA/cuDNN version: CUDA 10
GPU model and memory: NVIDIA V100 Tensor Core GPU (AWS p3.2x large)

@devinkmoore
Copy link

devinkmoore commented May 1, 2019

Found the solution (to at least my problem and hopefully the one above). If you don't supply a separate machine/task an evaluator you need to specify one of your machines/tasks as "master" NOT "chief" or "worker". In _TrainingExecutor (which is called by the distribute coordinator) there are separate methods for run_chief and run_master, run_chief does not call estimator.evaluate while run_master does!!

@wudixiaotie @byronyi

@ymodak
Copy link
Contributor

ymodak commented May 1, 2019

@wudixiaotie Is this still an issue?

@ymodak ymodak added the stat:awaiting response Status - Awaiting response from author label May 1, 2019
@yuefengz
Copy link
Contributor

yuefengz commented May 3, 2019

Please note "master" is not officially supported by distribution strategy or Estimator. If you want to run evaluation, you need to have an "evaluator" task with train_and_evaluate API.

@yuefengz yuefengz closed this as completed May 3, 2019
@tensorflow-bot
Copy link

tensorflow-bot bot commented May 3, 2019

Are you satisfied with the resolution of your issue?
Yes
No

@pavithrasv
Copy link
Member

@devinkmoore Are you seeing any loss scaling related issue or is that resolved?

@chopwoodwater
Copy link

Please note "master" is not officially supported by distribution strategy or Estimator. If you want to run evaluation, you need to have an "evaluator" task with train_and_evaluate API.

@yuefengz Any TF 2.x code examples to configure the "evaluator" task? Thanks.

@dekun-zou
Copy link

Please note "master" is not officially supported by distribution strategy or Estimator. If you want to run evaluation, you need to have an "evaluator" task with train_and_evaluate API.

Do you me in the tfconfig specify "evaluator" in addition to "chief", "worker", "ps"?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues stat:awaiting response Status - Awaiting response from author type:bug Bug
Projects
None yet
Development

No branches or pull requests

10 participants