Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train Movinet on tfrecord #10743

Open
KornilovaK opened this issue Aug 10, 2022 · 6 comments
Open

Train Movinet on tfrecord #10743

KornilovaK opened this issue Aug 10, 2022 · 6 comments

Comments

@KornilovaK
Copy link

I'm trying to train movinet model based on my tfrecord (2 classes of videos, 100 each, with the shape 256*256 and 12 fps)

dataset = tf.data.TFRecordDataset(['dataset.tfrecord'])

def _parse(example):
    
    features = {
        'video': tf.io.FixedLenFeature([2359296], tf.float32),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    
    parsed_dataset = tf.io.parse_single_example(example, features)
    
    video = parsed_dataset['video']
    video = tf.cast(video, tf.float32)
    video = tf.reshape(video, [12, 256, 256, 3])
    #label = tf.one_hot(parsed_dataset['label'], )
    label = parsed_dataset['label']

    
    return (video, label)

parsed_dataset = dataset.map(_parse, num_parallel_calls=tf.data.AUTOTUNE)
parsed_dataset
<ParallelMapDataset element_spec=(TensorSpec(shape=(12, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))>

Firstly, I followed the official tutorial to fine-tune

batch_size = 1
num_frames = 12
resolution = 256
model_id = 'a0'

tf.keras.backend.clear_session()

backbone = movinet.Movinet(model_id=model_id)
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)
model.build([1, 1, 1, 1, 3])

!wget https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz -O movinet_a0_base.tar.gz -q
!tar -xvf movinet_a0_base.tar.gz

checkpoint_dir = 'movinet_a0_base'
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()

def build_classifier(backbone, num_classes, freeze_backbone=False):
  """Builds a classifier on top of a backbone model."""
  model = movinet_model.MovinetClassifier(
      backbone=backbone,
      num_classes=num_classes)
  model.build([batch_size, num_frames, resolution, resolution, 3])

  if freeze_backbone:
    for layer in model.layers[:-1]:
      layer.trainable = False
    model.layers[-1].trainable = True

  return model

model = build_classifier(backbone, 2, freeze_backbone=True)

num_epochs = 3
total_train_steps = 600

loss_obj = tf.keras.losses.CategoricalCrossentropy(
    from_logits=True,
    label_smoothing=0.1)

metrics = [
    tf.keras.metrics.TopKCategoricalAccuracy(
        k=1, name='top_1', dtype=tf.float32),
    tf.keras.metrics.TopKCategoricalAccuracy(
        k=5, name='top_5', dtype=tf.float32),
]

initial_learning_rate = 0.01
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate, decay_steps=total_train_steps,
)
optimizer = tf.keras.optimizers.RMSprop(
    learning_rate, rho=0.9, momentum=0.9, epsilon=1.0, clipnorm=1.0)

model.compile(loss=loss_obj, optimizer=optimizer, metrics=metrics)

callbacks = [
    tf.keras.callbacks.TensorBoard(),
]

results = model.fit(
    parsed_dataset,
    epochs=num_epochs,
    callbacks=callbacks,
    validation_freq=1,
    verbose=1)

It results the error

Epoch 1/3
WARNING:tensorflow:Model was constructed with shape (None, None, None, None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, None, None, None, 3), dtype=tf.float32, name='image'), name='image', description="created by layer 'image'"), but it was called on an input with incompatible shape (12, 256, 256, 3).
WARNING:tensorflow:Model was constructed with shape (None, None, None, None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, None, None, None, 3), dtype=tf.float32, name='inputs'), name='inputs', description="created by layer 'inputs'"), but it was called on an input with incompatible shape (12, 256, 256, 3).
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-37-acd14ba5cf16>](https://localhost:8080/#) in <module>()
      4     callbacks=callbacks,
      5     validation_freq=1,
----> 6     verbose=1)

3 frames
[/usr/local/lib/python3.7/dist-packages/official/projects/movinet/modeling/movinet_layers.py](https://localhost:8080/#) in tf__call(self, inputs, states)
     27                     pass
     28                 ag__.if_stmt(ag__.and_((lambda : (ag__.ld(self)._conv_temporal is None)), (lambda : (ag__.ld(self)._stream_buffer is not None))), if_body, else_body, get_state, set_state, ('states', 'x'), 2)
---> 29                 x = ag__.converted_call(ag__.ld(self)._conv, (ag__.ld(x),), None, fscope)
     30 
     31                 def get_state_1():

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1160, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1146, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1135, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 993, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_file2bfw1kvo.py", line 14, in tf__call
        retval_ = ag__.converted_call(ag__.ld(self)._stem, (ag__.ld(inputs),), dict(states=ag__.ld(states)), fscope)
    File "/tmp/__autograph_generated_filecuy8n870.py", line 29, in tf__call
        x = ag__.converted_call(ag__.ld(self)._conv, (ag__.ld(x),), None, fscope)

    ValueError: Exception encountered when calling layer "stem" "                 f"(type Stem).
    
    in user code:
    
        File "/usr/local/lib/python3.7/dist-packages/official/projects/movinet/modeling/movinet_layers.py", line 1345, in call  *
            return self._stem(inputs, states=states)
        File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
            raise e.with_traceback(filtered_tb) from None
        File "/tmp/__autograph_generated_filecuy8n870.py", line 29, in tf__call
            x = ag__.converted_call(ag__.ld(self)._conv, (ag__.ld(x),), None, fscope)
    
        ValueError: Exception encountered when calling layer "stem" "                 f"(type StreamConvBlock).
        
        in user code:
        
            File "/usr/local/lib/python3.7/dist-packages/official/projects/movinet/modeling/movinet_layers.py", line 657, in call  *
                x = self._conv(x)
            File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
                raise e.with_traceback(filtered_tb) from None
            File "/usr/local/lib/python3.7/dist-packages/keras/engine/input_spec.py", line 251, in assert_input_compatibility
                f'Input {input_index} of layer "{layer_name}" '
        
            ValueError: Input 0 of layer "conv3d" is incompatible with the layer: expected min_ndim=5, found ndim=4. Full shape received: (12, 256, 256, 3)
        
        
        Call arguments received by layer "stem" "                 f"(type StreamConvBlock):
          • inputs=tf.Tensor(shape=(12, 256, 256, 3), dtype=float32)
          • states={}
    
    
    Call arguments received by layer "stem" "                 f"(type Stem):
      • inputs=tf.Tensor(shape=(12, 256, 256, 3), dtype=float32)
      • states={}

Then my attempt was to use tf hub and its implementation on a3 stream model and compiling and fitting from this article

import tensorflow as tf
import tensorflow_hub as hub

hub_url = "https://tfhub.dev/tensorflow/movinet/a3/stream/kinetics-600/classification/3"

encoder = hub.KerasLayer(hub_url, trainable=True)

# Define the image (video) input
image_input = tf.keras.layers.Input(
    shape=[12, 256, 256, 3],
    dtype=tf.float32,
    name='image')

# Define the state inputs, which is a dict that maps state names to tensors.
init_states_fn = encoder.resolved_object.signatures['init_states']
state_shapes = {
    name: ([s if s > 0 else None for s in state.shape], state.dtype)
    for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()
}
states_input = {
    name: tf.keras.Input(shape[1:], dtype=dtype, name=name)
    for name, (shape, dtype) in state_shapes.items()
}

# The inputs to the model are the states and the video
inputs = {**states_input, 'image': image_input}

outputs = encoder(inputs)

model = tf.keras.Model(inputs, outputs, name='movinet')

initial_learning_rate = 0.01
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "model.h5", save_best_only=True
)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=10, restore_best_weights=True
)

model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
        loss="binary_crossentropy",
        metrics=tf.keras.metrics.AUC(name="auc"),
    )

history = model.fit(
    parsed_dataset,
    epochs=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

It gives me

Epoch 1/2
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-44-d4e610c3ec4c>](https://localhost:8080/#) in <module>()
      2     parsed_dataset,
      3     epochs=2,
----> 4     callbacks=[checkpoint_cb, early_stopping_cb],
      5 )

1 frames
[/usr/local/lib/python3.7/dist-packages/keras/engine/training.py](https://localhost:8080/#) in tf__train_function(iterator)
     13                 try:
     14                     do_return = True
---> 15                     retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
     16                 except:
     17                     do_return = False

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1160, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1146, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1135, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 993, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/input_spec.py", line 217, in assert_input_compatibility
        f'Layer "{layer_name}" expects {len(input_spec)} input(s),'

    ValueError: Layer "movinet" expects 125 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(12, 256, 256, 3) dtype=float32>]

I'm really confused. I don't see any solutions and every time I get the new error. How to solve this?

@KornilovaK
Copy link
Author

I can provide the whole google colab notebook

@KornilovaK
Copy link
Author

I'm looking forward ANY useful information

@gadagashwini gadagashwini self-assigned this Aug 17, 2022
@gadagashwini
Copy link

Hi @KornilovaK,
ValueError: Input 0 of layer "conv3d" is incompatible with the layer: expected min_ndim=5, found ndim=4. Full shape received: (12, 256, 256, 3)
As error message says Conv3D expects 5D input. Add batch_size to dimension of your input data.
input_shape =(1, 12, 256, 256, 3)
Thank you!

@gadagashwini gadagashwini added the stat:awaiting response Waiting on input from the contributor label Aug 17, 2022
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@KornilovaK
Copy link
Author

Thanks for your response.
Unfortunately, that doesn't work. I've already tried it.
If I add in the map function
video = tf.reshape(video, [1, 12, 256, 256, 3])
instead of
video = tf.reshape(video, [12, 256, 256, 3])
and then try to fit it by the first variant, the error occurs:

Epoch 1/3
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-50-acd14ba5cf16>](https://localhost:8080/#) in <module>
      4     callbacks=callbacks,
      5     validation_freq=1,
----> 6     verbose=1)

1 frames
[/usr/local/lib/python3.7/dist-packages/keras/engine/training.py](https://localhost:8080/#) in tf__train_function(iterator)
     13                 try:
     14                     do_return = True
---> 15                     retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
     16                 except:
     17                     do_return = False

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1177, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1161, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1150, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1009, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1068, in compute_loss
        y, y_pred, sample_weight, regularization_losses=self.losses
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 265, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 152, in __call__
        losses = call_fn(y_true, y_pred)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 272, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 1987, in categorical_crossentropy
        label_smoothing, _smooth_labels, lambda: y_true
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 1981, in _smooth_labels
        num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype)

    ValueError: slice index -1 of dimension 0 out of bounds. for '{{node categorical_crossentropy/strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](categorical_crossentropy/Shape, categorical_crossentropy/strided_slice/stack, categorical_crossentropy/strided_slice/stack_1, categorical_crossentropy/strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <-1>, input[2] = <0>, input[3] = <1>.

@google-ml-butler google-ml-butler bot removed stat:awaiting response Waiting on input from the contributor stale labels Aug 30, 2022
@KornilovaK
Copy link
Author

or maybe I misunderstood you

@gadagashwini gadagashwini removed their assignment Sep 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants