Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeWarning: invalid value encountered in sqrt #36

Closed
insop opened this issue May 12, 2019 · 10 comments
Closed

RuntimeWarning: invalid value encountered in sqrt #36

insop opened this issue May 12, 2019 · 10 comments

Comments

@insop
Copy link

insop commented May 12, 2019

Hi,

I have ran to this error while I was running (upsnet_resnet50_coco_1gpu.yaml is just a number of gpu change based on 4gpu.yaml

python -u upsnet/upsnet_end2end_train.py --cfg upsnet/experiments/upsnet_resnet50_coco_1gpu.yaml
222 2019-05-11 23:08:29,257 | callback.py | line 40 : Batch [1560]  Speed: 2.28 samples/sec Train-rpn_cls_loss=0.202601,    rpn_bbox_loss=0.119188, rcnn_accuracy=0.918098, cls_loss=0.512433,      bbox_loss=0.178919,     mask_loss=0.637334,     fcn_loss=3.486774,      fcn_roi_loss=4.059261,  panoptic_accuracy=0.267040,     panoptic_loss=2.879008,
223 2019-05-11 23:08:38,512 | callback.py | line 40 : Batch [1580]  Speed: 2.16 samples/sec Train-rpn_cls_loss=0.204118,    rpn_bbox_loss=0.121994, rcnn_accuracy=0.918320, cls_loss=0.511436,      bbox_loss=0.178261,     mask_loss=0.637755,     fcn_loss=3.488841,      fcn_roi_loss=4.060041,  panoptic_accuracy=0.265982,     panoptic_loss=2.881260,
224 2019-05-11 23:08:47,145 | callback.py | line 40 : Batch [1600]  Speed: 2.32 samples/sec Train-rpn_cls_loss=0.208637,    rpn_bbox_loss=0.125312, rcnn_accuracy=0.918627, cls_loss=0.510648,      bbox_loss=0.177416,     mask_loss=0.637829,     fcn_loss=3.493248,      fcn_roi_loss=4.063709,  panoptic_accuracy=0.264609,     panoptic_loss=2.885051,
225 2019-05-11 23:08:55,470 | callback.py | line 40 : Batch [1620]  Speed: 2.40 samples/sec Train-rpn_cls_loss=0.210174,    rpn_bbox_loss=0.125260, rcnn_accuracy=0.918797, cls_loss=0.510753,      bbox_loss=0.176945,     mask_loss=0.637923,     fcn_loss=3.496151,      fcn_roi_loss=4.067450,  panoptic_accuracy=0.264234,     panoptic_loss=2.887030,
226 upsnet/../upsnet/operators/modules/fpn_roi_align.py:38: RuntimeWarning: invalid value encountered in sqrt
227   feat_id = np.clip(np.floor(2 + np.log2(np.sqrt(w * h) / 224 + 1e-6)), 0, 3)
228 Traceback (most recent call last):
229   File "upsnet/upsnet_end2end_train.py", line 403, in <module>
230     upsnet_train()
231   File "upsnet/upsnet_end2end_train.py", line 269, in upsnet_train
232     output = train_model(*batch)
233   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
234     result = self.forward(*input, **kwargs)
235   File "upsnet/../lib/utils/data_parallel.py", line 110, in forward
236     return self.module(*inputs[0], **kwargs[0])
237   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
238     result = self.forward(*input, **kwargs)
239   File "upsnet/../upsnet/models/resnet_upsnet.py", line 139, in forward
240     cls_label, bbox_target, bbox_inside_weight, bbox_outside_weight, mask_target)
241   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
242     result = self.forward(*input, **kwargs)
243   File "upsnet/../upsnet/models/rcnn.py", line 190, in forward
244     cls_loss = self.cls_loss(cls_score, cls_label)
245   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
246     result = self.forward(*input, **kwargs)
247   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 942, in forward
248     ignore_index=self.ignore_index, reduction=self.reduction)
249   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 2056, in cross_entropy
250     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
251   File "/opt/xxx_workspace/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1869, in nll_loss
252     .format(input.size(0), target.size(0)))
253 ValueError: Expected input batch_size (510) to match target batch_size (512).
@insop
Copy link
Author

insop commented May 12, 2019

Maybe this is the main cause instead of runtime warning from sqrt

 ValueError: Expected input batch_size (510) to match target batch_size (512).

@JoyHuYY1412
Copy link

Maybe this is the main cause instead of runtime warning from sqrt

 ValueError: Expected input batch_size (510) to match target batch_size (512).

I also met this question. But it wouldn't happen when I use multi-gpu.

@YuwenXiong
Copy link
Contributor

When you change # gpu from 4 to 1 you also need to reduce lr by 4x and increase # iter by 4x.

@insop
Copy link
Author

insop commented May 12, 2019

With the below change, I was able to run after my post.

And you confirmed it, thank you for the quick update.

--- upsnet/experiments/upsnet_resnet50_coco_4gpu.yaml	2019-05-11 15:21:57.000000000 -0700
+++ upsnet/experiments/upsnet_resnet50_coco_1gpu.yaml	2019-05-12 00:37:26.000000000 -0700
@@ -2,7 +2,7 @@
 output_path: "./output/upsnet/coco"
 model_prefix: "upsnet_resnet_50_coco_"
 symbol: resnet_50_upsnet
-gpus: '0,1,2,3'
+gpus: '0'
 dataset:
   num_classes: 81
   num_seg_classes: 133
@@ -32,12 +32,12 @@
   snapshot_step: 2000
   resume: false
   begin_iteration: 0
-  max_iteration: 360000
+  max_iteration: 720000
   decay_iteration:
-  - 240000
-  - 320000
+  - 480000
+  - 640000
   warmup_iteration: 1500
-  lr: 0.005
+  lr: 0.0025
   wd: 0.0001
   momentum: 0.9
   batch_size: 1
@@ -54,7 +54,7 @@
   - 800
   max_size: 1333
   batch_size: 1
-  test_iteration: 360000
+  test_iteration: 720000
   panoptic_stuff_area_limit: 4096
   vis_mask: false

@insop insop closed this as completed May 12, 2019
@YuwenXiong
Copy link
Contributor

changing #iter/lr by 2x may not match the result I reported as you changed batch size from 4 to 1, and they (batch size/lr/#iter) should be matched.

@insop
Copy link
Author

insop commented May 12, 2019

Right, thank you and updated with by 4

**--- upsnet/experiments/upsnet_resnet50_coco_4gpu.yaml	2019-05-11 15:21:57.000000000 -0700
+++ upsnet/experiments/upsnet_resnet50_coco_1gpu.yaml	2019-05-12 05:58:38.000000000 -0700
@@ -2,7 +2,7 @@
 output_path: "./output/upsnet/coco"
 model_prefix: "upsnet_resnet_50_coco_"
 symbol: resnet_50_upsnet
-gpus: '0,1,2,3'
+gpus: '0'
 dataset:
   num_classes: 81
   num_seg_classes: 133
@@ -32,12 +32,12 @@
   snapshot_step: 2000
   resume: false
   begin_iteration: 0
-  max_iteration: 360000
+  max_iteration: 1440000
   decay_iteration:
-  - 240000
-  - 320000
+  - 960000
+  - 1280000
   warmup_iteration: 1500
-  lr: 0.005
+  lr: 0.00125
   wd: 0.0001
   momentum: 0.9
   batch_size: 1
@@ -54,7 +54,7 @@
   - 800
   max_size: 1333
   batch_size: 1
-  test_iteration: 360000
+  test_iteration: 1440000
   panoptic_stuff_area_limit: 4096
   vis_mask: false**

@YuwenXiong YuwenXiong mentioned this issue May 14, 2019
@andyhahaha
Copy link

Hi I encounter a similar error.
I change the backbone to PeleeNet and train with 4 gpu.
But feat_id will have some elements are nan.

feat_id = np.clip(np.floor(2 + np.log2(np.sqrt(w * h) / 224 + 1e-6)), 0, 3)

It is because the propose rois has x1>x2 or y1>y2 which cause the w<0 or h<0.
np.log2(negative number ) cause nan.

I have tried smaller learning rate. 0.0025 or 0.00125. But it still happen.
Do anyone know how to solve this problem?
Thanks!

@weixianghong
Copy link

Hi I encounter a similar error.
I change the backbone to PeleeNet and train with 4 gpu.
But feat_id will have some elements are nan.

feat_id = np.clip(np.floor(2 + np.log2(np.sqrt(w * h) / 224 + 1e-6)), 0, 3)

It is because the propose rois has x1>x2 or y1>y2 which cause the w<0 or h<0.
np.log2(negative number ) cause nan.
I have tried smaller learning rate. 0.0025 or 0.00125. But it still happen.
Do anyone know how to solve this problem?
Thanks!

I also met the same issue. After I change the backbone to ResNeXT-101, RPN will produces negative width or height and causes NaN.
Have you solved it? May you guide me?

@YuwenXiong
Copy link
Contributor

@andyhahaha @weixianghong Please notice that we used pretrained weights converted from caffe, which are expecting different image preprocessing comparing to torchvision model. Please set use_caffe_model to false if you wanna use models with torchvision-style preprocessing

@weixianghong
Copy link

@andyhahaha @weixianghong Please notice that we used pretrained weights converted from caffe, which are expecting different image preprocessing comparing to torchvision model. Please set use_caffe_model to false if you wanna use models with torchvision-style preprocessing

It works, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants