-
-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
width-height: 'yolo' method vs. 'power' method #168
Comments
The However, the switch from darknet to power method for the IMPORTANT: If you use the default This issue has a comparison plot of the two methods: |
@glenn-jocher Appreciate your response. So, occasionally not converging is supposed to be normal when using the darknet method, isn't it? Where else should be applied the power method? So far I've seen it commented in build_targets function and in the yolo layer. Would it work by uncommenting just these and commenting the darknet ones? |
@100330706 yes, these are the two places you need to make the switch, and in ONNX export also if you plan to use that feature. Note that the power method currently trends from 0 to 4 as the input trends from negative infinity to positive infinity. The yolo method has unbounded outputs, causing it diverge on occasion as you've noticed. If you feel you need higher |
TO SUMMARIZE: The 'yolo' width-height (
Lines 260 to 269 in f0b4f9f
Lines 158 to 165 in f0b4f9f
|
so all img values are change to the color default you put in the letterbox function. color = 127.5, 127.5, 127.5. they round up to 128. is it the logic behind it? because there is nothing else that could change all values of img to 128 in coco dataset. what if one uses gray images these numbers should change? |
@sanazss this is grey padding. |
@glenn-jocher I believe in the newest implementation there are only two places where one needs to make this change (doesn't seem like it needs to be done in build-targets). If this is right, could be useful to update your summary to avoid confusion. If this is wrong, then I'd love to know where else to make the change. My loss currently diverges if using the yolo method, but P, R, mAP and F1 all remain zero when using the power method. Odd. |
@bchugg it's true that the darknet/power method is subject to divergence (as is clear from the plot above). The introduction of GIoU has mostly suppressed instances of this happening, though it still does happen on occasion, typically when GIoU hyperparameter or SGD LR are set too high. You are correct also, the introduction of GIoU removed the w/h calculation from build_targets. I will update the 'TO SUMMARIZE' comment above! In any case, to keep things simple for you, I would leave the wh method to the default and simply reduce your GIoU gain or SGD LR hyperparameters: Lines 23 to 30 in a96e010
|
I've clamped the wh output to max=1E4 now to prevent wh divergence. This should resolve the issue completely now. Line 342 in b027c66
|
@glenn-jocher i was implementing yolov2 from scratch for face detection and encountered the same issue as loss going nan just because of wh_loss term. So i came across this power method u have developed , and started training it using power method , but observed that exploding gradients problem is gone now but the model is not converging to an acceptable optimal state and loss is getting stagnant after some 1000 epochs to 15.xx loss. so i have some questions which i hope u can help me with :
also , i am using MSE for regression loss thanks. |
@agarwalyogeesh the power method seen in #168 (comment) is in units of grid points. There are no log operations. The equations are in #168 (comment) Note that GIoU loss implementation seems to fix most of the unstable losses in the original exp wh method, and now we use this combination (GIoU loss with original exp wh). |
@glenn-jocher , thanks for your reply, Also in the inference equation u have mentioned above , the range varies from [0 - 8] right ? . The non converging of the model is what worrying me, maybe its because of less data or maybe i need more training time. thanks, |
@agarwalyogeesh k is a hyperparameter, it's tunable. The range can vary from 0 to any number, in our case we set it to 8. |
This issue is stale because it has been open 30 days with no activity. Remove Stale label or comment or this will be closed in 5 days. |
Thanks for wonderful explanation and idea to solve instability in training. I want to point out that the output range of power wh method will be in range [0-8] and that will be multiplied by the anchor width / height. But this makes / forces the prediction to be greater than or equal to the size of anchor. Not smaller than anchor size. Our prediction should handle / predict smaller or larger boxes than the anchor. So with the current power wh method it is forced to handle / predict only larger (or equal) boxes than the anchor. So I suggest we should use tanh function which has range of [-1, +1]. We than multiply it by 2 to make the range [-2, 2]. After we square it / cube it, we may get the output range [-4, +4] / [-8, +8]. This may solve that issue and probably provide better results. |
@meet-minimalist yes tanh is very similar to sigoid, and it outputs from -1 to 1. But we are outputing a multiple of the anchor, which currently ranges from 0-inf. The exp(x) method has no upper limit on x, causing the instability, but we need an output floor of 0 in all cases. The proposed change is (sigmoid(x) * 2) ** 3, which ranges from 0-8, and retains the same centerpoint f(0)=1 as exp(x). |
Yeah, you are correct. My bad that I forgot this is a multiplier which need to be positive in any case. Thanks again. |
YOLOv3 vs YOLOv5 wh method plotting code: def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
# Compares the two methods for width-height anchor multiplication
# https://github.com/ultralytics/yolov3/issues/168
x = np.arange(-4.0, 4.0, .1)
ya = np.exp(x)
yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
fig = plt.figure(figsize=(6, 3), tight_layout=True)
plt.plot(x, ya, '.-', label='YOLOv3')
plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
plt.xlim(left=-4, right=4)
plt.ylim(bottom=0, top=6)
plt.xlabel('input')
plt.ylabel('output')
plt.grid()
plt.legend()
fig.savefig('comparison.png', dpi=200) |
@glenn-jocher These changes have to be made before training and after these swapping has been done then we have to train if we have to train on custom data. Right? |
@jaskiratsingh2000 this issue is simply explaining the existing updates to our regression equations, no modifications need to be made in the code as the updates are already implemented. |
Hi! We are running your YOLO implementation into a 5 class detection task. However, it seems that at some iteration of some epoch (it is not always the same), the loss suddenly starts quickly going to infinite, giving nan values. The term that it seems that is increasing exponentially is the wh loss (wh tensor sometimes has negative values I don't know if this is normal). By applying your power method
wh = torch.sigmoid(p[..., 2:4]) # wh (power method)
instead ofwh = p[..., 2:4] # wh (yolo method)
it seems that this problem stops and the algorithm does not diverge. However, the wh loss flattens out at a higher value (around 1.07, 1.08 instead of going down to 0) than the other losses as shown below:Do you know any clue about why this could be happening? What are the supposed advantages of using the power method?
Kind regards.
The text was updated successfully, but these errors were encountered: