Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cached augmentation in segmentation tutorial - this does not increase dataset size #47755

Closed
jameshfisher opened this issue Mar 12, 2021 · 7 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:docs-bug Document issues

Comments

@jameshfisher
Copy link
Contributor

URL(s) with the issue:

https://www.tensorflow.org/tutorials/images/segmentation

Description of issue (what needs changing):

Augmentation, to increase the size of the dataset, has to convert one source datapoint into many augmented datapoints. But in this tutorial, augmentation is applied once to each datapoint - effectively keeping the dataset size the same. The root cause is that augmentation is applied before a Dataset.cache().

These are the relevant lines of the tutorial:

@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  # This bit is random augmentation - mixed into the load function
  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

# Here we load the dataset, including augmentation
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)

# Then we cache that single round of augmentation, and repeat that single round forever
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()

I believe this should look more like (untested):

@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

# We can cache here, because the cached dataset is deterministic
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE).cache()

@tf.function
def random_augment(input_image, input_mask):
  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)
  return input_image, input_mask

# Now apply augmentation after caching, getting different results each time
train_dataset = train.map(random_augment).shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
@jameshfisher
Copy link
Contributor Author

(Happy to contribute change to tutorial, if we agree that this is a problem)

@Saduf2019 Saduf2019 added type:docs-bug Document issues comp:ops OPs related issues labels Mar 15, 2021
@Saduf2019 Saduf2019 assigned ymodak and unassigned Saduf2019 Mar 15, 2021
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 17, 2021
@RobertBiehl
Copy link

RobertBiehl commented Apr 17, 2021

@jameshfisher I think the tutorial should probably mention the advantages and disadvantages. Image augmentation quickly becomes the bottleneck, and depending on your goals both ways can be useful.

@jameshfisher
Copy link
Contributor Author

@RobertBiehl what are the advantages of the current approach (cached augmentation) vs no augmentation? As far as I can see, there aren't any, but I might be missing something ...?

@jameshfisher
Copy link
Contributor Author

FYI, this post was popular a few days ago, about what seems to be an identical issue with PyTorch tutorials.

@aaudiber
Copy link
Contributor

Thanks @jameshfisher, your suggestion is much better. Caching right before augmentation will give most of the benefits of caching, while still allowing augmentation to be re-applied each epoch, so that the model gets to train on a wider variety of inputs. If you're still interested in submitting a change, I'd be happy to review.

@RobertBiehl
Copy link

@RobertBiehl what are the advantages of the current approach (cached augmentation) vs no augmentation? As far as I can see, there aren't any, but I might be missing something ...?

I guess it is an edge case. E.g. if for some reason the fact that you augment changes the input data distribution in a way you want, and for performance reasons you accept the fact that the augmentation is frozen for the rest of the training. (e.g. if the reason for augmentation is not just producing more training data but changing the input in some way).

For the tutorial you suggestion definitely makes sense.

@8bitmp3
Copy link
Contributor

8bitmp3 commented Jun 7, 2021

Thank you for the awesome feedback @jameshfisher @RobertBiehl @aaudiber 👍 We'll check this out cc @MarkDaoust

@ymodak ymodak removed their assignment Jun 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:docs-bug Document issues
Projects
None yet
Development

No branches or pull requests

7 participants