Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jpeg decoding (for example when loading TFRecords from files) causes error on TPU when trying to fit a model #41590

Closed
sedol1339 opened this issue Jul 21, 2020 · 7 comments
Assignees
Labels
comp:tpus tpu, tpuestimator TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@sedol1339
Copy link

System information

  • TensorFlow version (use command below): 2.2.0 (v2.2.0-0-g2b96f3662b)
  • Python version: 3.6.9
  • GPU model and memory: Google Colab TPU

I'm not sure that this is a bug, but I've encountered this weird behaviour with my .tfrec dataset and made simple code to reproduce it. This problem only exists at TPU.

Firstly I initialize TPU:

import os
import tensorflow as tf
import numpy as np

tf.get_logger().propagate = False
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu = 'grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
INFO:tensorflow:Initializing the TPU system: grpc://10.26.115.226:8470
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

Then I run the following code, which creates tf.data.Dataset of dummy images, encodes it to jpeg and back, then normalizes to float32 and makes batches.

with strategy.scope():

  def encode_jpg(image, class_idx):
    return tf.io.encode_jpeg(image, quality = 95, optimize_size = True, chroma_downsampling = False), class_idx

  def decode_jpg(image, class_idx):
    return tf.image.decode_jpeg(image, channels = 3), class_idx

  def normalize_img(image, class_idx):
    return image / 255 - 0.5, class_idx

  dataset = tf.data.Dataset.from_tensor_slices((
    [tf.cast(np.zeros((256, 256, 3)), dtype = tf.uint8) for _ in range(300)],
    [0 for _ in range(300)]
  ))
  dataset = dataset.map(encode_jpg)
  dataset = dataset.map(decode_jpg)
  dataset = dataset.map(normalize_img)
  dataset = dataset.batch(8)

  print('\nhow does our dataset look like?')
  for i, (image, label) in enumerate(dataset):
    print(image.shape, label.shape)
    if i == 2: break

  model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape = (256, 256, 3)),
    tf.keras.layers.Dense(100, activation = 'relu'),
    tf.keras.layers.Dense(10)
  ])

  print('\nhow does our model model like?')
  model.summary()

  model.compile(loss = 'sparse_categorical_crossentropy', optimizer = 'adam')
  model.fit(dataset, epochs = 1)

I receive the following output which ends with exception:

how does our dataset look like?
(8, 256, 256, 3) of <dtype: 'float32'> (8,) of <dtype: 'int32'>
(8, 256, 256, 3) of <dtype: 'float32'> (8,) of <dtype: 'int32'>
(8, 256, 256, 3) of <dtype: 'float32'> (8,) of <dtype: 'int32'>

how does our model model like?
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 196608)            0         
_________________________________________________________________
dense (Dense)                (None, 100)               19660900  
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1010      
=================================================================
Total params: 19,661,910
Trainable params: 19,661,910
Non-trainable params: 0
_________________________________________________________________
---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
<ipython-input-2-9c1762b0cefe> in <module>()
     36 
     37   model.compile(loss = 'sparse_categorical_crossentropy', optimizer = 'adam')
---> 38   model.fit(dataset, epochs = 1)

10 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

UnimplementedError: {{function_node __inference_train_function_5323}} Compilation failure: Asked to propagate a
dynamic dimension from hlo dot.472@{}@0 to hlo %all-reduce.477 = f32[<=196608,100]{1,0}
all-reduce(f32[<=196608,100]{1,0} %dot.472), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=%sum.473,
metadata={op_type="CrossReplicaSum" op_name="CrossReplicaSum_2"}, which is not implemented.
	TPU compilation failed
	 [[{{node tpu_compile_succeeded_assert/_4970277850434216321/_5}}]]

When I remove these two lines:

dataset = dataset.map(encode_jpg)
dataset = dataset.map(decode_jpg)

Then it works:

38/38 [==============================] - 1s 16ms/step - loss: 16.8563

However shapes and types of dataset batches remain the same:

how does our dataset look like?
(8, 256, 256, 3) of <dtype: 'float32'> (8,) of <dtype: 'int32'>
(8, 256, 256, 3) of <dtype: 'float32'> (8,) of <dtype: 'int32'>
(8, 256, 256, 3) of <dtype: 'float32'> (8,) of <dtype: 'int32'>

To fix this error I tried to case labels to tf.int64, but error still occurs. I tried to run this code on CPU version of Colab (removing with strategy.scope():), and then it works perfectly. So I guess the problem is in TPU and jpeg encoding-decoding.

@sedol1339
Copy link
Author

sedol1339 commented Jul 21, 2020

I also tried to change label to tf.one_hot(label, 10) and change loss to categorical_crossentropy, but the error remains the same.

@sedol1339 sedol1339 changed the title Jpeg decoding (for example wien loading TFRecords from files) raises error on TPU when trying to fit a model Jpeg decoding (for example when loading TFRecords from files) raises error on TPU when trying to fit a model Jul 21, 2020
@sedol1339 sedol1339 changed the title Jpeg decoding (for example when loading TFRecords from files) raises error on TPU when trying to fit a model Jpeg decoding (for example when loading TFRecords from files) causes error on TPU when trying to fit a model Jul 21, 2020
@amahendrakar
Copy link
Contributor

Was able to reproduce the issue with TF v2.2. Please find the gist of it here. Thanks!

@amahendrakar amahendrakar added comp:dist-strat Distribution Strategy related issues comp:tpus tpu, tpuestimator TF 2.2 Issues related to TF 2.2 labels Jul 22, 2020
@amahendrakar amahendrakar assigned ymodak and unassigned amahendrakar Jul 22, 2020
@nikitamaia
Copy link
Member

The error is not reproducible with MirroredStrategy so I think this is TPU specific.

Based on the error it looks like you're passing in some dynamic shape. I wonder if the following has something to do with it?
If you look at the DatasetSpec of the dataset without the encoding/decoding the dimensions are known.
DatasetSpec(<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>, TensorShape([]))

But the encoded/decoded dataset has None for the height and width.
DatasetSpec(<BatchDataset shapes: ((None, None, None, 3), (None,)), types: (tf.float32, tf.int32)>, TensorShape([]))

@nikitamaia nikitamaia removed the comp:dist-strat Distribution Strategy related issues label Jul 23, 2020
@nikitamaia
Copy link
Member

nikitamaia commented Jul 23, 2020

Try explicitly setting the size after you decode, using tf.reshape. I think that should work.

def decode_jpg(image, class_idx):
  return tf.reshape(tf.image.decode_jpeg(image, channels = 3),[256,256, 3]), class_idx

@sedol1339
Copy link
Author

@nikitamaia thanks! that helped

@ymodak ymodak assigned nikitamaia and unassigned ymodak Jul 24, 2020
@nikitamaia
Copy link
Member

Closing this issue since a solution was found. Explicit size is needed for TPUs.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:tpus tpu, tpuestimator TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
Development

No branches or pull requests

4 participants