Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About reproduced results #6

Closed
WongKinYiu opened this issue Jun 1, 2020 · 83 comments
Closed

About reproduced results #6

WongKinYiu opened this issue Jun 1, 2020 · 83 comments
Labels

Comments

@WongKinYiu
Copy link

Hello,

I download the provided ckpt files and use python test.py --img-size 736 --conf_thres 0.001 for reproducing results, but I got different AP and AP50 when compare with the table in readme.

yolov3-spp: 44.1% AP, 64.4% AP50 (Table 45.5% AP, 65.2% AP50)
Speed: 10.4/2.1/12.6 ms inference/NMS/total per 736x736 image at batch-size 16

yolov5s: 31.4% AP, 52.3% AP50 (Table 33.1% AP, 53.3% AP50)
Speed: 2.2/2.1/4.4 ms inference/NMS/total per 736x736 image at batch-size 16

yolov5m: 39.9% AP, 60.7% AP50 (Table 41.5% AP, 61.5% AP50)
Speed: 5.4/1.8/7.2 ms inference/NMS/total per 736x736 image at batch-size 16

yolov5l: 42.7% AP, 63.5% AP50 (Table 44.2% AP, 64.3% AP50)
Speed: 11.3/2.2/13.5 ms inference/NMS/total per 736x736 image at batch-size 16

yolov5x: 45.7% AP, 65.9% AP50 (Table 47.1% AP, 66.7% AP50)
Speed: 20.3/2.2/22.5 ms inference/NMS/total per 736x736 image at batch-size 16

Do the reported results are from the test.py or they are calculated by evaluation server?

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 1, 2020

@WongKinYiu you may be referencing the in-house repo mAP calculation, which is always a little lower than the official pycocotools results, but is used for mAP plots during training because it is much faster than pycocotools.

The correct pycocotools val and test-dev server results are shown in the readme table. You can use the notebook to reproduce easily. yolov3-spp for example:
https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb

Screen Shot 2020-06-01 at 11 18 41 AM

@WongKinYiu
Copy link
Author

WongKinYiu commented Jun 1, 2020

Thanks.

Yes, I got the results from in-house repo mAP calculation.

I just checked the code and find the reason is that my coco path is different as default path in test.py, cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]).
https://github.com/ultralytics/yolov5/blob/master/test.py#L207
So there were no results calculated by pycocotools in my testing.

@glenn-jocher
Copy link
Member

@WongKinYiu ah yes. I moved pycocotools into a try except clause, to avoid the problem where for example training completes but a pycocotools error would prevent the final model from being saved at the very end of the process etc.

So now if pycocotools fails for some reason it will only print a warning to screen and continue with the process. 672 image-size produces results almost as well as 736 BTW.

@WongKinYiu
Copy link
Author

@glenn-jocher

Due to the setting in dataset.py, 672/736 will use 704/768 for testing.
Why don't you directly use 704/768 for producing the results?
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 64.).astype(np.int) * 64
https://github.com/ultralytics/yolov5/blob/master/utils/datasets.py#L325

By the way, there seems have no corresponding process for test task in test.py.
https://github.com/ultralytics/yolov5/blob/master/test.py#L246-L268

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 2, 2020

@WongKinYiu ah, yes you found a bug in test.py! I've pushed a fix for this in ee8988b, along with solving the datasets.py mystery.

It turns out the one line of code that letterboxed images to 64-multiples in datasets.py was helping the mAP by reducing edge effects on certain images. I've modified this line to pad to 32-multiples now, and forced a minimum of 0.5 grid cells (16 pixels) around each image by increasing the grids + 1 (0.5 on each side). This allows for testing yolov3-spp at --img 640 for example with 0.455 mAP now!

The 16 pixel padding looks like this. test.py --img 640 will actually test at 672 now, and --img 672 will actually test at 704, etc, with 16 pixels of letterbox around each side. This allows a much better balance of speed - mAP than before.
e7cb1723-200f-43fa-b078-f1c13dc055c0

@WongKinYiu
Copy link
Author

@glenn-jocher great!

and one more question, I would like to know how much AP improvements you get from change
wh = torch.exp(p[:, 2:4]) * anchor_wh
to
y[..., 2:4] = (y[..., 2:4].sigmoid() * 2) ** 2 * self.anchor_grid[i]?

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 2, 2020

@WongKinYiu ah yes. This change is not about mAP improvement, this change is about training stability on custom datasets. Too many users were reporting unstable/nan width-height loss, I suspect due to exp(), so I removed it and replaced it with a sigmoid that ranges from 0 to 4 instead. This simplifies the code as well for inference, as the entire output is sigmoid() now, removing a bit of slicing tensors. etc.

But to answer your question I don't have an exact number attributable to this change, as this change was one of many that combined to create the change.

@WongKinYiu
Copy link
Author

@glenn-jocher OK, thank you very much.

@WongKinYiu
Copy link
Author

WongKinYiu commented Jun 8, 2020

@glenn-jocher

I just finish training yolov3-spp_csp, and following are results compare with yolov3-spp and yolov3-spp_csp using python test.py --img-size 736 --conf 0.001.

yolov3-spp: 45.5% AP, 65.1% AP50, 49.6% AP75.
Model Summary: 225 layers, 6.29987e+07 parameters, 6.29987e+07 gradients
Speed: 10.4/2.1/12.6 ms inference/NMS/total per 736x736 image at batch-size 16
yolov3-spp_csp: 45.6% AP, 65.4% AP50, 49.7% AP75
Model Summary: 275 layers, 4.90092e+07 parameters, 4.90092e+07 gradients
Speed: 9.1/2.0/11.1 ms inference/NMS/total per 736x736 image at batch-size 16

It seems YOLOv4-based models outperforms than YOLOv3-based models and YOLOv5-based models. yolov3-spp_csp gets 12.5% faster model inference speed and 0.1/0.3/0.1% higher AP/AP50/AP75 than yolov3-spp. Do you have plan to implemented other YOLOv4-based models in this repository?

@glenn-jocher
Copy link
Member

@WongKinYiu excellent!! Good, I'm glad the results are reproducible. And yes, I have also trained a set of yolov5 models using the csp bottlenecks. The csp models are faster and more accurate, except for yolov5x, which is much faster, but may drop a bit in mAP. The updated models swap Bottleneck() for BottleneckCSP(), and increase the P5 bottlenecks from 3 to 6. yolov5s.yaml looks like this for example:

# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple

# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# yolov5 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 3]],  # 1-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4
   [-1, 3, Bottleneck, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 4-P3/8
   [-1, 9, BottleneckCSP, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 6-P4/16
   [-1, 9, BottleneckCSP, [512]],
   [-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
   [-1, 1, SPP, [1024, [5, 9, 13]]],
   [-1, 6, BottleneckCSP, [1024]],  # 10
  ]

# yolov5 head
head:
  [[-1, 3, BottleneckCSP, [1024, False]],  # 11
   [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]],  # 12 (P5/32-large)

   [-2, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 1, Conv, [512, 1, 1]],
   [-1, 3, BottleneckCSP, [512, False]],
   [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]],  # 17 (P4/16-medium)

   [-2, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 1, Conv, [256, 1, 1]],
   [-1, 3, BottleneckCSP, [256, False]],
   [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]],  # 22 (P3/8-small)

   [[], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

The new table looks like this.

Model APval APtest AP50 Latency FPS params FLOPs
YOLOv5-s (ckpt) 35.5 35.5 55.0 2.5ms 400 7.1M 12.6B
YOLOv5-m (ckpt) 42.4 42.4 61.8 4.4ms 227 22.0M 39.0B
YOLOv5-l (ckpt) 45.7 45.9 65.1 6.8ms 147 50.3M 89.0B
YOLOv5-x (ckpt) - - - 11.7ms 85 95.9M 170.3B
YOLOv3-SPP (ckpt) 45.6 45.5 65.2 7.9ms 127 63.0M 118.0B

Now I am waiting for yolov5x to finish training. Probably another 1-2 days, and then I will push a commit with the updates. Note most of the speed gains you see here (compared to current readme table) are due to the reduced flops from the csp bottlenecks, though a bit of the speed improvement (maybe 0.1-0.4ms) is due to increased batch-size from 16 to 32.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 8, 2020

@WongKinYiu also note I did not replace the first bottleneck with csp, because this slowed down training and inference substantially (about 10%), and increased CUDA memory, due the large P2 grid size, and resulted in negligible params/FLOPs savings. There are 3 problems that need resolving now after this upcoming update:

  • Architecture: We should try PANet or BiFPN heads. I have not had time to do this.
  • Training. Overfitting is a serious issue, especially for objectness. Overfitting starts earlier, and is more severe, for the larger models.
  • Scaling. I've been scaling depth by 1/3 and width by 1/4 for each new model. I'm not sure if there is a better way. The gains seem to be diminishing past yolov5l. The gains from s to x are: +6.9, +3.3, +1.3 (maybe). This is a bad trend. Ideally we want a scaling strategy to keep increasing mAP by +3 going from m to l to x and higher.

@WongKinYiu
Copy link
Author

WongKinYiu commented Jun 8, 2020

@glenn-jocher

If there were no accident, I will get results of ALL CSP MODEL (apply CSP on both backbone and neck) next week. And it already show that it is more efficient in both of inference speed and parameters.

yolov3-spp 43.6%
Model Summary: 152 layers, 6.29719e+07 parameters, 6.29719e+07 gradients
Speed: 6.8/1.6/8.3 ms inference/NMS/total per 608x608 image at batch-size 16
cd53s-yocsp 43.9%
Model Summary: 194 layers, 4.3134e+07 parameters, 4.3134e+07 gradients
Speed: 6.0/1.6/7.6 ms inference/NMS/total per 608x608 image at batch-size 16
cd53s-pacsp 45.0%
Model Summary: 230 layers, 5.28891e+07 parameters, 5.28891e+07 gradients
Speed: 6.5/1.7/8.2 ms inference/NMS/total per 608x608 image at batch-size 16
  • Architecture: We should try PANet or BiFPN heads. I have not had time to do this.

I will upload ALL CSP MODEL to github after finish training, here is the cfg file of cd53s-pacsp, you can take a look of it if you would like to implemented it.

  • Training. Overfitting is a serious issue, especially for objectness. Overfitting starts earlier, and is more severe, for the larger models.

Yes, I also find my model becomes over-fitting at ~230 epochs.

  • Scaling. I've been scaling depth by 1/3 and width by 1/4 for each new model. I'm not sure if there is a better way. The gains seem to be diminishing past yolov5l. The gains from s to x are: +6.9, +3.3, +1.3 (maybe). This is a bad trend. Ideally we want a scaling strategy to keep increasing mAP by +3 going from m to l to x and higher.

I am designing new scaling strategy based on CSPNet, will deliver it if it scucess.

@glenn-jocher
Copy link
Member

That's great news! Yes looking at your results you see a big jump with panet head. I will try and use cd53s-pacsp to implement the head this week.

A good example of overfitting is my latest yolov5l run, plotted here from epoch 50 to 280. yolov5s, in comparison, does not really overfit until right before 300, so it is affecting larger models more than smaller models.
results

@glenn-jocher
Copy link
Member

@WongKinYiu yolov5x topped out at 47.2 and started overfitting, so I will update the readme shortly, after I recreate the plots. I've put this for the update. Are there any changes you'd like me to make?

  • June 9, 2020: CSP updates to all YOLOv5 models. New models are faster, smaller and more accurate. Credit to @WongKinYiu for his excellent work with CSP.

Here is a closeup of the results. You can see the overtraining progression from s to x:
results

@WongKinYiu
Copy link
Author

@glenn-jocher Hello,

Could you provide (GPU_latency, AP_val) data of points in the figure?
image
I would like to add YOLOv4 to make a comparison after I finish training my models.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 11, 2020

@WongKinYiu yes, here are the results in the attached study.zip. The results are reproduced by this code. The entire study takes about 1-2 hours to run. study.zip

python test.py --task study

yolov5/test.py

Lines 264 to 275 in 3a5c532

elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
x = list(range(288, 896, 64)) # x axis
y = [] # y axis
for i in x: # img-size
print('\nRunning %s point %s...' % (f, i))
r, _, t = test(opt.data, weights, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json)
y.append(r + t) # results and times
np.savetxt(f, y, fmt='%10.4g') # save
os.system('zip -r study.zip study_*.txt')
# plot_study_txt(f, x) # plot

@WongKinYiu
Copy link
Author

WongKinYiu commented Jun 11, 2020

@glenn-jocher Thanks

I just test YOLOv4(CSP, Leaky) model in the same testing protocol, the results will be like this.
image
Will update the information after finish my training.

And do you test EfficentDet with batch size equals to 32, could you also help for providing the (GPU_latency, AP_val) information of EfficientDet?

@glenn-jocher
Copy link
Member

@WongKinYiu ah, very interesting. Are these new models trained with ultralytics repository or a different one?

@WongKinYiu
Copy link
Author

WongKinYiu commented Jun 11, 2020

@glenn-jocher

Yes, just replace coco14.data to coco17.data to train the model from scratch in #6 (comment).
The performance improves ~0.5% after remove iscrowd=1 bounding boxes.

@glenn-jocher
Copy link
Member

@WongKinYiu ah ok. You should probably plot over the full range of image sizes to get the same comparison:

list(range(288, 896, 64))
Out[5]: [288, 352, 416, 480, 544, 608, 672, 736, 800, 864]

Efficientdet values from their tables here:
https://github.com/google/automl/tree/master/efficientdet

Hmm, I will look at iscrowd again, not sure if I am using them or not.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 11, 2020

@WongKinYiu it may make sense to label ultralytics-trained yolov4 models to differentiate them from darknet-trained yolov4 models also, because I believe there is a performance difference.

There is also an inference difference also, most notably in that the ultralytics repos do not use any of the special tricks mentioned in the yolov4 paper, which supposedly make yolov4 what it is, so I would suspect that the training and inference both taking place on ultralytics repos is important to the story.

@WongKinYiu
Copy link
Author

@glenn-jocher OK,

So in the figure, EfficentDet models use batch size equals to 8, and YOLOv5 models use batch size equals to 32.

I download coco17 using https://github.com/ultralytics/yolov5/blob/master/data/get_coco2017.sh.

yes, darknet-trained yolov4 gets better AP50 while ultralytics-trained yolov4 gets better AP.

@glenn-jocher
Copy link
Member

@WongKinYiu yes, I know the efficientdet values are at batch-size 8. They do not show information for larger batch sizes, probably because their models run out of GPU ram at larger batch-sizes. In the figure caption I clearly state our data conditions, batch-size 32.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 11, 2020

@WongKinYiu batch-size 32 throughput is important for cloud video streaming inference. We have customers that use one 2080ti GPU with 16 simultaneous RTSP feeds running inference in parallel for example, so this number is a direct value of the throughput capability a model can provide them with a single GPU.

Batch-size 1 FPS is not really a useful value for a GPU, as it will underutilize the resource tremendously. It is more informative for example to display this in iDetection, which we plan on doing when we have some time.

@glenn-jocher
Copy link
Member

@WongKinYiu you might ask the efficientdet guys to provide speeds at higher batch-sizes if they can. I'm not sure it's possible though, i.e. google/automl#85

One of the main benefits of yolov5 are the large batch sizes that can be used, since the GPU memory requirements are very low. For example yolov5x can run at --batch-size 64 --img 640 with no problems, and yolov5s can run at up to --batch-size 256 --img 640 on certain rectangular images like HD video.

So yes, I can artificially limit ourselves to batch size 8 to compare to efficientdet, but I'm not sure if this makes sense to do when one of the primary benefits of yolov5 is larger batch sizes and lower gpu memory.

@glenn-jocher
Copy link
Member

@WongKinYiu, I've updated the study results to include yolov3-spp, and also to update yolov5m, which was reporting results from an older training before (new results are +0.3mAP to match correct table).
study_mAP_latency

The updated study results are here: study2.zip

@WongKinYiu
Copy link
Author

WongKinYiu commented Jun 12, 2020

@glenn-jocher Thanks for updating many information.

#6 (comment)

I think it is better to remove efficientdet in the figure if they are running in different condition.

#6 (comment)

I agree large-batch inference is very important in cloud streaming inference, since the latency of model inference usually can be ignored when compare with internet latency.
But batch-size 1 inference is still very important on GPU, a good example is auto-driving scenario, the inference latency of main streaming should be less than 1/200 second.

#6 (comment)

I think efficientdet can run in larger batch-size due to depth-wise convolution usually need less memory space for inference (but it need huge memory space for training).

#6 (comment)

Thanks for the update, it seems currently YOLOv4-608(45.9%) can gets comparable AP as YOLOv5l-736(45.7%).

@WongKinYiu
Copy link
Author

@glenn-jocher

image

  • YOLOv3-SPP & YOLOv5l: trained with resolution 640 +50% -50%
  • YOLOv4: trained with resolution 512 +50% -33%

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 12, 2020

@WongKinYiu thanks for the comparison! Yes it is good to do an apples to apples comparison. YOLOv5 is just getting started, so I hope there is a lot of improvement we can look forward to.

One point though is that all 3 of these models were trained and tested with ultralytics repos, so yolov4 here is not the same results as the paper, with the bag of specials, etc, it is the result of extensive and very difficult ultralytics development over the last year that allows ultralytics yolov4 to exceed the published yolov4 paper results. The yolov4 paper on arxiv shows 43.5 AP. No one will realize this since ultralytics training is not mentioned in the yolov4 paper, only darknet training.

When yolov4 was published, ultralytics/yolov3 43.1 AP was ignored, only very old pjreddie AP of 33 was shown, so if we follow this example, then we should only include published benchmarks for comparison, such as efficientdet, FCOS etc. and 43.5 AP from darknet yolov4.

Also for a full comparison, you may want to include yolov5m, as it will be better than all of these 3 models for all latencies < 5 ms.

@bonlime
Copy link

bonlime commented Aug 7, 2020

@glenn-jocher I saw your discussion about building faster swish/mish and want to comment.

  1. there is no need to build custom CUDA kernels for mish/swish. it could be implemented in pure pytorch using AutoGrad to be memory efficient and jit.script to be faster. like this This version is not traceable/exportable so you need a naive swish/mish implementation to be able to switch.
  2. Mobilenet v3 introduced Hard Swish approximation for Swish. And is has already made it to pytorch master (in 1.6 version). check F.hard_swish. It is almost identical to ReLU in terms of speed but still gives a boost in performance.
  3. Efficient native implementation of Swish is also coming to PyTorch (it is already available in nightly build btw). They call it SiLU because it appeared under this name first, not swish.

@glenn-jocher
Copy link
Member

@bonlime thanks for the tips! In the course of my research above I experimented a bit with all 3 (swish/SiLU), hardswish and Mish. hardswish from mobilenet v3 seems to be a good balance of improved performance with minimal performance penalites, though it may actually be less exportable to coreml than normal swish.

I'm training a few models with updated activations now. I believe we should be able to release a new set of v3.0 models around mid-late August with updated activations.

I don't like the Swish branding either by the way. It's typical corporate modus operandi to take previous research and brand and resell it, api it, or cloud service it.

One odd item is that i see you have hardswish inplace option. I thought it was not possible for any x * f(x) operation to be executed inplace. The hardsigmoid yes, but if x is modified inplace by the hardsigmoid op, how would one still have the original x to multiply it against?

@bonlime
Copy link

bonlime commented Aug 7, 2020

The inplace version is only for compatibility with the rest of the codebase. It's not really in place, it only accepts this arg.

@glenn-jocher
Copy link
Member

@bonlime ah I see, ok.

@WongKinYiu
Copy link
Author

@glenn-jocher @AlexeyAB

Hello,
I tried channel_last in PyTorch.
It can make training and testing faster on GPU with tensor core.
And the AP only drops about 0.002% on channel_first trained model.

Just add code something like:

model = model.to(memory_format=torch.channels_last)
inf_out, train_out = model(img.to(memory_format=torch.channels_last), augment=augment)

reference

@WongKinYiu
Copy link
Author

@glenn-jocher

Finish tested performance of channel last, it increase about 10~15% training speed and ~3% batch-8 inference speed.
And the most interested thing is - it reduce ~30% GPU RAM of training.

https://github.com/ultralytics/yolov5/blob/master/train.py#L77
from

        model = Model(opt.cfg, ch=3, nc=nc).to(device)  # create

to

        model = Model(opt.cfg, ch=3, nc=nc).to(device).to(memory_format=torch.channels_last)  # create

https://github.com/ultralytics/yolov5/blob/master/train.py#L271
from

                pred = model(imgs)

to

                pred = model(imgs.to(memory_format=torch.channels_last))

https://github.com/ultralytics/yolov5/blob/master/test.py#L91
add

    model = model.to(memory_format=torch.channels_last)

https://github.com/ultralytics/yolov5/blob/master/test.py#L104
from

            inf_out, train_out = model(img, augment=augment)  # inference and training outputs

to

            inf_out, train_out = model(img.to(memory_format=torch.channels_last), augment=augment)  # inference and training outputs

@glenn-jocher
Copy link
Member

glenn-jocher commented Aug 25, 2020

@WongKinYiu wow, that's a super speedup!! Even more impressive if it can reduce CUDA memory. That's an amazing discovery you made there!!

I didn't know about this option. Would it make sense to convert the imgs to channels_last at the same time they are converted to pytorch tensors (L93 here for example)?

yolov5/test.py

Lines 92 to 93 in 4fb8cb3

for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
img = img.to(device, non_blocking=True)

EDIT: hmm here they are already tensors, must have been converted in the dataloader here:

return torch.from_numpy(img), labels_out, self.img_files[index], shapes

So checkpointing the channels_last model and then loading it for inference are ok?

@glenn-jocher
Copy link
Member

Looking at the pytorch channels_last examples, it seems .contiguous() plays a role in reversing the format. I wonder what this means for the .contiguous() op in the yolo forward method. I have this in place because it speeds up training a bit vs not having it.

yolov5/models/yolo.py

Lines 35 to 42 in 4fb8cb3

def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

The train and test dataloader also passes the images through a numpy contiguous op. I don't remember if removing this causes an error or if I put it there simply for speed:

yolov5/utils/datasets.py

Lines 557 to 559 in 4fb8cb3

# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)

@rafale77
Copy link
Contributor

rafale77 commented Aug 26, 2020

@glenn-jocher Hello,

Would you like to integrate mish_cuda to the repo?
It takes almost same training time as LeakyReLU model when batch size is same.
And bring huge benefit on AP, following results are training with default setting.

Model Test Size APval AP50val AP75val APSval APMval APLval
YOLOv4(CSP,Leaky) 672 47.2% 65.9% 51.7% 29.8% 52.3% 61.5%
YOLOv4(CSP,Mish) 672 48.1% 66.8% 52.6% 31.9% 53.3% 61.0%
YOLOv4(CSP,Mish) 544 47.1% 65.6% 51.0% 29.0% 52.2% 63.6%
By the way, MixUp can benefit AP by ~0.3%.

@WongKinYiu, just wondering how your x and l mish models compare on the graph with yolo V5?

I have re-implemented within a home assistant project, inferring with a fixed framerate across multiple (6) camera streams I am was able to compare to the opencv implementation of the yoloV4 darknet model:

Just comparing the resources consumed on my setup with an i7 6700K and RTX 2070:

opencv/darknet yoloV4 608: CPU ~30% -- RAM ~ 7GB -- GPU 8% -- VRAM 3GB
pytorch yoloV4l-mish 672: CPU ~ 53% -- RAM <0.5GB -- GPU 9% -- VRAM 2GB
pytorch yolov5l 672: CPU ~50% -- RAM < 0.5GB -- GPU 7% -- VRAM 1.8GB
pytorch yolov4x-mish 672: CPU ~60% -- RAM <0.5GB -- GPU 15% -- VRAM 2.5GB

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@RostamDin
Copy link

How to avoid Dnn speed drop while using yolov5 simultaneously.

I want to detect people in image with Yolo5 and their age with Dnn. Age detection speed is fine before I try to load YOLOV5 to my code. loading the YOLOV5 model causes to drop process speed to less then half even when I am not started using the YOLOV5 model to detect any thing yet.

import numpy as np
import time
import torch
import os



# below line drops the process speed
detectObjectsModel = torch.hub.load('ultralytics/yolov5', 'yolov5s')



age_proto = os.path.join(".\\models", 'age_deploy.prototxt')
age_model = os.path.join(".\\models", 'age_net.caffemodel')
net = cv2.dnn.readNet(age_model, age_proto)
net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)

frame = np.random.randint(255, size=(416, 416, 3), dtype=np.uint8) # put your image here!
cv2.imshow("frame",frame)
blob = cv2.dnn.blobFromImage(frame, 0.00392, (227, 227), [0, 0, 0], True, False)


start = time.time()
for i in range(100):
    net.setInput(blob)
    detections = net.forward(net.getUnconnectedOutLayersNames())
end = time.time()

ms_per_image = (end - start) * 1000 / 100```

BjarneKuehl referenced this issue in fhkiel-mlaip/yolov5 Aug 26, 2022
* Squashed commit of the following:

commit d738487
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 17:33:38 2020 +0700

    Adding world_size

    Reduce calls to torch.distributed. For use in create_dataloader.

commit e742dd9
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 15:38:48 2020 +0800

    Make SyncBN a choice

commit e90d400
Merge: 5bf8beb cd90360
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Tue Jul 14 15:32:10 2020 +0800

    Merge pull request #6 from NanoCode012/patch-5

    Update train.py

commit cd90360
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 13:39:29 2020 +0700

    Update train.py

    Remove redundant `opt.` prefix.

commit 5bf8beb
Merge: c9558a9 880d072
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 14:09:51 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed

commit c9558a9
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 13:51:34 2020 +0800

    Add device allocation for loss compute

commit 4f08c69
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 9 11:16:27 2020 +0800

    Revert drop_last

commit 1dabe33
Merge: a1ce9b1 4b8450b
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 9 11:15:49 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit a1ce9b1
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 9 11:15:21 2020 +0800

    fix lr warning

commit 4b8450b
Merge: b9a50ae 02c63ef
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Wed Jul 8 21:24:24 2020 +0800

    Merge pull request #4 from NanoCode012/patch-4

    Add drop_last for multi gpu

commit 02c63ef
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Wed Jul 8 10:08:30 2020 +0700

    Add drop_last for multi gpu

commit b9a50ae
Merge: ec2dc6c 86e7142
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 7 19:48:04 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed

commit ec2dc6c
Merge: d0326e3 82a6182
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 7 19:34:31 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit d0326e3
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 7 19:31:24 2020 +0800

    Add SyncBN

commit 82a6182
Merge: 96fa40a 050b2a5
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Tue Jul 7 19:21:01 2020 +0800

    Merge pull request #1 from NanoCode012/patch-2

    Convert BatchNorm to SyncBatchNorm

commit 050b2a5
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 7 12:38:14 2020 +0700

    Add cleanup for process_group

commit 2aa3301
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 7 12:07:40 2020 +0700

    Remove apex.parallel. Use torch.nn.parallel

    For future compatibility

commit 77c8e27
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 7 01:54:39 2020 +0700

    Convert BatchNorm to SyncBatchNorm

commit 96fa40a
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Mon Jul 6 21:53:56 2020 +0800

    Fix the datset inconsistency problem

commit 16e7c26
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Mon Jul 6 11:34:03 2020 +0800

    Add loss multiplication to preserver the single-process performance

commit e838055
Merge: 625bb49 31a9f25
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Fri Jul 3 20:56:30 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed

commit 625bb49
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 2 22:45:15 2020 +0800

    DDP established

* Squashed commit of the following:

commit 94147314e559a6bdd13cb9de62490d385c27596f
Merge: 65157e2 9de5a7a
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 16 14:00:17 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov4 into feature/DDP_fixed

commit 9de5a7a
Author: Glenn Jocher <glenn.jocher@ultralytics.com>
Date:   Wed Jul 15 20:03:41 2020 -0700

    update test.py --save-txt

commit 825e729
Author: Glenn Jocher <glenn.jocher@ultralytics.com>
Date:   Wed Jul 15 20:00:48 2020 -0700

    update test.py --save-txt

commit 65157e2
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Wed Jul 15 16:44:13 2020 +0800

    Revert the README.md removal

commit 1c802bf
Merge: cd55b44 0f3b8bb
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Wed Jul 15 16:43:38 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit cd55b44
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Wed Jul 15 16:42:33 2020 +0800

    fix the DDP performance deterioration bug.

commit 0f3b8bb
Author: Glenn Jocher <glenn.jocher@ultralytics.com>
Date:   Wed Jul 15 00:28:53 2020 -0700

    Delete README.md

commit f5921ba
Merge: 85ab2f3 bd3fdbb
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Wed Jul 15 11:20:17 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit bd3fdbb
Author: Glenn Jocher <glenn.jocher@ultralytics.com>
Date:   Tue Jul 14 18:38:20 2020 -0700

    Update README.md

commit c1a97a7
Merge: 2bf86b8 7d73bfb
Author: Glenn Jocher <glenn.jocher@ultralytics.com>
Date:   Tue Jul 14 18:36:53 2020 -0700

    Merge branch 'master' into feature/DDP_fixed

commit 2bf86b8
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 22:18:15 2020 +0700

    Fixed world_size not found when called from test

commit 85ab2f3
Merge: 5a19011 c8357ad
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 22:19:58 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit 5a19011
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 22:19:15 2020 +0800

    Add assertion for <=2 gpus DDP

commit c8357ad
Merge: e742dd9 787582f
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Tue Jul 14 22:10:02 2020 +0800

    Merge pull request ultralytics#8 from MagicFrogSJTU/NanoCode012-patch-1

    Modify number of dataloaders' workers

commit 787582f
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 20:38:58 2020 +0700

    Fixed issue with single gpu not having world_size

commit 6364892
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 19:16:15 2020 +0700

    Add assert message for clarification

    Clarify why assertion was thrown to users

commit 69364d6
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 17:36:48 2020 +0700

    Changed number of workers check

commit d738487
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 17:33:38 2020 +0700

    Adding world_size

    Reduce calls to torch.distributed. For use in create_dataloader.

commit e742dd9
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 15:38:48 2020 +0800

    Make SyncBN a choice

commit e90d400
Merge: 5bf8beb cd90360
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Tue Jul 14 15:32:10 2020 +0800

    Merge pull request #6 from NanoCode012/patch-5

    Update train.py

commit cd90360
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 14 13:39:29 2020 +0700

    Update train.py

    Remove redundant `opt.` prefix.

commit 5bf8beb
Merge: c9558a9 880d072
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 14:09:51 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed

commit c9558a9
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 14 13:51:34 2020 +0800

    Add device allocation for loss compute

commit 4f08c69
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 9 11:16:27 2020 +0800

    Revert drop_last

commit 1dabe33
Merge: a1ce9b1 4b8450b
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 9 11:15:49 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit a1ce9b1
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 9 11:15:21 2020 +0800

    fix lr warning

commit 4b8450b
Merge: b9a50ae 02c63ef
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Wed Jul 8 21:24:24 2020 +0800

    Merge pull request #4 from NanoCode012/patch-4

    Add drop_last for multi gpu

commit 02c63ef
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Wed Jul 8 10:08:30 2020 +0700

    Add drop_last for multi gpu

commit b9a50ae
Merge: ec2dc6c 86e7142
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 7 19:48:04 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed

commit ec2dc6c
Merge: d0326e3 82a6182
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 7 19:34:31 2020 +0800

    Merge branch 'feature/DDP_fixed' of https://github.com/MagicFrogSJTU/yolov5 into feature/DDP_fixed

commit d0326e3
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Tue Jul 7 19:31:24 2020 +0800

    Add SyncBN

commit 82a6182
Merge: 96fa40a 050b2a5
Author: yzchen <Chenyzsjtu@gmail.com>
Date:   Tue Jul 7 19:21:01 2020 +0800

    Merge pull request #1 from NanoCode012/patch-2

    Convert BatchNorm to SyncBatchNorm

commit 050b2a5
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 7 12:38:14 2020 +0700

    Add cleanup for process_group

commit 2aa3301
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 7 12:07:40 2020 +0700

    Remove apex.parallel. Use torch.nn.parallel

    For future compatibility

commit 77c8e27
Author: NanoCode012 <kevinvong@rocketmail.com>
Date:   Tue Jul 7 01:54:39 2020 +0700

    Convert BatchNorm to SyncBatchNorm

commit 96fa40a
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Mon Jul 6 21:53:56 2020 +0800

    Fix the datset inconsistency problem

commit 16e7c26
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Mon Jul 6 11:34:03 2020 +0800

    Add loss multiplication to preserver the single-process performance

commit e838055
Merge: 625bb49 31a9f25
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Fri Jul 3 20:56:30 2020 +0800

    Merge branch 'master' of https://github.com/ultralytics/yolov5 into feature/DDP_fixed

commit 625bb49
Author: yizhi.chen <chenyzsjtu@outlook.com>
Date:   Thu Jul 2 22:45:15 2020 +0800

    DDP established

* Fixed destroy_process_group in DP mode

* Update torch_utils.py

* Update utils.py

Revert build_targets() to current master.

* Update datasets.py

* Fixed world_size attribute not found

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
manole-alexandru added a commit to manole-alexandru/yolov5-uolo that referenced this issue Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

10 participants