Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c2a5d26
Add files via upload
mjyun01 Jan 14, 2024
2899533
Rename train.py to train.py
mjyun01 Jan 14, 2024
f479bf7
Delete official/projects/RNGDET directory
mjyun01 Jan 14, 2024
7863a97
Add files via upload
mjyun01 Jan 14, 2024
c567f2c
Update run_test.py
mjyun01 Jan 14, 2024
4b0ad1a
Update run_test_all.py
mjyun01 Jan 14, 2024
32adcea
Update README.md
mjyun01 Jan 14, 2024
e377848
Update run_test.py
mjyun01 Jan 14, 2024
0aacff8
Delete official/projects/rngdet/only_eval_metric.py
mjyun01 Jan 14, 2024
d3ce989
Update run_test_all.py
mjyun01 Jan 14, 2024
452d3ce
Update README.md
mjyun01 Jan 14, 2024
9f7a098
Update README.md
mjyun01 Jan 14, 2024
693f0d1
Add files via upload
mjyun01 Jan 14, 2024
03b0e90
Update rngdet_input.py
mjyun01 Jan 14, 2024
6b57692
Update run_test_all.py
mjyun01 Jan 14, 2024
3d38757
Update do_train.sh
mjyun01 Jan 14, 2024
87f1688
Update run_test_all.py
mjyun01 Jan 14, 2024
35aac4c
Clean for PR
gunho1123 Jan 17, 2024
38454d6
Update rngdet.py
mjyun01 Jan 17, 2024
b0de2fb
clean up for PR
mjyun01 Jan 17, 2024
19220f2
clean up for PR
mjyun01 Jan 17, 2024
e2b18d1
Update rngdet.py
mjyun01 Jan 17, 2024
a1bef83
Update README.md
mjyun01 Jan 17, 2024
1294ae2
Update README.md
mjyun01 Jan 17, 2024
99a3bf6
Update README.md
mjyun01 Jan 17, 2024
3941690
Update rngdet_test.py
mjyun01 Jan 18, 2024
b69be04
Update rngdet_test.py
mjyun01 Jan 18, 2024
2179239
Delete official/projects/rngdet/tasks/__pycache__ directory
mjyun01 Jan 18, 2024
9e25367
Delete official/projects/rngdet/configs/__pycache__ directory
mjyun01 Jan 18, 2024
f64a4de
Delete official/projects/rngdet/dataloaders/__pycache__ directory
mjyun01 Jan 18, 2024
33b7ed2
Delete official/projects/rngdet/eval/__pycache__ directory
mjyun01 Jan 18, 2024
790a02a
Delete official/projects/rngdet/modeling/__pycache__ directory
mjyun01 Jan 18, 2024
a48020c
Delete official/projects/rngdet/metric directory
mjyun01 Jan 18, 2024
287eae3
Update run_test_all.py
mjyun01 Jan 18, 2024
0930ba6
Create requirements.txt
mjyun01 Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions official/projects/rngdet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Road Network Graph Detection by Transformer

[![RNGDet](https://img.shields.io/badge/RNGDet-arXiv.2202.07824-B3181B?)](https://arxiv.org/abs/2202.07824)
[![RNGDet++](https://img.shields.io/badge/RNGDet++-arXiv.2209.10150-B3181B?)](https://arxiv.org/abs/2209.10150)

## Environment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
of `tf.distribute`.

## Data preparation
To download the dataset and generate labels, try the following command:

```
cd data
./prepare_dataset.bash
```

To generate training samples, try the following command:

```
python create_cityscale_tf_record.py \
--dataroot ./dataset/ \
--roi_size 128 \
--image_size 2048 \
--edge_move_ahead_length 30 \
--num_queries 10 \
--noise 8 \
--max_num_frame 10000 \
--num_shards 32
```
## Training
To edit training options of RNGDet, you can edit following commands in do_train.sh :

```
CUDA_VISIBLE_DEVICES=4 python3 train.py \
--mode=train \
--experiment=rngdet_cityscale \
--model_dir=./CKPT_DIR_NAME \
--config_file=./configs/experiments/cityscale_rngdet_r50_gpu.yaml \
```

To start training, try the following command :
```
sh do_train.sh
```

## Evaluation
To evaluate one image with internal step visualization,

```
python run_test.py -ckpt ./CKPT_DIR_NAME
```

To evaluate all images in the test dataset, and see score(P-P, P-R, R-F) for each images,

```
python run_test_all.py -ckpt ./CKPT_DIR_NAME
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
num_gpus: 1
task:
train_data:
dtype: 'float32'
validation_data:
dtype: 'float32'
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
train_data:
dtype: 'float32'
validation_data:
dtype: 'float32'
228 changes: 228 additions & 0 deletions official/projects/rngdet/configs/rngdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DETR configurations."""

import dataclasses
import os
from typing import List, Optional, Union

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import backbones
#from official.projects.rngdet import optimization as optimization_detr


@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
tfds_name: str = ''
tfds_split: str = 'train'
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'float32'
decoder: common.DataDecoder = dataclasses.field(default_factory=common.DataDecoder)
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True


@dataclasses.dataclass
class Losses(hyperparams.Config):
lambda_cls: float = 1.0
lambda_box: float = 5.0
background_cls_weight: float = 0.2

@dataclasses.dataclass
class Rngdet(hyperparams.Config):
"""Rngdet model definations."""
num_queries: int = 10
hidden_size: int = 256
num_classes: int = 2 # 0: vertices, 1: background
num_encoder_layers: int = 6
num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
roi_size: int = 128
backbone: backbones.Backbone = dataclasses.field(default_factory=lambda:backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)))
decoder: decoders.Decoder = dataclasses.field(
default_factory=lambda: decoders.Decoder(type='fpn', fpn=decoders.FPN())
)
min_level: int = 2
max_level: int = 5
norm_activation: common.NormActivation = dataclasses.field(default_factory=common.NormActivation)
backbone_endpoint_name: str = '5'


@dataclasses.dataclass
class RngdetTask(cfg.TaskConfig):
model: Rngdet = dataclasses.field(default_factory=Rngdet)
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
validation_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
losses: Losses = dataclasses.field(default_factory=Losses)
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
per_category_metrics: bool = False


#CITYSCALE_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/cityscale'
CITYSCALE_TRAIN_EXAMPLES = 420140
#CITYSCALE_TRAIN_EXAMPLES = 10140
CITYSCALE_INPUT_PATH_BASE = '/data2/cityscale/tfrecord'
#CITYSCALE_TRAIN_EXAMPLES = 1900
CITYSCALE_VAL_EXAMPLES = 5000

@exp_factory.register_config_factory('rngdet_cityscale')
def rngdet_cityscale() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = CITYSCALE_TRAIN_EXAMPLES // train_batch_size
train_steps = 50 * steps_per_epoch # 50 epochs
config = cfg.ExperimentConfig(
task=RngdetTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint_modules='backbone',
model=Rngdet(
input_size=[128, 128, 3],
roi_size=128,
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
#input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise-8-00000-of-00032.tfrecord*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=CITYSCALE_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=1*steps_per_epoch,
validation_interval=1*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw_experimental',
'adamw_experimental': {
'epsilon': 1.0e-08,
'weight_decay': 1.0e-05,
'global_clipnorm': -1.0,
},
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.0001,
'end_learning_rate': 0.000001,
'offset': 0,
'power': 1.0,
'decay_steps': 50 * steps_per_epoch,
},
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2 * steps_per_epoch,
'warmup_learning_rate': 0,
},
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config



@exp_factory.register_config_factory('rngdet_cityscale_detr')
def rngdet_cityscale() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 16
eval_batch_size = 64
steps_per_epoch = CITYSCALE_TRAIN_EXAMPLES // train_batch_size
train_steps = 50 * steps_per_epoch # 50 epochs
config = cfg.ExperimentConfig(
task=RngdetTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint_modules='backbone',
model=Rngdet(
input_size=[128, 128, 3],
roi_size=128,
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise*'),
#input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train-noise-8-00000-of-00032.tfrecord*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(CITYSCALE_INPUT_PATH_BASE, 'train_noise*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=CITYSCALE_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=1*steps_per_epoch,
validation_interval=1*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 1e-5,
'epsilon': 1e-08,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [20 * steps_per_epoch,
30 * steps_per_epoch,
40 * steps_per_epoch],
'values': [1.0e-05, 1.0e-05, 1.0e-06, 1.0e-07]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config

51 changes: 51 additions & 0 deletions official/projects/rngdet/configs/rngdet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for detr."""

# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.detr.configs import detr as exp_cfg
from official.projects.detr.dataloaders import coco


class DetrTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.parameters(('detr_coco',))
def test_detr_configs_tfds(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()

@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()


if __name__ == '__main__':
tf.test.main()
Loading