Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect in assigning box GT (?!), and suggest the new one #44

Open
ardianumam opened this issue Feb 14, 2019 · 6 comments
Open

Incorrect in assigning box GT (?!), and suggest the new one #44

ardianumam opened this issue Feb 14, 2019 · 6 comments

Comments

@ardianumam
Copy link

Hi,

First, thanks a lot for making this project.
I spot a potential issue in assigning box GT in this code (line 99 and 100), rewritten as:

i = np.floor(gt_boxes[t,0]/self.image_w*grid_sizes[l][1]).astype('int32')
j = np.floor(gt_boxes[t,1]/self.image_h*grid_sizes[l][0]).astype('int32')

Using code above, the label assigment may encounter wrong assignment illustrated in this attachment pic, where the label should be in the grid of (1,1), but code above puts the label in grid (0,0) due to floor operation. So, I suggest to use round operation instead. I know, the network can still learn via the training data, but I think, using round operation, in this case, will be more consistent and makes the network easier to learn. What do you think?

Thanks.

@YunYang1994
Copy link
Owner

YunYang1994 commented Feb 14, 2019

As it said in the original paper:

Each bounding box consists of 5 predictions: x, y, w, h, and confidence. The (x, y) coordinates represent the center of the box relative to the bounds of the grid cell. The width and height are predicted relative to the whole image.

Thanks for your distinctive idea, but I can't agree with you. As we know, round operation contains cell and floor operation. If you feed neural network an ambiguous operation, how can you expect the network to learn easier ? In that case, I think it will be rather more difficult to make the network converge.

@ardianumam
Copy link
Author

ardianumam commented Feb 14, 2019

Oh, I see. But, I mean, the consistency is about which grid we assign for the object label. A good common sense is: (i) assigning to the grid which has biggest IoU to the box GT, right? See another illustration here, if we use floor, both case (1) and (2) will be assigned in the grid (0,0) for the object area, meanwhile, box GT in case (1) and (2) have a distance almost one grid. If we use round, case (1) and (2) will be assigned in the grid (1,1) and (0,0), repectively, just like the common sense I mention in (i) before. So, we can think that consistency is more about assigning label to the one which gives the least distance between box_center and grid_center. In this case, round will give less distance compared to floor. We can also think that by performing round and floor operation, we will lose some information from the 0.x number we remove. Using round, the maximum number we lose is 0.5, while, using floor, it can be 0.99....

Anyway, do you ever run this code to train MS COCO dataset from the scratch (not using pre-training network)? I wonder how many days are needed. I'm currectly still running this code to train MS COCO from the scratch. Now is in epoc 4 (3 days), it looks converged, but seems it needs a lot of epoch, i.e., a lot of days (in the original paper is stated using 160 epoch).

@YunYang1994
Copy link
Owner

Oh, I see. But, I mean, the consistency is about which grid we assign for the object label. A good common sense is: (i) assigning to the grid which has biggest IoU to the box GT, right? See another illustration here, if we use floor, both case (1) and (2) will be assigned in the grid (0,0) for the object area, meanwhile, box GT in case (1) and (2) have a distance almost one grid. If we use round, case (1) and (2) will be assigned in the grid (1,1) and (0,0), repectively, just like the common sense I mention in (i) before. So, we can think that consistency is more about assigning label to the one which gives the least distance between box_center and grid_center. In this case, round will give less distance compared to floor. We can also think that by performing round and floor operation, we will lose some information from the 0.x number we remove. Using round, the maximum number we lose is 0.5, while, using floor, it can be 0.99....

Anyway, do you ever run this code to train MS COCO dataset from the scratch (not using pre-training network)? I wonder how many days are needed. I'm currectly still running this code to train MS COCO from the scratch. Now is in epoc 4 (3 days), it looks converged, but seems it needs a lot of epoch, i.e., a lot of days (in the original paper is stated using 160 epoch).

Oh, I got it. Your idea is very impressive ! But, since YOLO is a regression problem, which means the regression objection must be certain. So we need a particular grid cell location to be regressor.

For training MS COCO dataset from scratch, I have not done it yet. I will appreciate it very much if you would have shared your result with us.

@ardianumam
Copy link
Author

Yes sure, later I can share the result. Now is still in epochs 6, the recall is still 0.0x, but the precision is already high, almost one.

Btw, do you know why running this training code constantly increases the used RAM memory? Keeping running the code will encouter running out of (ram) memory issue and eventually kills the process.

@ardianumam
Copy link
Author

Good news
I wanna share the cause of memory increase in train.py code (currently you already delete it in this repository). The root cause is in this code part:

_, _, _, summary = sess.run([tf.assign(rec_tensor, rec),
                            tf.assign(prec_tensor, prec),
                            tf.assign(mAP_tensor, mAP), write_op], feed_dict={is_training:True})

Putting tf.assign operation inside the training loop will create new additional graph repeatedly. So, I change those three tf.assign by using placeholder, and do feed_dict to them using rec, prec and mAP. The training time is also faster afterward.

@dodogoffy
Copy link

Good news
I wanna share the cause of memory increase in train.py code (currently you already delete it in this repository). The root cause is in this code part:

_, _, _, summary = sess.run([tf.assign(rec_tensor, rec),
                            tf.assign(prec_tensor, prec),
                            tf.assign(mAP_tensor, mAP), write_op], feed_dict={is_training:True})

Putting tf.assign operation inside the training loop will create new additional graph repeatedly. So, I change those three tf.assign by using placeholder, and do feed_dict to them using rec, prec and mAP. The training time is also faster afterward.

Can you share the code ?? Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants