Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSPResNeXt50-PANet-SPP #698

Closed
LukeAI opened this issue Dec 9, 2019 · 109 comments
Closed

CSPResNeXt50-PANet-SPP #698

LukeAI opened this issue Dec 9, 2019 · 109 comments
Labels
enhancement New feature or request Stale

Comments

@LukeAI
Copy link
Contributor

LukeAI commented Dec 9, 2019

Does this repo. support CSPResNeXt50-PANet-SPP? (https://github.com/WongKinYiu/CrossStagePartialNetworks/)

AlexeyABs support: AlexeyAB/darknet#4406

My tests have found it to be a clear winner over yolov3-spp in terms of mAP and speed.

@LukeAI LukeAI added the enhancement New feature or request label Dec 9, 2019
@glenn-jocher
Copy link
Member

@LukeAI hi, thanks for the feedback! Off the top of my head I think we may not support some of the layers there (#631 (comment)). Do you have an exact *.cfg file that you saw improvements with?

Is this complementary to Gaussian YOLO, can they both be used togethor? So this would be a replacement of the darknet53 backbone with a ReseXt50 backbone?

It's too bad @WongKinYiu didn't do the modifications directly in this repo :)

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 9, 2019

@WongKinYiu I'd like to implement this cfg in ultralytics/yolov3:
https://github.com/WongKinYiu/CrossStagePartialNetworks/blob/master/cfg/csresnext50-panet-spp.cfg

The only new field I see is 'groups' in the convolution layers. Are there other new fields I didn't see?
https://github.com/WongKinYiu/CrossStagePartialNetworks/blob/ff762e58750a2261d64855ac9c3a3ea1a993a24a/cfg/csresnext50-panet-spp.cfg#L383-L390

Do you know where I would slot groups into the PyTorch nn.Conv2d() module?

yolov3/models.py

Lines 22 to 33 in 07c1faf

if mdef['type'] == 'convolutional':
bn = int(mdef['batch_normalize'])
filters = int(mdef['filters'])
size = int(mdef['size'])
stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x']))
pad = (size - 1) // 2 if int(mdef['pad']) else 0
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
out_channels=filters,
kernel_size=size,
stride=stride,
padding=pad,
bias=not bn))

@glenn-jocher
Copy link
Member

@LukeAI @WongKinYiu I've added import of 'groups' into the Conv2d() definition in 3bfbab7. Is this sufficient to run CSPResNeXt50-PANet-SPP? @LukeAI can you git pull this repo and try with the cfg?

yolov3/models.py

Lines 22 to 34 in 3bfbab7

if mdef['type'] == 'convolutional':
bn = int(mdef['batch_normalize'])
filters = int(mdef['filters'])
size = int(mdef['size'])
stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x']))
pad = (size - 1) // 2 if int(mdef['pad']) else 0
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
out_channels=filters,
kernel_size=size,
stride=stride,
padding=pad,
groups=int(mdef['groups']) if 'groups' in mdef else 1,
bias=not bn))

@glenn-jocher
Copy link
Member

yolov3-spp.cfg has 17 unique fields in it's cfg:

17 ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'from', 'layers', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random']

csresnext50-panet-spp.cfg has 18 unique fields. It seems group is the only newcomer. Ok, so this repo should now fully support csresnext50-panet-spp.cfg @LukeAI.

18 ['type', 'batch_normalize', 'filters', 'size', 'stride', 'pad', 'activation', 'layers', 'groups', 'from', 'mask', 'anchors', 'classes', 'num', 'jitter', 'ignore_thresh', 'truth_thresh', 'random']

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 9, 2019

@WongKinYiu I am getting an error in 3 of the shortcut layers when running csresnext50-panet-spp.cfg. They are trying to add tensors of torch.Size([1, 64, 104, 104]) from -4 layers previous with incorrectly sized tensors of torch.Size([1, 128, 104, 104]).

# models.py line 260:
            elif mtype == 'shortcut':
                try:
                    x = x + layer_outputs[int(mdef['from'])]
                except:
                    print(i, x.shape, layer_outputs[int(mdef['from'])].shape)
                    x = layer_outputs[int(mdef['from'])]

# excepted layers:
# 8 torch.Size([1, 128, 104, 104]) torch.Size([1, 64, 104, 104])
# 12 torch.Size([1, 128, 104, 104]) torch.Size([1, 64, 104, 104])
# 16 torch.Size([1, 128, 104, 104]) torch.Size([1, 64, 104, 104])

Possible Fix

Change filters=128 to filters=64 on csresnext50-panet-spp.cfg lines 79, 110, 141. Then all the shapes combine correctly in this repo. Implemented in 86588f1. Not sure if this is a correct modification according to the original cfg designer.

@WongKinYiu
Copy link

WongKinYiu commented Dec 10, 2019

@glenn-jocher Hello,

In pytorch, do zero padding to same size, then add.
for example, in pytorch 0.4:

        if residual_channel != shortcut_channel:
            padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 
            out += torch.cat((shortcut, padding), 1)
        else:
            out += shortcut 

@WongKinYiu
Copy link

@glenn-jocher

For convenient, I change line 57 to filters=128 instead of filter=64 to make it has consistent filter number. Here are csresnext50c.cfg and csresnext50c.conv.80.

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 11, 2019

@WongKinYiu @LukeAI @AlexeyAB I trained csresnext50-panet-spp.cfg 86588f1 against default yolov3-spp.cfg for 27 COCO epochs at 416 (10% of full training), but got worse results at a slower speed. I ran yolov3-spp3.cfg (see #694) with slightly worse results as well. Commands to reproduce:

git clone https://github.com/ultralytics/yolov3
bash yolov3/data/get_coco_dataset_gdrive.sh
cd yolov3
python3 train.py --epochs 27 --weights '' --cfg yolov3-spp.cfg --name 113
python3 train.py --epochs 27 --weights '' --cfg yolov3-spp3.cfg --name 115
python3 train.py --epochs 27 --weights '' --cfg csresnext50-panet-spp.cfg --name 121
  mAP
@0.5...0.95
mAP
@0.5
time (hrs)
to 27 epochs
yolov3-spp.cfg 29.7 49.5 12.7
yolov3-spp3.cfg 29.1 49.0 13.5
csresnext50-panet-spp.cfg 86588f1 25.9 44.2 28.3
csresnext50-panet-spp.cfg zero-pad TODO?
csresnext50c.cfg TODO?

results

If you guys have time and are good with PyTorch please feel free to clone this repo and try the https://github.com/WongKinYiu/CrossStagePartialNetworks/ implementations yourself. I'd really like to exploit some of the research there but I don't have time. We are getting excellent results with our baseline yolov3-spp.cfg from scratch (40.9mAP@0.5...0.95, 60.9mAP@0.5 see https://github.com/ultralytics/yolov3#map), so if the improvements are relative, then they should help here also I assume.

@WongKinYiu
Copy link

ok, i ll try to install this repo.

so all of ur training do not use imagenet pre-trained model?

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 11, 2019

@WongKinYiu ok great! No I don't use any pre-trained model for the initial weights. In an earlier test I found that starting from darknet53.conv.74 produced worse mAP after 273 epochs than starting from randomly initialized weights. For quick results (a day or less of training) yes, the imagenet trained weights will help, but for longer training I found they hurt.

To reproduce:

git clone https://github.com/ultralytics/yolov3
bash yolov3/data/get_coco_dataset_gdrive.sh
cd yolov3
python3 train.py --epochs 273 --weights darknet53.conv.74 --cfg yolov3-spp.cfg --name 41
python3 train.py --epochs 273 --weights '' --cfg yolov3-spp.cfg --name 42
  mAP
@0.5...0.95
mAP
@0.5
results41: 416 multiscale to 273 epochs (darknet53.conv.74 start) 56.8 36.2
results42: 416 multiscale to 273 epochs (random start) 57.5 37.1

results

@WongKinYiu
Copy link

@glenn-jocher

Thanks for your reply.
PANet need more training epochs to converge when compare with YOLOv3.

Do your models are trained using single GPU?

@glenn-jocher
Copy link
Member

@WongKinYiu yes I typically train them on one 2080Ti or V100, which usually do about 50 epochs per day with the default settings (5 days to train COCO). See https://github.com/ultralytics/yolov3#speed for training speeds. Multi-GPU can also be used.

To get the best mAPs though --multi-scale must be used, which adds about 50% more training time (7-8 days on 1 GPU). This is why I usually test changes on 27 epochs.

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 11, 2019

@glenn-jocher

Thanks for your reply.
PANet need more training epochs to converge when compare with YOLOv3.

Should I try csresnext50c.cfg?

UPDATE: I put it in, but there are new layers again :)

python3 train.py --epochs 27 --weights '' --cfg csresnext50c.cfg --name 122

Warning: Unrecognized Layer Type: avgpool
Warning: Unrecognized Layer Type: softmax

@WongKinYiu
Copy link

@glenn-jocher

No, if train from scratch, i think u will get similar results.
panet has additional path than fpn, so it need more epochs.

oh, it is becuz csresnext50c.cfg is for imagenet classifier.

@glenn-jocher
Copy link
Member

@glenn-jocher

No, if train from scratch, i think u will get similar results.
panet has additional path than fpn, so it need more epochs.

oh, it is becuz csresnext50c.cfg is for imagenet classifier.

Oh, haha, ok I'll leave csresnext50c.cfg alone then.

@WongKinYiu
Copy link

@glenn-jocher

start training...
do u use python3 train.py --epochs 273 --weights '' --cfg yolov3-spp.cfg --name 42 to get 40.9 AP?

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 11, 2019

@WongKinYiu the exact training command to get to 40.9 AP with one GPU is:

python3 train.py --weights '' --epochs 273 --batch 16 --accumulate 4 --multi --pre

If you use multi-GPU though you will have more memory available, so you can use a larger --batch --accumulate combination to get to 64 like 32x2, or even 64x1:

python3 train.py --weights '' --epochs 273 --batch 32 --accumulate 2 --multi --pre

yolov3-spp.cfg is the default cfg, so you don't need to supply it above (but you can). The --pre argument performs one epoch of biasing the yolo output neurons before training starts. See #460

@WongKinYiu
Copy link

my gpu ram is not enough even though i set --batch 16 --accumulate 4 --multi --pre.
i will borrow other gpu for training.

@AlexeyAB
Copy link

@glenn-jocher

I trained csresnext50-panet-spp.cfg 86588f1 against default yolov3-spp.cfg for 27 COCO epochs at 416 (10% of full training), but got worse results at a slower speed.

This is weird, did you measure speed on GPU? And what FPS/ms did you get for SPP vs CSP?

Have you tried converting an already trained on Darknet model CSPResNeXt50-PANet-SPP (cfg / weights) to ultralytics (pytorch), and did you get better mAP and better speed?

Or does this inconsistency interfere with this conversion? #698 (comment)

@WongKinYiu
Copy link

@AlexeyAB Hello,

I think slow speed is talking about training speed.
training of group convolution is slower than training of conventional convolution.

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 12, 2019

@AlexeyAB @WongKinYiu if I run test.py on the two trained models, this applies inference (and NMS) on the 5000 images in 5k.txt. This takes 138 seconds with yolov3-spp.cfg on a P4 GPU, and 139 seconds with csresnext50-panet-spp.cfg. Ah interesting, so the inference speed is nearly identical, but the training speed takes twice as long.

@FranciscoReveriano
Copy link
Contributor

So is the CSPResNeXt50-PANet-SPP operational? And does it provide better results? I am looking more into it right now. And reading the article.

@glenn-jocher
Copy link
Member

my gpu ram is not enough even though i set --batch 16 --accumulate 4 --multi --pre.
i will borrow other gpu for training.

@WongKinYiu I forgot to mention, you should install Nvidia Apex for mixed precision training with this repo. It increases speed substantially and reduces memory requirements substantially. Once installed correctly you should see this:
Screen Shot 2019-12-14 at 2 36 45 PM

See https://github.com/NVIDIA/apex#quick-start

@WongKinYiu
Copy link

@glenn-jocher

Yes, I have installed apex.
Now I training with --multi with scale 320~608.

@WongKinYiu
Copy link

@glenn-jocher Hello,

I would like to know why --pre need a little bit more gpu memory than without using it.

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 16, 2019

@WongKinYiu ahhh this is interesting, I had not realized that. There is a memory leak when invoking train.py repeatedly, which is very obvious when running hyperparameter evolution as train.py is called repeatedly in a foor loop #392 (comment), but I did not realize --pre also causes this. This makes sense though, as it is calling train.py once to train the output biases for one epoch, then calling it again for actual training. How much extra memory is this using?

@LukeAI
Copy link
Contributor Author

LukeAI commented Dec 20, 2019

Is this complementary to Gaussian YOLO, can they both be used togethor?

I independently found good improvements ~+3mAP with Gaussian-Yolo and also cspresnext50-pan-spp vs. yolov3-spp - but I got pretty bad results when I tried combining them (-10mAP) - this may be because:
(1) I made a mistake
(2) The features that are useful to gaussian-yolo are quite different to the features that are useful for yolo so training a network with a gaussian-yolo head from pretrained weights from a non-gaussian head gives poor results
(3) I need to tune the hyper-parameters more (I tried 3 different learning rates but no dice)
(4) for some subtle reason these features just don't play well together - seems unlikely to me though.

@WongKinYiu have you tried Gaussian with cspresnext-pan-spp? Do you have any thoughts or results?

@WongKinYiu
Copy link

WongKinYiu commented Dec 20, 2019

I think you need more iterations for warmup when combine cspresnext50-pan-spp with gaussian-yolo (I have no gpus to test it currently).

In my experiments, when combining cspresnext50-pan-spp with gaussian-yolo, the precision drops and recall improves. And the strange thing is that the loss become lager after 200k epochs.

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 10, 2020

@WongKinYiu ok great! I got the last darknet model to run, but mAPs came back as 0.0. Note that I modified my default test nms --iou-thres from 0.5 to 0.6, as this produces a better balance of mAP@0.5:0.95 (best at --iou-thres 0.7) and mAP@0.5 (best at --iou-thres 0.5).

Also note the latest yolov3-spp.cfg baseline trains to 41.9/61.8 at 608 with the default settings. The training commands to reproduce this are here. The two seperate --img-size are train img-size and test img-size. Multi-scale train img sizes using this command will be 288 - 640.

python3 train.py --data coco2014.data --img-size 416 608 --epochs 273 --batch 16 --accum 4 --weights '' --device 0 --cfg yolov3-spp.cfg --multi

@WongKinYiu
Copy link

@glenn-jocher

Note that I modified my default test nms --iou-thres from 0.5 to 0.6, as this produces a better balance of mAP@0.5:0.95 (best at --iou-thres 0.7) and mAP@0.5 (best at --iou-thres 0.5).

Yes, I know.
However, for the competition, we should use same IoU threshold for both mAP@0.5:0.95 and mAP@0.5.

Also note the latest yolov3-spp.cfg baseline trains to 41.9/61.8 with the default settings. The training commands to reproduce this are here. The two seperate --img-size are train img-size and test img-size. Multi-scale train img sizes using this command will be 288 - 640.

Thanks, I just use the default setting of the repo which I used to train the model. As I remember, that repo gets about 40.9 mAP@0.5:0.95 on your report. By the way, all of my results are obtained by test-dev set and your results are obtained by min-val set.

@glenn-jocher
Copy link
Member

@WongKinYiu ah test-dev set could be a difference too then!

Well it seems some differences remain as the ultralytics repo can't load the best performing darknet CSPDarknet53s-PANet-SPP model then. These differences must be the source of the problem I think.

@AlexeyAB
Copy link

@glenn-jocher

Also note the latest yolov3-spp.cfg baseline trains to 41.9/61.8 at 608 with the default settings.

What is the difference between your training and this yolov3-spp.cfg https://github.com/WongKinYiu/CrossStagePartialNetworks/tree/pytorch#ms-coco ?
Why such difference?

@WongKinYiu
Copy link

@AlexeyAB

I use this repo to train: https://github.com/ultralytics/yolov3/tree/a6f87a28e7595e71752583fb41340f9d1105d75f
There are many improvements in these days on ultralytics.

@AlexeyAB
Copy link

@WongKinYiu @glenn-jocher So, I want to know what improvements have been made?

@glenn-jocher
Copy link
Member

Hmmm well lots of small day to day changes. If I use the github /compare it doesn't show the date of that commit, but it shows that there are 400 commits since then, with many modifications:
a6f87a2...master#diff-04c6e90faac2675aa89e2176d2eec7d8

The README from then was showing 40.0/60.9 mAP, which is similar to what @WongKinYiu was seeing, vs today's README which shows 41.9/61.8.

The improvements are over many different parts, such as the NMS, which now uses multi-label, the augmentation, which has been set to zero, the loss function reduction, which I returned to mean() instead of sum(), the cosine scheduler implementation, the increase in the LR to 0.01 after cos was implemented, and maybe a few other tiny things. The architecture itself is the same (yolov3-spp.cfg).

Actually this is an important point. A lot of papers today are showing very outdated comparisons to YOLOv3, i.e. showing 33 mAP@0.5:0.95 like the EfficientDet paper, with a GPU latency of 51ms. The reality is the most recent YOLOv3-SPP model I trained is at 42.1 mAP@0.5:0.95, with a GPU latency of 12.8ms #679 (comment), which puts it far better than their own D0-D2 models in both speed and mAP. I'm not sure how best to get that message out.

Screen Shot 2020-03-10 at 4 27 33 PM

@AlexeyAB
Copy link

AlexeyAB commented Mar 11, 2020

@glenn-jocher
So the main difference:

  1. NMS uses multi-label
  2. the augmentation, which has been set to zero - what does it mean, did you disable data augmentation?
  3. the loss function reduction, which I returned to mean() instead of sum() - are all the true-positive loss values averaged new_loss = sum_for_i( loss_obj, loss_cls, loss_bbox) / count ?

@WongKinYiu
Copy link

image

@glenn-jocher
Copy link
Member

@AlexeyAB

Yes NMS uses multi-label now, which bumped up mAP about +0.3. Yes spatial augmentation seemed to hurt training, so I set it to zero, but left HSV augmentation on:

       'hsv_h': 0.0138,  # image HSV-Hue augmentation (fraction)
       'hsv_s': 0.678,  # image HSV-Saturation augmentation (fraction)
       'hsv_v': 0.36,  # image HSV-Value augmentation (fraction)
       'degrees': 1.98 * 0,  # image rotation (+/- deg)
       'translate': 0.05 * 0,  # image translation (+/- fraction)
       'scale': 0.05 * 0,  # image scale (+/- gain)
       'shear': 0.641 * 0}  # image shear (+/- deg)
  1. The loss is back to it's original form, using the PyTorch defaults, which is for example for the 3 yolo layers: loss_giou = (giou_1.mean() + giou_2.mean() + giou_3.mean()).sum()

I'm really hoping we might be able to merge the YOLO outputs some day so I can do away with this uncertainty in how to combine the losses from the different layers. ASFF seems to be an interesting step in that direction.

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 11, 2020

@AlexeyAB ah also another change I forgot to mention was I changed multi-scale to change the resolution every batch now, instead of every 10 batches before. This seemed to smooth the results a bit, epoch to epoch.

@glenn-jocher
Copy link
Member

@WongKinYiu yes they look super similar to each other unfortunately. I'm not sure why we aren't seeing the same gains as the darknet training. It must have to do with the grouped convolutions I think.

@AlexeyAB
Copy link

@glenn-jocher

Yes NMS uses multi-label now, which bumped up mAP about +0.3.

Does it currently work in such a way?
if there are 2 bboxes with IoU > iou_nms

  1. class1_prob = 0.5, class2_prob = 0.7
  2. class1_prob = 0.7, class2_prob = 0.5

Then it will remove class1_prob = 0.5 and class2_prob = 0.5, and will leave:

  1. class2_prob = 0.7
  2. class1_prob = 0.7

The loss is back to it's original form, using the PyTorch defaults, which is for example for the 3 yolo layers: loss_giou = (giou_1.mean() + giou_2.mean() + giou_3.mean()).sum()

Do you know how this changes the Delta during auto-differentiation in Pytorch?
Do you apply it only for x,y,w,h and not for probs and obj?


Yes spatial augmentation seemed to hurt training, so I set it to zero, but left HSV augmentation on:

Yes, it may help to win compete, but may be it may hurt cross-domain accuracy when testing images/videos are not similar to MS COCO.

It seems it works well because Ultralitics uses letter_box-image-resizing by default, so it keeps aspect ratio and doesn't require large spatial image transformation.
In the Darknet we can try to use jitter=0.1 letter_box=1 instead of jitter=0.3 letter_box=0
I think the higher network resolution - the more preferably to use jitter=0.1 letter_box=1

I'm really hoping we might be able to merge the YOLO outputs some day so I can do away with this uncertainty in how to combine the losses from the different layers.

What do you mean?

I changed multi-scale to change the resolution every batch now, instead of every 10 batches before. This seemed to smooth the results a bit, epoch to epoch.

Does it decrease training speed, because changing of network size requires time?

If we use dynamic_minibatch=1 in the Darknet, when we change width,height,mini_batch dynamically and should reallocate GPU-arrayes for each layer, it can decrease treaining speed 2x-3x times if we will use it after each iteration.

@AlexeyAB
Copy link

@WongKinYiu

Have you checked if scale_x_y=1.1 increases AP95 accuracy, while it decreases AP50 and AP75 but keeps the same AP50...95? https://github.com/WongKinYiu/CrossStagePartialNetworks/blob/master/coco/results.md#mscoco


EfficientNetB0-Yolo was added to the OpenCV-dnn module

So it only requires to implement scale_x_y=1.1 for using csresnext50-panet-spp-original-optimal.cfg with OpenCV-dnn.

@WongKinYiu
Copy link

i have only done experiments for scale_x_y=1.05, scale_x_y=1.1, and scale_x_y=1.2 of different feature pyramids.

have u tested the inference speed of enetb0-yolo using opencv-dnn?

@AlexeyAB
Copy link

have u tested the inference speed of enetb0-yolo using opencv-dnn?

Not yet. I will test it on Intel CPU and Intel Myraid X neurochip

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 12, 2020

@AlexeyAB @WongKinYiu I made a simple Colab notebook to see the time effects of group/mix convolutions.

It times a tensor passing forward and backward (to mimic training) through a Conv2d() op. The speeds stay about the same even as the parameter count drops by >10X. So similar sized models using these ops may be much slower.

b=m(x), x=[16, 128, 38, 38], b=[16, 256, 38, 38]

    groups  time(ms)    params  shape m             
         1       5.1    294912  [256, 128, 3, 3]    
         2       4.2    147456  [256, 64, 3, 3]     
         4       4.2     73728  [256, 32, 3, 3]     
         8       4.9     36864  [256, 16, 3, 3]     
        16       6.9     18432  [256, 8, 3, 3]      
        32       6.1      9216  [256, 4, 3, 3]      
        64       2.6      4608  [256, 2, 3, 3]      
       128       2.0      2304  [256, 1, 3, 3]   

@AlexeyAB
Copy link

@glenn-jocher
Yes, nVidia cuDNN work in the same way.
Also Google Coral TPU-Edge neurochip doesn't use Grouped-conv, despite the fact that they advertise the EffecientDet/Net with grouped convolutions. https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html

@mmaaz60
Copy link

mmaaz60 commented Mar 21, 2020

have u tested the inference speed of enetb0-yolo using opencv-dnn?

Not yet. I will test it on Intel CPU and Intel Myraid X neurochip

Hi @AlexeyAB,

I ran the speed test of this network on the Intel CPU. It looks like it is almost 5 times slower than the Tiny Yolov3 PRN network on CPU as well. Below are the results,

OpenCV: 3.4.10-pre (https://github.com/opencv/opencv/tree/377dd04224630e835cce8c7d67e651cae73fd3b3)
CPU: Intel(R) Core(TM) i5-6200U CPU @ 2.30 GHz
Hard Drive Type: HDD
Display: Off
Yolov3-Tiny-PRN: 21.62 FPS
EfficientNetB0-Yolov3: 4.72 FPS

It looks like depth wise convolutions are slow on CPU as well. Any thoughts?

Thanks

@AlexeyAB
Copy link

@glenn-jocher @mmaaz60 Take a look at the comparison: AlexeyAB/darknet#5079

@AlexeyAB
Copy link

@glenn-jocher Hi, Did you successfully train ASFF model?

@glenn-jocher
Copy link
Member

@AlexeyAB yes I trained ASFF on COCO (results99 in orange), but got slightly worse results in the end compared to default (blue). Performance in the first 5% of epochs was much better, probably because the summation of outputs reduced a lot of that early noise in the model, but did not help after that point.

Of course my implementation might be wrong!

results

@AlexeyAB
Copy link

@glenn-jocher
Is this a comparison of yolov3-spp.cfg and yolov3-asff.cfg? Show the asff-cfg file.
What is the network resolution?
So asff / bifpn don't increase accuracy?

@glenn-jocher
Copy link
Member

glenn-jocher commented Mar 23, 2020

@AlexeyAB yes, basically. I created a 12-anchor version of yolov3-spp.cfg called yolov4.cfg (for 4 anchors per yolo layer) which I used for my comparison (this 12 anchor model increases mAP a tiny bit, about +0.1). I compared yolov4.cfg against yolov4-asff.cfg. For asff I moved all of the yolo layers to the end, and added 3 features to the existing feature vector of length 340 to create the asff weights, so the input to each yolo layer is the same: (1,343,13,13), (1,343,26,26), (1,343,52,52).

I split the feature vectors into the traditional size (1,340,13,13) and the weights (1,3,13,13) for the weighted summations:

yolo1 = (1,340,13,13)*(1,1,13,13) + (1,340,13,13)*(1,1,13,13) + (1,340,13,13)*(1,1,13,13)

etc. using this extra ASFF code. I used sigmoid weights since softmax was much slower, and did a linear interpolation for the resizing.

        if ASFF:
            i, n = self.index, self.nl  # index in layers, number of layers
            p = out[self.layers[i]]
            bs, _, ny, nx = p.shape  # bs, 255, 13, 13
            if (self.nx, self.ny) != (nx, ny):
                create_grids(self, img_size, (nx, ny), p.device, p.dtype)

            # outputs and weights
            # w = F.softmax(p[:, -n:], 1)  # normalized weights
            w = torch.sigmoid(p[:, -n:]) * (2 / n)  # sigmoid weights (faster)
            # w = w / w.sum(1).unsqueeze(1)  # normalize across layer dimension

            # weighted ASFF sum
            p = out[self.layers[i]][:, :-n] * w[:, i:i + 1]
            for j in range(n):
                if j != i:
                    p += w[:, j:j + 1] * \
                         F.interpolate(out[self.layers[j]][:, :-n], size=[ny, nx], mode='bilinear', align_corners=False)

Training was multi-scale 288-640, with metrics plotted at 608 img-size. So no, so far I haven't been able to increase accuracy with BiFPN or ASFF. The only thing that improved a tiny bit was weighted feature fusion, but the gain was tiny (0.1 mAP).

yolov4.cfg.txt
yolov4-asff.cfg.txt

@AlexeyAB
Copy link

@glenn-jocher So do you get AP50...95 higher than 40.6 - 42.4% for ASFF 608x608? https://github.com/ruinmessi/ASFF#coco

It seems that ASFF+RFB or multi-block-BiFPN should use higher network resolution for higher accuracy.

@glenn-jocher
Copy link
Member

@AlexeyAB no, I actually saw worse results for my ASFF impementation, about -0.5mAP at 608 vs the default yolov4.cfg.

Higher image size is definitely one of the ingredients in higher mAPs. EfficientDet uses 512@D0, 640@D1, all the way to 1280@D7:
https://github.com/google/automl/blob/3d88847cc18c69d194490f039279502ddcb536f2/efficientdet/hparams_config.py#L199

The official ASFF trains at 320-608 for 42.4@608 and 480-800 for 43.9@800. https://github.com/ruinmessi/ASFF#models

@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove Stale label or comment or this will be closed in 5 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Stale
Projects
None yet
Development

No branches or pull requests

10 participants