Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train Faster RCNN #494

Closed
chunfuchen opened this issue Nov 16, 2017 · 32 comments
Closed

Train Faster RCNN #494

chunfuchen opened this issue Nov 16, 2017 · 32 comments

Comments

@chunfuchen
Copy link

chunfuchen commented Nov 16, 2017

I get an error to train faster rcnn based on your example; however, with your model, I am able to evaluate its performance and get the same results you posted on github.

Always include the following:

  1. What you did. (command you run if using examples; post or describe your code if not)

./examples/FasterRCNN/train.py --load snapshots/tensorpack/COCO-ResNet50-FasterRCNN.npz --gpu 2,3 --datadir /path/to/COCO14 --logdir snapshots/fasterRCNN-ResNet50

  1. What you observed. (training logs)
[1116 16:23:10 @graph.py:70] Running Op sync_variables_from_main_tower ...  
2017-11-16 16:23:10.457645: E tensorflow/stream_executor/cuda/cuda_driver.cc:1299] could not retrieve CUDA device count: CUDA_ERROR_NOT_INITIALIZED  
[1116 16:23:14 @param.py:144] After epoch 0, learning_rate will change to 0.00300000  
[1116 16:23:15 @base.py:209] Start Epoch 1 ...

and then the program is idle there forever, does it related to the line about CUDA_ERROR_NOT_INITIALIZED

  1. Your environment (TF version, GPUs), if it matters.
    TF version 1.4.0, Python-3.6, CUDA 9, CUDNN-7.
    Tensorpack version: the newest commit.

  2. Others:

  • if I commented out the ds = PrefetchDataZMQ(ds, 1) in get_train_dataflow function. of data.py file, the training is running. Or if I replace ds = PrefetchDataZMQ(ds, 1) by ds = PrefetchData(ds, 500, 1), it will work as well.

Thanks.

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Nov 17, 2017

Yes. I've seen the same error on one machine (but not the others) and use the same solution. I think that's because the GPU on that machine is in exclusive mode -- so using multiprocess may cause problems like this.

@ppwwyyxx
Copy link
Collaborator

Can you check your nvidia-smi -q | grep 'Compute Mode' ?

@chunfuchen
Copy link
Author

My machine has 8 v100 GPU and it is a bare metal machine.
All are Compute Mode: Default.

@ppwwyyxx
Copy link
Collaborator

Interesting. Mine is P100. It might have something to do on how the new GPUs handle the fork. It works on old GPUs, though.
I'll take a deeper look when I got time. Meanwhile you can just disable the prefetch because data is not a bottleneck for detection.

@chunfuchen
Copy link
Author

Got it and thanks.

@chunfuchen
Copy link
Author

Do you mind sharing the performance speed on 8 P100 when training Faster RCNN?

When using 8 gpus, I can only get ~200-300 seconds per epoch, and utilization of each one is about 50%-60%.
The QueueInput/queue_size is 46.96 in the log. Do you think it is related to prefetch?
Thanks.

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Nov 17, 2017

The performance will get stable only after about 3k steps. The default settings will take 70~80 seconds per epoch. GPU utilization is 70%~80% with the current default setting.

@ppwwyyxx
Copy link
Collaborator

The dataflow problem should be fixed now.

@chunfuchen
Copy link
Author

Thanks. It seems that the speed per epoch is varied on my machine even after 30 epoch (9k steps)
It can be from 90 seconds to 200 seconds per epoch (with 8 gpus, no other users use the gpus), I guess that because the number of proposals on each image might be varied which affects the speed of each image.

Furthermore, may I have few questions about training log?

  1. The Warning message of training image:
    COCO_val2014_000000251330.jpg is invalid for training: No valid foreground/background for RPN!
    I think it is okay since the image did not provide FG/BG for RPN

  2. Performance metric:
    I observed that I will get nan on certain metrics, is it normal? Here is log

[1117 10:48:45 @monitor.py:363] rpn_losses/label_metrics/precision_th0.1: 0.39033
[1117 10:48:45 @monitor.py:363] rpn_losses/label_metrics/precision_th0.2: nan
[1117 10:48:45 @monitor.py:363] rpn_losses/label_metrics/precision_th0.5: nan
[1117 10:48:45 @monitor.py:363] rpn_losses/label_metrics/recall_th0.1: 0.98191
[1117 10:48:45 @monitor.py:363] rpn_losses/label_metrics/recall_th0.2: 0.95139
[1117 10:48:45 @monitor.py:363] rpn_losses/label_metrics/recall_th0.5: 0.67927

Thanks.

@ppwwyyxx
Copy link
Collaborator

Here the speed roughly decreased from 70 sec / epoch @epoch10 to 120 sec / epoch @epoch700. It decreases because of more and more positive predictions. I haven't seen 200.

  1. Those images are filtered out.
  2. Precision was only added yesterday so I don't know. It's probably OK for training -- if rpn predicts everything as negative (even just for once) it will be nan forever. But I'll change it because this metric becomes useless.

@chunfuchen
Copy link
Author

chunfuchen commented Nov 18, 2017

Thanks for your help. After about 400 epochs, the speed is more stable (~105 seconds per epoch).

After pull the newest changes, got an error about mismtached data type.

At model.py, line 88:

precision = tf.truediv(pos_prediction_corr, nr_pos_prediction)

Error message:

TypeError: x and y must have the same dtype, got tf.int64 != tf.int32

since you cast the valid_prediction to tf.int32 at line 80:

valid_prediction = tf.cast(valid_label_prob > th, tf.int32)

However, the tf.count_nonzero will return the tf.int64 by default, I should set the dtype for tf.count_nonzero to tf.int32, right?

Note: I am using python3.6.

@ppwwyyxx
Copy link
Collaborator

Yes.

@chunfuchen
Copy link
Author

chunfuchen commented Nov 18, 2017

Another error :(

I just pull the newest changes.

2017-11-18 10:44:54.781565: E tensorflow/stream_executor/cuda/cuda_driver.cc:406] failed call to cuInit: CUDA_ERROR_NO_DEVICE
2017-11-18 10:44:54.781715: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:158] retrieving CUDA diagnostic information for host: xxxx
2017-11-18 10:44:54.781731: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:165] hostname: xxx
2017-11-18 10:44:54.781960: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:189] libcuda reported version is: 384.81.0
2017-11-18 10:44:54.782014: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:369] driver version file contents: """NVRM version: NVIDIA UNIX x86_64 Kernel Module  384.81  Sat Sep  2 02:43:11 PDT 2017
GCC version:  gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.5)
"""
2017-11-18 10:44:54.782060: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:193] kernel reported version is: 384.81.0
2017-11-18 10:44:54.782072: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:300] kernel version seems to match DSO: 384.81.0
^```

@ppwwyyxx
Copy link
Collaborator

This looks more like a problem of your environment.

@chunfuchen
Copy link
Author

hmm, weird, after I comment out the line 78 in utils/box_ops.py

os.environ['CUDA_VISIBLE_DEVICES'] = ''  # we don't want the dataflow process to touch CUDA

It works again. maybe I system need to keep this variable existed even when we would set Ops in CPU.

@chunfuchen
Copy link
Author

Hi, this might be related to the general tensorflow question, if you think stackoverflow is a better place to ask, please just ignore it.

Now, the input image size is [None, None, 3] for dynamic image size in faster rcnn; however, my basemodel needs to do upsampling to dynamically change the image resolution on-the-fly (like encoder-decoder architecture); however, with None type in image size, I can not build the graph since the upsampling needs deterministic shape to perform upsample (I have tried BilinearUpsampling and FixedUnpooling in tensorpack (specify the ratio to 2), and 'tf.image.resize' in tensorflow (it requires final size of an image).)

Do you have any suggestion? Thanks.

@ppwwyyxx
Copy link
Collaborator

I can't understand exactly what is "dynamically change the image resolution". But from all I can see tf.image.resize support target size as a "tensor" so it's dynamic.

@chunfuchen
Copy link
Author

Thanks for your reply. Yes, tf.image.resize_images (or more specific tf.image.resize_bilinear) support target size as a "tensor", but it is a 1-D tensor with [new_height, new_width]; however, when the size of the input of graph is [None,None,3], I do not find out a way to derive the [new_height, new_width] for tf.image.resize_images during building the graph.

E.g. in my graph, there are two tensors, A and B, and the size of A is larger than B; during the running time, I would like to resize B to the size of A (the size of A is varied); however, I can not infer the size of A since I create a placeholder [None,None,3] for A for dynamic input size. Hence, during building the graph, when I try to use get_shape().as_list() to get A's shape, I would only get [None, None, 3] and then I can not derive the size of B for tf.image._resize_images function.

Thanks.

@chunfuchen
Copy link
Author

Sorry, my bad, never mind. Thanks for reminding to use "tensor"... I used to use "list" to specify the targeted tensor size..

@ppwwyyxx
Copy link
Collaborator

FYI I have a bug introduced in Nov 13 and fixed just now. It will affect the precision.
Lots of changes are being pushed recently. I tested the model periodically which means bugs will be found with a delay.

@chunfuchen
Copy link
Author

Noted and thanks.

@chunfuchen
Copy link
Author

chunfuchen commented Nov 20, 2017

A bug of resuming of faster RCNN:

After I resume a trained model, the learning rate of the first epoch will be 0.003.

[1119 22:52:36 @param.py:144] After epoch 0, learning_rate will change to 0.00300000
[1119 22:52:36 @monitor.py:262] Found training history from JSON, now starting from epoch number 117.
[1119 22:52:36 @base.py:209] Start Epoch 117 ...
[1119 22:58:19 @argtools.py:142] WRN Input /home/chenrich/dataset/COCO14/train2014/COCO_train2014_000000273046.jpg is filtered for training: No valid foreground/background for RPN!
[1119 23:06:41 @base.py:219] Epoch 117 (global_step 34800) finished, time:844.78 sec.
[1119 23:06:41 @graph.py:70] Running Op sync_variables_from_main_tower ...
[1119 23:06:41 @monitor.py:363] learning_rate: 0.003

On the other hand, I can understand the speed per epoch might be varied for each epoch since the number of positive proposals might be increased after more epochs; however, when I resume model, the speed will become as slow as training from scratch and then it gradually increases its speed. However, the number of positive proposals after resume should be identical or similar to previous one, why the speed is slow? (As you can see that above log shows finishing time are 844 sec, I can get about 200 seconds on average before resuming.)

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Nov 20, 2017

Tensorflow convolution needs warm up. For variable-size inputs it needs more.
The overall speed will always first increase (until about 10 epochs) and then decrease

@chunfuchen
Copy link
Author

Okay, got it. Thanks.

@ppwwyyxx
Copy link
Collaborator

A subtle bug that makes the result 2 points worse: 6fc4378 .
Now the training curve looks the same as what I had before -- it hasn't finished, but probably is correct now.

This again shows how important it is to match the paper's performance -- if I didn't try to compare with some reference number, I'll never find hidden mistakes like this.

@chunfuchen
Copy link
Author

Thanks for sharing your experience. :)

@chunfuchen
Copy link
Author

chunfuchen commented Nov 22, 2017

Just want to share the trained results, basemodel is ResNet-50:
Evaluation on minival set, FASTRCNN_BATCH=256, ~33h on 8 v100.
Average speed: after epoch 300, it costs ~115 seconds per epcoh.

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.344
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.555
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.365
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.158
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.391
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.498
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.301
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.462
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.480
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.250
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.534
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.66

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Nov 22, 2017

Thanks! I know the model probably gets slightly better than before, but haven't got time to train it.

@Superlee506
Copy link

Superlee506 commented Apr 6, 2018

@ppwwyyxx @chunfuchen How can you achieve 60~200s per epoch? I used 4 K80, and after 3K steps, it still takes 1000s per epoch. I fine-tuned the mask RCNN mode with the follow codes:
python train.py --gpu 0,1,2,3 --load ImageNet-ResNet50.npz
image

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Apr 6, 2018

There are many factors here:

  1. 3k steps may be not enough warmup. I checked my recent logs and saw about 20% better speed at around 10k steps. I'll update the notes later.
  2. I don't have numbers but I won't be surprised if K80 is 4~5 times slower than V100.
  3. https://github.com/ppwwyyxx/tensorpack/blob/8b4d4f779dc48eea299a0b15fe7a0f714e5b8113/examples/FasterRCNN/model.py#L315-L319

These two lines are recently added. They may impact speed (probably not much, if any) but I haven't run a benchmark yet.

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Apr 6, 2018

  1. @chunfuchen was not training a mask-rcnn, but a faster-rcnn. (because mask-rcnn implementation was added later)

@Superlee506
Copy link

@ppwwyyxx Copy that~~Thanks for your patience~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants