Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can u get the mAP as reported in darknet ? #9

Closed
xuzheyuan624 opened this issue Sep 10, 2018 · 41 comments
Closed

Can u get the mAP as reported in darknet ? #9

xuzheyuan624 opened this issue Sep 10, 2018 · 41 comments
Assignees

Comments

@xuzheyuan624
Copy link

No description provided.

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 10, 2018

Issue #7 is open on this topic. The mAP calculation is still under development. The current mAP calculation using this repo with the official yolov3 weights is 56.7 compared to 57.9 in Darknet.

@xuzheyuan624
Copy link
Author

xuzheyuan624 commented Sep 13, 2018

I use your code with the official yolov3 weights and COCO API to calculate AP in 5k, and I get AP[IoU=0.50] = 0.533 when image size is 416x416.Here is the whole results.

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.299
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.533
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.304
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.153
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.264
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.390
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.406
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.236
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.428
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.553

@xuzheyuan624
Copy link
Author

xuzheyuan624 commented Sep 13, 2018

When change image size to 608x608, I get results like this

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.303
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.538
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.307
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.189
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.335
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.387
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.269
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.406
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.423
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.281
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.449
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.527

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 13, 2018

@xuzheyuan624 oh thanks for using the official COCO mAP! Do you have code to generate the COCO input text files for the mAP calculation? Maybe we can update test.py to generate those also.

I've been updating the mAP code recently, so I'm not sure if you are using the most recent version. Can you tell me the equivalent mAP you end up with if you run test.py for the 5000 images? I get 0.567.

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 21, 2018

@xuzheyuan624 could you share the code you used to generate the COCO text files used to get the official COCO mAP? It would be very useful if I could integrate that into test.py, then we could finally close this issue. Thanks!

@xuzheyuan624
Copy link
Author

I have modified some code (I write a dataloader for my own).But u can refer to it to kown how to calculate the mAP by COCO APi
`import os
import json
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.2f')
import sys
import numpy as np
import argparse
from models import *
from utils.datasets import *
from utils.utils import *
from common.coco_dataset import COCODataset

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

parser = argparse.ArgumentParser()
parser.add_argument('-batch_size', type=int, default=4, help='size of each image batch')
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
parser.add_argument('-weights_path', type=str, default='checkpoints/yolov3.weights', help='path to weights file')
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('-conf_thres', type=float, default=0.01, help='object confidence threshold')
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
parser.add_argument('-n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
parser.add_argument('-img_size', type=int, default=608, help='size of each image dimension')
parser.add_argument('-use_cuda', type=bool, default=True, help='whether to use cuda if available')
parser.add_argument('-annotation_path', type=str, default='./data/coco/annotations/instances_val2014.json',help='')
opt = parser.parse_args()
print(opt)

cuda = torch.cuda.is_available() and opt.use_cuda
device = torch.device('cuda:0' if cuda else 'cpu')

data_config = parse_data_config(opt.data_config_path)
num_classes = int(data_config['classes'])
if platform == 'darwin':
test_path = data_config['valid']
else:
test_path = './data/coco/new_5k.txt'

model = Darknet(opt.cfg, opt.img_size)

if opt.weights_path.endswith('.weights'):
load_weights(model, opt.weights_path)
elif opt.weights_path.endswith('.pt'):
checkpoint = torch.load(opt.weights_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
del checkpoint

model.to(device).eval()

dataloader = torch.utils.data.DataLoader(COCODataset(test_path,
(opt.img_size, opt.img_size),
is_training=False),
batch_size=opt.batch_size,
shuffle=False, num_workers=8, pin_memory=False)

index2category = json.load(open("coco_index2category.json"))

print('Start evaling')
coco_results = []
coco_img_ids= set([])
for step, samples in enumerate(dataloader):
images = samples["image"].to(device)
labels = samples["label"]
image_paths, origin_sizes = samples["image_path"], samples["origin_size"]
with torch.no_grad():
output = model(images)
output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
for idx, detections in enumerate(output):
image_id = int(os.path.basename(image_paths[idx])[-16:-4])
coco_img_ids.add(image_id)
if detections is not None:
origin_size = eval(origin_sizes[idx])
detections = detections.cpu().numpy()
for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
x1 = x1 / opt.img_size * origin_size[0]
x2 = x2 / opt.img_size * origin_size[0]
y1 = y1 / opt.img_size * origin_size[1]
y2 = y2 / opt.img_size * origin_size[1]
w = x2 - x1
h = y2 - y1
coco_results.append({
"image_id":image_id,
"category_id":index2category[str(int(cls_pred.item()))],
"bbox":(float(x1), float(y1), float(w), float(h)),
"score":float(conf),
})
print("Now {}/{}".format(step, len(dataloader)))
save_path = "coco_results.json"
with open(save_path, 'w') as f:
json.dump(coco_results, f, sort_keys=True, indent=4, separators=(',', ':'))
print("Using coco-evaluate tools to evaluate.")
cocoGT = COCO(opt.annotation_path)
cocoDT = cocoGT.loadRes(save_path)
cocoEval = COCOeval(cocoGT, cocoDT, "bbox")
cocoEval.params.imgIds = list(coco_img_ids)
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()`

@nirbenz
Copy link

nirbenz commented Oct 15, 2018

@xuzheyuan624 I noticed that your code doesn't shift the boxes back to original image coordinates from before the letterboxing process. Is this done somewhere else internally?

@xuzheyuan624
Copy link
Author

xuzheyuan624 commented Oct 15, 2018

@nirbenz
x1 = x1 / opt.img_size * origin_size[0] x2 = x2 / opt.img_size * origin_size[0] y1 = y1 / opt.img_size * origin_size[1] y2 = y2 / opt.img_size * origin_size[1]
And I also write a pytorch implement recently, u can see my code for details.

@nirbenz
Copy link

nirbenz commented Oct 15, 2018

@xuzheyuan624 this code takes care of scaling back, but what about letterboxing? I mean the process of fitting the image inside a 608x608 square.
If you are not doing this - this might be the reason you are getting an mAP difference!

This is usually done by resizing the larger dimension to 608 (and smaller dimension accordingly) and then padding to the aspect ratio the model was trained on (for COCO - 1x1). @glenn-jocher's original code reverses this process by subtracting the amount of padding done during inference.

@xuzheyuan624
Copy link
Author

@nirbenz Oh, thanks for reminding me that.Indeed, I write a dataloader for my own when I test this code.In this dataloader, I didn't use padding when resizing the image to 608x608. Maybe it's the reason that I got a different mAP and I will try again by keeping aspect ratio.

@nirbenz
Copy link

nirbenz commented Oct 15, 2018

@xuzheyuan624 That would give you aspect ratio invariance, which is probably not what you want for COCO. I wouldn't be surprised if modifying this will get you much closer to the target mAP :)

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 15, 2018

@xuzheyuan624 @nirbenz the dataloader in utils/datasets.py augments and pads the images into -img_size pixel squares. It augments the bounding boxes (the targets) to match the augmented image. You can plot these yourself here. I think this is all good no?

yolov3/utils/datasets.py

Lines 157 to 163 in 24a4197

plotFlag = False
if plotFlag:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10)) if index == 0 else None
plt.subplot(4, 4, index + 1).imshow(img[:, :, ::-1])
plt.plot(labels[:, [1, 3, 3, 1, 1]].T, labels[:, [2, 2, 4, 4, 2]].T, '.-')
plt.axis('off')

figure_1

A second example:
coco_augmentation_examples

@nirbenz
Copy link

nirbenz commented Oct 16, 2018

@glenn-jocher your implementation definitely does it (although for custom datasets I would expand it to support different aspect ratios - e.g. when width and height aren't the same). But, I believe the data-loader which @xuzheyuan624 was referring to doesn't letterbox. This could attribute to a rather large difference in mAP.

As for how letterboxing it performed along with data augmentation - well it seems that in the examples you sent me there are a few examples that can be cropped further while keeping correct AR - for example, this image:

image

I am not too sure how the original Darknet implementation does its augmentations but I would assume the image undergoes all the necessary augmentation and then the minimal bounding box (minx, miny, maxx, maxy over a a rotated/skewed rectangle) is used as a reference for letterboxing. This would ensure that you are as tight as possible around the image boundaries and that you are keeping the grey area to minimum.

@nirbenz
Copy link

nirbenz commented Oct 31, 2018

@glenn-jocher I am reviving this to note that that even if training succeeds, I wouldn't be surprised if the different letterboxing (which is wasteful in image real estate) might cause a slight difference in overall mAP. I will try to fix this code and PR once I get to training myself.

@nirbenz
Copy link

nirbenz commented Nov 4, 2018

@xuzheyuan624 @glenn-jocher reviving this issue. Other than different aspect ratios, something else that should be taken into account to get performance comparable to Darknet training is different resolutions.

When training YOLOv3 using darknet - not only image AR is changed but the actual final image resolution is also changed (i.e., without letterboxing). A training YOLOv3 model naturally supports every image for which both height and width are multiples of 32. This is used during training to get better robustness to varying object sizes. This is also the reason the original YOLOv3 model is the same one for all three resolutions in the paper (320, 416, 608).
When training using Darknet it would seem the model both upscales and downscales both height and width to achieve this.

I am in the process of implementing this but I wonder if anyone else already did it. @xuzheyuan624 , if I recall correctly your trained model achieves a final mAP not too far from the original implementation - correct?

I'll note that I see that existing code performs minor scaling (between 0.8 and 1.2) but this isn't in the same range as original model and would (to the best of my understanding) crop the original image and it's bounding boxes. This is also useful but should be performed regardless (and after) what I just described.

@glenn-jocher
Copy link
Member

@nirbenz I think you are referring to the multi-scale training. I have an implementation of this commented out currently in train.py line 102-106. If these lines are uncommented then each epoch will train on a random image size from 320 - 608. Aspect ratio is different though, this should always be constant. For example if you increase the height of an image by 50% you must also increase its width by 50%.

# Multi-Scale YOLO Training
img_size = random.choice(range(10, 20)) * 32  # 320 - 608 pixels
dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=img_size, augment=True)
print('Running this epoch with image size %g' % img_size)

To work properly, img_size must be initialised to the maximum size that might be encountered, 608:

parser.add_argument('-img_size', type=int, default=608, help='size of each image dimension')

@glenn-jocher
Copy link
Member

@xuzheyuan624 @nirbenz Commit dc7b58b updates train.py to use multi-scale training by default, though this can be turned off by setting opt.multi_scale = False in the argparser. The bigger question is what effect this has on the mAP at 416, 608 etc. I'll try and start a fresh GCP session to train this from scratch (which now uses weights/darknet53.conv.74 to initialize the first 74 layers too by default).

# Multi-Scale YOLO Training
if opt.multi_scale:
    img_size = random.choice(range(10, 20)) * 32  # 320 - 608 pixels
    dataloader = load_images_and_labels(train_path, batch_size=opt.batch_size, img_size=img_size, augment=True)
    print('Running Epoch %g at multi_scale img_size %g' % (epoch, img_size))

@xuzheyuan624
Copy link
Author

I use multi-scale training like @glenn-jocher @nirbenz :

        if self.is_training and index % self.batch_size == 0:
            if self.seed < 4000:
                self.img_size = 13 * 32
            elif self.seed < 8000:
                self.img_size = (random.randint(0, 3) + 13) * 32
            elif self.seed < 12000:
                self.img_size = (random.randint(0, 5) + 12) * 32
            elif self.seed < 16000:
                self.img_size = (random.randint(0, 7) + 11) * 32
            else:
                self.img_size = (random.randint(0, 9) + 10) * 32
            self.update()

@okanlv
Copy link

okanlv commented Nov 5, 2018

@glenn-jocher In the original yolov3 code, the image size is changed for every 10 batches.

@glenn-jocher
Copy link
Member

@okanlv ah yes now I remember why I did it only once per epoch. The torch.backends.cudnn.benchmark flag speeds up GPU operations by about 20%, but it has problems with varying size inputs, such as when we change img_size, it must re-optimize, becoming very slow if done every few batches. If I turn this off then I can change img_size every batch, though now my batch_size must be reduced from 16 to 8 to keep CUDA from running out of memory on the max image size of 608 pixels. I can still update weights every 16 images, though two changes make the training take about 2X longer.

Does darknet vary the image size throughout training or just in the final few epochs?

@okanlv
Copy link

okanlv commented Nov 6, 2018

@glenn-jocher Hmm, that is a valid point. Darknet changes the image size throughout the training according to the following rule. l.random is 1 so it is just a flag to determine whether or not apply multi-scale training. get_current_batch(net) gives how many batches the model has seen. So up to the last 200 batches the image size changes for every 10 batches. The image size is kept same (608x608) for the last 200 batches.

    if(l.random && count++%10 == 0){
        printf("Resizing\n");
        int dim = (rand() % 10 + 10) * 32;
        if (get_current_batch(net)+200 > net->max_batches) dim = 608;
        //int dim = (rand() % 4 + 16) * 32;
        printf("%d\n", dim);
        args.w = dim;
        args.h = dim;

I thought changing the image size every 10 batches prevents the model to overfit a particular image size. However, your method might work as well imo.

@nirbenz
Copy link

nirbenz commented Nov 6, 2018

@glenn-jocher @xuzheyuan624 Thanks, I missed that since I was looking for it in the dataloader rather than outside of it. Which really brings the question - wouldn't this be easier in the data loader itself and make more sense?
Unless I'm missing something, the model is by definition robust to varying image sizes (as multiples of 32) and the only thing that matter is the code that actually resizes (and letterboxes) the image (before applying affine transforms on that).

@nirbenz
Copy link

nirbenz commented Nov 6, 2018

@okanlv Are you sure this is the relevant piece of code from Darknet training? It would impose constant aspect ratio images, which is clearly not the case (I've been using Darknet to train non-square native models for a while now).

@okanlv
Copy link

okanlv commented Nov 6, 2018

@nirbenz Yes, this is the multi-scale training part of Darknet. In order to deal with different image sizes, the layers are resized as follows:

        for(i = 0; i < ngpus; ++i){
            resize_network(nets[i], dim, dim);
        }

In practice, images with different aspect ratio are padded to the same aspect ratio to train on gpu.

@nirbenz
Copy link

nirbenz commented Nov 6, 2018

I am not sure this is the whole picture - and that the configured (from cfg file) width/height do affect ho anchors locations are used in practice (both in training and inference).
I have actually expanded @glenn-jocher 's code to support different width/height configurations and this was the only way I managed to reproduce results from a Darknet-trained model with similar properties.

So while naturally, internally the model is aspect ratio invariant - if the width/height are configured to be non-square, letterboxing won't be performed at all and the model is trained on actual rectangular images. This isn't relevant to MS-COCO training but is relevant to custom datasets (which is what I'm facing, hence needing to expand the original code).

@sporterman
Copy link

sporterman commented Nov 7, 2018

@okanlv hello. i meet some trouble while using this code train my own dataset, the data include 10 classes, the training result was as follow:
image
precision and recall never change, can you give some advice, thanks.

@okanlv
Copy link

okanlv commented Nov 7, 2018

@sporterman Hi, I have not trained on another dataset but I have summerized the necessery steps in this comment for VOC dataset. You should follow similar steps for your dataset.

@sporterman
Copy link

@okanlv Appreciate ! Actually i want to know if this code is useful for my small dataset, because there are many troubles you guys met. I've searched for yolov3 realise. but almost all the blogs are about how to train yolov3 on official c++ code. Up to now, i didn't find any code in pytorch that realized yolov3 steady, any advice?

@glenn-jocher
Copy link
Member

@sporterman yes this repository will work for any dataset, not just COCO, just follow @okanlv comment with excellent directions. I used it for the xView challenge this summer. You can see training results and example inference result here: https://github.com/ultralytics/xview-yolov3

@sanmianjiao
Copy link

@glenn-jocher so now we can say u get the mAP as reported in darknet?

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 11, 2018

@sanmianjiao It seems not quite. The latest commit produces 0.52 mAP around epoch 62, at 416 x 416. Darknet reported mAP is 0.58 at 608 x 608 (paper does not report at 416 pixels). I have not tried to train fully with --multi-scale enabled however, or perhaps to simply train at --img-size 608.

49729265-3ed01e80-fc75-11e8-888a-8b34fbbdd5b7

@sanmianjiao
Copy link

@glenn-jocher so u get mAP 0.52 in 62epoch and size is 416*416 and the dataset is train2014+val2014-5k ? When we use 5k and yolo_weights.pt provided by author to evaluate the mAP , u can get 0.58 as you said in read.md? Forget my poor english~

@sanmianjiao
Copy link

@glenn-jocher and I want to say that the paper said the mAP50 is 55.3 in size of 416*416.
image

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 12, 2018

@sanmianjiao ah yes you are correct. And yes this is on COCO2014 you mentioned. So the proper comparison is darknet 55.3 to this repo 52.2 currently. This is not so far apart, I'm happy to see that. I'm still running experiments to improve the map as well, but it is slow going as I only have one GPU, so hopefully I can raise this 52.2 a little higher in the coming weeks and months.

@nirbenz
Copy link

nirbenz commented Dec 12, 2018

@glenn-jocher there is no real difference between COCO14 and COCO17, other than (and this is a big thing) separation of validations and train. With COCO14 there is a roughly 50-50 split, so common practice is to merge them and choose a small subset for test. This is what 5k.part and testvalno5k.part are.

For COCO17 the dataset is already split that way (train+val and test). But because of that, evaluating YOLOv3 (using original weights) must be done on the 5k split performed by the author. Otherwise you are probably testing on some of the train-set.

Apologies if this is well known already but I though it's important to clarify.

@codingantjay
Copy link

codingantjay commented Dec 15, 2018

@sanmianjiao ah yes you are correct. And yes this is on COCO2014 you mentioned. So the proper comparison is darknet 55.3 to this repo 52.2 currently. This is not so far apart, I'm happy to see that. I'm still running experiments to improve the map as well, but it is slow going as I only have one GPU, so hopefully I can raise this 52.2 a little higher in the coming weeks and months.

Hi Glenn,

Thanks for sharing this repo. I noticed there's a difference between this repo and darknet during training that may impact performance. In "build_targets" you use an arbitrary threshold -> 0.1 to skip anchors that are not good enough, whereas in darknet, the (0, 0) center IOU is calculated between all anchors (here the total number of anchors is 9) and the target. Only when the best anchor in all 9 anchors is at the current yolo layer the prediction joins training. In short, a target is only assigned to 1 yolo layer, 1 anchor. In this repo, a target can be assigned to multiple yolo layers, and all of them calculate loss and gradients, which could affect the training significantly.

https://github.com/pjreddie/darknet/blob/61c9d02ec461e30d55762ec7669d6a1d3c356fb2/src/yolo_layer.c#L205-L214

https://github.com/pjreddie/darknet/blob/61c9d02ec461e30d55762ec7669d6a1d3c356fb2/src/yolo_layer.c#L216-L217

https://github.com/pjreddie/darknet/blob/61c9d02ec461e30d55762ec7669d6a1d3c356fb2/src/utils.c#L633-L640

@glenn-jocher
Copy link
Member

@codingantjay yes you are correct, this repo sets an arbitrary lower threshold (0.1 IOU) for rejecting potential anchors within each of the 3 yolo layers. This is a tunable parameter that I set a while back after some trial and error, though the repo was in a substantially different state at the time, so perhaps it needs retuning.

You are also correct that this means that an object can be assigned to multiple anchors, perhaps even 3 times, one in each layer. I did not know darknet was only assigning an object to one of the 9 layers. This would be difficult to replicate in this repo as each of the yolo layers creates its own independent loss function, though you could try to do this, and to tune the rejection threshold (I would vary it between 0.0 and 0.3), and if any of these work I'd be all ears!! Unfortunately I only have one GPU and limited resources to devote to further improving the repo. Any help is appreciated!

@nirbenz
Copy link

nirbenz commented Dec 20, 2018

@glenn-jocher @codingantjay This would indeed necessitate rebuilding the YOLO layer to be shareable across the network rather than repeated three times. I would assume it shouldn't be too difficult to change the chosen anchors in a forward/backward pass depending on an argument passed to forward. I believe something similar is performed in the original implementation although I haven't checked thoroughly (@codingantjay have you?).

In other words, the current behavior of forward (and part of the constructor code) will be conditioned on an argument passed to YOLOLayer's forward function. Based on the current implementation of Darknet's forward function this can easily be a counter of number of previous calls (out of the total of 3).

I will have a go at this soon.

@nirbenz
Copy link

nirbenz commented Dec 28, 2018

whereas in darknet, the (0, 0) center IOU is calculated between all anchors (here the total number of anchors is 9) and the target

@codingantjay I'm actually not 100% sure I understand this part, can you perhaps clarify what you meant there?

@glenn-jocher
Copy link
Member

glenn-jocher commented Feb 27, 2019

@xuzheyuan624 @codingantjay @sanmianjiao @sporterman pycocotools mAP is 0.550 (416) and 0.579 (608) with yolov3.weights in the latest commit. See #71 (comment) for more info.

sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
sudo rm -rf cocoapi && git clone https://github.com/cocodataset/cocoapi && cd cocoapi/PythonAPI && make && cd ../.. && cp -r cocoapi/PythonAPI/pycocotools yolov3
cd yolov3
...
python3 test.py --save-json --conf-thres 0.005
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.308
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.550
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.313
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.143
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.339
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.448
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.266
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.398
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.417
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.226
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.456
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.572
...
python3 test.py --save-json --conf-thres 0.005 --img-size 608 --batch-size 16
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.328
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.579
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.341
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.196
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.359
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.279
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.423
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.444
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.293
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.472
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.557

@vim5818
Copy link

vim5818 commented Aug 18, 2020

I use your code with the official yolov3 weights and COCO API to calculate AP in 5k, and I get AP[IoU=0.50] = 0.533 when image size is 416x416.Here is the whole results.

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.299
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.533
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.304
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.153
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.264
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.390
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.406
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.236
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.428
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.553

@xuzheyuan624 can you please share the command and steps you followed to make the evaluation of your model. ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants