Skip to content

Commit

Permalink
upgrade detection / some v2 compat fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jul 11, 2020
1 parent 55b2e0f commit edb7897
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 41 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ jobs:
max-parallel: 6
matrix:
python-version: [3.6]
# TF-version: [1.3.0, 1.14.0, nightly] # TODO make nightly work
TF-version: [1.3.0, 1.14.0]
TF-version: [1.5.0, 1.15.0]
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion examples/FasterRCNN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ with the support of:
+ Training from scratch (from [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883))

## Dependencies
+ OpenCV, TensorFlow ≥ 1.6
+ OpenCV, TensorFlow ≥ 1.14
+ pycocotools/scipy: `for i in cython 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' scipy; do pip install $i; done`
+ Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/#FasterRCNN)
from tensorpack model zoo
Expand Down
3 changes: 1 addition & 2 deletions examples/FasterRCNN/modeling/generalized_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# -*- coding: utf-8 -*-
# File:

import tensorflow as tf

from tensorpack.compat import tfv1 as tf
from tensorpack import ModelDesc
from tensorpack.models import GlobalAvgPooling, l2_regularizer, regularize_cost
from tensorpack.tfutils import optimizer
Expand Down
2 changes: 1 addition & 1 deletion examples/FasterRCNN/modeling/model_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def encode_bbox_target(boxes, anchors):

# Note that here not all boxes are valid. Some may be zero
txty = (xbyb - xaya) / waha
twth = tf.log(wbhb / waha) # may contain -inf for invalid boxes
twth = tf.math.log(wbhb / waha) # may contain -inf for invalid boxes
encoded = tf.concat([txty, twth], axis=1) # (-1x2x2)
return tf.reshape(encoded, tf.shape(boxes))

Expand Down
4 changes: 2 additions & 2 deletions examples/FasterRCNN/modeling/model_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def fpn_map_rois_to_levels(boxes):
"""
sqrtarea = tf.sqrt(tf_area(boxes))
level = tf.cast(tf.floor(
4 + tf.log(sqrtarea * (1. / 224) + 1e-6) * (1.0 / np.log(2))), tf.int32)
4 + tf.math.log(sqrtarea * (1. / 224) + 1e-6) * (1.0 / np.log(2))), tf.int32)

# RoI levels range from 2~5 (not 6)
level_ids = [
Expand Down Expand Up @@ -127,7 +127,7 @@ def multilevel_roi_align(features, rcnn_boxes, resolution):
all_rois = tf.concat(all_rois, axis=0) # NCHW
# Unshuffle to the original order, to match the original samples
level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
level_id_invert_perm = tf.invert_permutation(level_id_perm)
level_id_invert_perm = tf.math.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm, name="output")
return all_rois

Expand Down
15 changes: 5 additions & 10 deletions examples/FasterRCNN/modeling/model_frcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from tensorpack.models import Conv2D, FullyConnected, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized_method
Expand Down Expand Up @@ -74,13 +73,13 @@ def sample_fg_bg(iou):
num_fg = tf.minimum(int(
cfg.FRCNN.BATCH_PER_IM * cfg.FRCNN.FG_RATIO),
tf.size(fg_inds), name='num_fg')
fg_inds = tf.random_shuffle(fg_inds)[:num_fg]
fg_inds = tf.random.shuffle(fg_inds)[:num_fg]

bg_inds = tf.reshape(tf.where(tf.logical_not(fg_mask)), [-1])
num_bg = tf.minimum(
cfg.FRCNN.BATCH_PER_IM - num_fg,
tf.size(bg_inds), name='num_bg')
bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
bg_inds = tf.random.shuffle(bg_inds)[:num_bg]

add_moving_summary(num_fg, num_bg)
return fg_inds, bg_inds
Expand Down Expand Up @@ -151,12 +150,8 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
num_fg = tf.size(fg_inds, out_type=tf.int64)
empty_fg = tf.equal(num_fg, 0)
if int(fg_box_logits.shape[1]) > 1:
if get_tf_version_tuple() >= (1, 14):
fg_labels = tf.expand_dims(fg_labels, axis=1) # nfg x 1
fg_box_logits = tf.gather(fg_box_logits, fg_labels, batch_dims=1)
else:
indices = tf.stack([tf.range(num_fg), fg_labels], axis=1) # nfgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
fg_labels = tf.expand_dims(fg_labels, axis=1) # nfg x 1
fg_box_logits = tf.gather(fg_box_logits, fg_labels, batch_dims=1)
fg_box_logits = tf.reshape(fg_box_logits, [-1, 4]) # nfg x 4

with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
Expand Down Expand Up @@ -253,7 +248,7 @@ def fastrcnn_Xconv1fc_head(feature, num_convs, norm=None):
with argscope(Conv2D, data_format='channels_first',
kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out',
distribution='untruncated_normal' if get_tf_version_tuple() >= (1, 12) else 'normal')):
distribution='untruncated_normal')):
for k in range(num_convs):
l = Conv2D('conv{}'.format(k), l, cfg.FPN.FRCNN_CONV_HEAD_DIM, 3, activation=tf.nn.relu)
if norm is not None:
Expand Down
14 changes: 4 additions & 10 deletions examples/FasterRCNN/modeling/model_mrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from tensorpack.models import Conv2D, Conv2DTranspose, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary

Expand All @@ -20,14 +19,9 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
fg_labels: #fg, in 1~#class, int64
fg_target_masks: #fgxhxw, float32
"""
if get_tf_version_tuple() >= (1, 14):
mask_logits = tf.gather(
mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1)
mask_logits = tf.squeeze(mask_logits, axis=1)
else:
indices = tf.stack([tf.range(tf.size(fg_labels, out_type=tf.int64)),
fg_labels - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fg x h x w
mask_logits = tf.gather(
mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1)
mask_logits = tf.squeeze(mask_logits, axis=1)

mask_probs = tf.sigmoid(mask_logits)

Expand Down Expand Up @@ -74,7 +68,7 @@ def maskrcnn_upXconv_head(feature, num_category, num_convs, norm=None):
with argscope([Conv2D, Conv2DTranspose], data_format='channels_first',
kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out',
distribution='untruncated_normal' if get_tf_version_tuple() >= (1, 12) else 'normal')):
distribution='untruncated_normal')):
# c2's MSRAFill is fan_out
for k in range(num_convs):
l = Conv2D('fcn{}'.format(k), l, cfg.MRCNN.HEAD_DIM, 3, activation=tf.nn.relu)
Expand Down
6 changes: 0 additions & 6 deletions examples/FasterRCNN/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: train.py

import argparse

from tensorpack import *
from tensorpack.tfutils import collect_env_info
from tensorpack.tfutils.common import get_tf_version_tuple

from dataset import register_coco, register_balloon
from config import config as cfg
Expand Down Expand Up @@ -34,10 +32,6 @@
default='train_log/maskrcnn')
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py", nargs='+')

if get_tf_version_tuple() < (1, 6):
# https://github.com/tensorflow/tensorflow/issues/14657
logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN if you're unlucky.")

args = parser.parse_args()
if args.config:
cfg.update_args(args.config)
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/graph_builder/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ def get_post_init_ops():
Copy values of variables on GPU 0 to other GPUs.
"""
# literally all variables, because it's better to sync optimizer-internal variables as well
all_vars = tf.global_variables() + tf.local_variables()
all_vars = tfv1.global_variables() + tfv1.local_variables()
var_by_name = {v.name: v for v in all_vars}
trainable_names = {x.name for x in tf.trainable_variables()}
trainable_names = {x.name for x in tfv1.trainable_variables()}
post_init_ops = []

def log_failure(name, reason):
Expand Down
6 changes: 3 additions & 3 deletions tensorpack/graph_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def _replace_global_by_local(kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = {tf.GraphKeys.GLOBAL_VARIABLES}
collections = {tfv1.GraphKeys.GLOBAL_VARIABLES}
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
collections.remove(tfv1.GraphKeys.GLOBAL_VARIABLES)
collections.add(tfv1.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)


Expand Down
4 changes: 2 additions & 2 deletions tensorpack/train/tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def xla_func():

grads_no_vars = xla.compile(xla_func)
if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
varlist = ctx.get_collection_in_tower(tfv1.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
varlist = tfv1.trainable_variables()
return list(zip(grads_no_vars, varlist))

return get_grad_fn

0 comments on commit edb7897

Please sign in to comment.