Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions official/vision/beta/projects/yt8m/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# YouTube-8M Tensorflow Starter Code (tf2 version)

This repo contains starter code for training and evaluating machine learning
models over the [YouTube-8M][1] dataset.
This is the Tensorflow2 version of the original starter code:
[YouTube-8M Tensorflow Starter Code][2]
which was tested on Tensorflow 1.14. (The code gives an end-to-end
working example for reading the dataset, training a TensorFlow model,
and evaluating the performance of the model). Functionalities are maintained,
while necessary migrations were done to accomodate running on tf2 environment.

### Requirements

The starter code requires Tensorflow. If you haven't installed it yet, follow
the instructions on [tensorflow.org][3].
This code has been tested with Tensorflow 2.4.0. Going forward,
we will continue to target the latest released version of Tensorflow.

Please verify that you have Python 3.6+ and Tensorflow 2.4.0 or higher
installed by running the following commands:

```sh
python --version
python -c 'import tensorflow as tf; print(tf.__version__)'
```

Refer to the [instructions here][4]
for using the model in this repo. Make sure to add the models folder to your
Python path.

[1]: https://research.google.com/youtube8m/
[2]: https://github.com/google/youtube-8m
[3]: https://www.tensorflow.org/install/
[4]:
https://github.com/tensorflow/models/tree/master/official#running-the-models

#### Using GPUs

If your Tensorflow installation has GPU support
(which should have been provided with `pip install tensorflow` for any version
above Tensorflow 1.15), this code will make use of all of your compatible GPUs.
You can verify your installation by running

```
tf.config.list_physical_devices('GPU')
```

This will print out something like the following for each of your compatible
GPUs.

```
I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720]
Found device 0 with properties:
pciBusID: 0000:00:04.0 name: Tesla P100-PCIE-16GB computeCapability: 6.0
coreClock: 1.3285GHz coreCount: 56 deviceMemorySize: 15.90GiB
deviceMemoryBandwidth: 681.88GiB/s
...
```

### Train and inference
Train video-level model on frame-level features and inference at segment-level.

#### Train using the config file.
Create a YAML or JSON file for specifying the parameters to be overridden.
Working examples can be found in yt8m/experiments directory.
```sh
task:
model:
cluster_size: 2048
hidden_size: 2048
add_batch_norm: true
sample_random_frames: true
is_training: true
activation: "relu6"
pooling_method: "average"
yt8m_agg_classifier_model: "MoeModel"
train_data:
segment_labels: false
temporal_stride: 1
num_devices: 1
input_path: 'gs://youtube8m-ml/2/frame/train/train*.tfrecord'
num_examples: 3888919
...
```

The code can be run in different modes: `train / train_and_eval / eval`.
Run `yt8m_train.py` and specify which mode you wish to execute.
Training is done using frame-level features with video-level labels,
while inference can be done at segment-level.
Setting `segment_labels=True` in your configuration forces
the segment level labels to be used in the evaluation/validation phrase.
If set to `False`, video level labels are used for inference.

The following commands will train a model on Google Cloud over frame-level
features.

```bash
python3 yt8m_train.py --mode='train' \
--experiment='yt8m_experiment' \
--model_dir=$MODEL_DIR \
--config_file=$CONFIG_FILE
```

In order to run evaluation after each training epoch,
set the mode to `train_and_eval`.
Paths to both train and validation dataset on Google Cloud are set as
train: `input_path=gs://youtube8m-ml/2/frame/train/train*.tfrecord`
validation:`input_path=gs://youtube8m-ml/3/frame/validate/validate*.tfrecord`
as default.

```bash
python3 yt8m_train.py --mode='train_and_eval' \
--experiment='yt8m_experiment' \
--model_dir=$MODEL_DIR \
--config_file=$CONFIG_FILE \
```

Running on evaluation mode loads saved checkpoint from specified path
and runs inference task.
```bash
python3 yt8m_train.py --mode='eval' \
--experiment='yt8m_experiment' \
--model_dir=$MODEL_DIR \
--config_file=$CONFIG_FILE
```


Once these job starts executing you will see outputs similar to the following:
```
train | step: 15190 | training until step 22785...
train | step: 22785 | steps/sec: 0.4 | output:
{'learning_rate': 0.0049961056,
'model_loss': 0.0012011167,
'total_loss': 0.0013538885,
'training_loss': 0.0013538885}

```

and the following for evaluation:

```
eval | step: 22785 | running 2172 steps of evaluation...
eval | step: 22785 | eval time: 1663.4 | output:
{'avg_hit_at_one': 0.5572835238737471,
'avg_perr': 0.557277077999072,
'gap': 0.768825760186494,
'map': 0.19354554465020685,
'model_loss': 0.0005052475,
'total_loss': 0.0006564412,
'validation_loss': 0.0006564412}
```
Empty file.
2 changes: 2 additions & 0 deletions official/vision/beta/projects/yt8m/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
""" Configs package definition. """
from official.vision.beta.projects.yt8m.configs import yt8m
155 changes: 155 additions & 0 deletions official/vision/beta/projects/yt8m/configs/yt8m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification configuration definition."""
import dataclasses
from typing import Optional, Tuple

from absl import flags
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams, optimization

FLAGS = flags.FLAGS

YT8M_TRAIN_EXAMPLES = 3888919
YT8M_VAL_EXAMPLES = 1112356
# 2/frame -> frame level
# 3/frame -> segment level
YT8M_TRAIN_PATH = 'gs://youtube8m-ml/2/frame/train/train*.tfrecord'
YT8M_VAL_PATH = 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord'

@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""The base configuration for building datasets."""
name: Optional[str] = 'yt8m'
split: str = None
feature_sizes: Tuple[int, ...] = (1024, 128)
feature_names: Tuple[str, ...] = ('rgb', 'audio')
segment_size: int = 1
segment_labels: bool = False
temporal_stride: int = 1
max_frames: int = 300
num_frames: int = 300 # set smaller to allow random sample (Parser)
num_classes: int = 3862
num_devices: int = 1
input_path: str = ''
is_training: bool = True
random_seed: int = 123
num_examples: int = -1


def yt8m(is_training):
""" YT8M dataset configs. """
return DataConfig(
num_frames=30,
temporal_stride=1,
segment_labels=False,
segment_size=5,
is_training=is_training,
split='train' if is_training else 'valid',
num_examples=YT8M_TRAIN_EXAMPLES if is_training
else YT8M_VAL_EXAMPLES,
input_path=YT8M_TRAIN_PATH if is_training
else YT8M_VAL_PATH
)


@dataclasses.dataclass
class YT8MModel(hyperparams.Config):
"""The model config."""
cluster_size : int = 2048
hidden_size : int = 2048
add_batch_norm : bool = True
sample_random_frames : bool = True
is_training : bool = True
activation : str = 'relu6'
pooling_method : str = 'average'
yt8m_agg_classifier_model : str = 'MoeModel'

@dataclasses.dataclass
class Losses(hyperparams.Config):
name: str = 'binary_crossentropy'
from_logits: bool = False
label_smoothing: float = 0.0

@dataclasses.dataclass
class YT8MTask(cfg.TaskConfig):
"""The task config."""
model: YT8MModel = YT8MModel()
train_data: DataConfig = yt8m(is_training=True)
validation_data: DataConfig = yt8m(is_training=False)
losses: Losses = Losses()
gradient_clip_norm: float = 1.0
num_readers: int = 8
top_k: int = 20
top_n: int = None

def add_trainer(experiment: cfg.ExperimentConfig,
train_batch_size: int,
eval_batch_size: int,
learning_rate: float = 0.005,
train_epochs: int = 44,
):
"""Add and config a trainer to the experiment config."""
if YT8M_TRAIN_EXAMPLES <= 0:
raise ValueError('Wrong train dataset size {!r}'.format(
experiment.task.train_data))
if YT8M_VAL_EXAMPLES <= 0:
raise ValueError('Wrong validation dataset size {!r}'.format(
experiment.task.validation_data))
experiment.task.train_data.global_batch_size = train_batch_size
experiment.task.validation_data.global_batch_size = eval_batch_size
steps_per_epoch = YT8M_TRAIN_EXAMPLES // train_batch_size
experiment.trainer = cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=train_epochs * steps_per_epoch,
validation_steps=YT8M_VAL_EXAMPLES //
eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
}
},
'learning_rate': {
'type': 'exponential',
'exponential': {
'initial_learning_rate': learning_rate,
'decay_rate': 0.95,
'decay_steps': 1500000,
}
},
}))
return experiment

@exp_factory.register_config_factory('yt8m_experiment')
def yt8m_experiment() -> cfg.ExperimentConfig:
"""Video classification general."""
exp_config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=YT8MTask(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.train_data.num_classes == task.validation_data.num_classes',
'task.train_data.feature_sizes != None',
'task.train_data.feature_names != None',
])

return add_trainer(exp_config, train_batch_size=512,eval_batch_size=512)
Loading