Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a dataloader as ultralytics in detection pipeline #92

Open
zhiqwang opened this issue Apr 19, 2021 · 5 comments
Open

Add a dataloader as ultralytics in detection pipeline #92

zhiqwang opened this issue Apr 19, 2021 · 5 comments
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@zhiqwang
Copy link
Owner

zhiqwang commented Apr 19, 2021

The outputs are great, although not the same as yolov5, maybe some pre-processing/post-processing steps are different.

That's a great catch! I think it is caused by the different pre-processing operations. I've verified the the post-processing stages before, it can get the same results as ultralytics/yolov5 (when w/o TTA predict). And I've uploaded a notebook in order to verify the model inference and post-processing stages, this one is a bit outdated now, I haven’t had enough time to update it. And I plan to add a dataloader in the predict pipeline to make yolort detect the same results as ultralytics.

Looks like you forgot to convert color space from bgr to rgb in your "inference-pytorch-export-libtorch.ipynb"?

In my impression, ultralytics uses the BGR channel as default, but I am not very sure and need a double check. And it seems that the default image dataloader are using the RGB channel, if you input the image path to model, and use model.predict('image_path') to detect a image, it will be wrong, here also needs further verification.

BTW, all PRs are welcome here.

Originally posted by @stereomatchingkiss and @zhiqwang in #90 (comment)

@zhiqwang zhiqwang added enhancement New feature or request good first issue Good for newcomers labels Apr 19, 2021
@stereomatchingkiss
Copy link
Contributor

stereomatchingkiss commented Apr 20, 2021

In my impression, ultralytics uses the BGR channel as default, but I am not very sure and need a double check. And it seems that the default image dataloader are using the RGB channel,

I study the source codes of detection.py(v4.0), they are using RGB channel.

utils/datasets.py

line 187

img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416

if you input the image path to model, and use model.predict('image_path') to detect a image, it will be wrong, here also needs further verification.

I input tensor, following are my codes, for customize model

import cv2
import os
import torch

from yolort.utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from yolort.utils.image_utils import color_list, plot_one_box, letterbox

from yolort.models import yolov5s

model = yolov5s(pretrained=False, score_thresh=0.25, num_classes=1)
model_basename = 'yolov5s_updated'
ckpt = torch.load(model_basename + '.pt')
device = 'cuda'
model.eval()
model = model.to(device)

save_at = 'results'

if os.path.exists(save_at) == False:
    os.mkdir(save_at)

device = torch.device('cuda' if  torch.cuda.is_available() else 'cpu')
root = ['C:/Users/yyyy/programs/Qt/3rdLibs/pytorch_projects/yolov5/train_yolov5/archive/fire_dataset/fire_images/']
root = root[0]
img_list = os.listdir(root)
colors = color_list()
labels = ['danger']

def preprocess_image(img_raw):
    img_raw = letterbox(img_raw, new_shape=(640, 640))[0]
    img_rgb = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
    img = read_image_to_tensor(img_rgb)
    img = img.to(device)
    
    return img, img_raw

def generate_output(model):
    with torch.no_grad():
        for i in range(len(img_list)):
            img_raw = cv2.imread(root + img_list[i])            
            img, img_raw = preprocess_image(img_raw)
            # Perform inference on an image tensor
            model_out = model.predict(img)
            for box, label in zip(model_out[0]['boxes'].tolist(), model_out[0]['labels'].tolist()):
                img_raw = plot_one_box(box, img_raw, color=colors[label % len(colors)], label=labels[label])
                cv2.imwrite(save_at + "/" + str(i).zfill(5) + '.jpg', img_raw)
                
generate_output(model) 

# TorchScript export
print(f'Starting TorchScript export with torch {torch.__version__}...')
export_script_name = model_basename + '.torchscript.pt'  # filename

model_script = torch.jit.script(model)
model_script.eval()
model_script = model_script.to(device)

# Save the scripted model file for subsequent use (Optional)
model_script.save(export_script_name)

img = cv2.imread(root + img_list[0])
img, _ = preprocess_image(img)
x = [img]
out = model.predict(img) #model(img) do not work
out_script = model_script(x)

for k, v in out[0].items():
    torch.testing.assert_allclose(out_script[0][k], v, rtol=1e-07, atol=1e-09)

print("Exported model has been tested with libtorch, and the result looks good!")

If I change to model(img), it output error message

images is expected to be a list of 3d tensors of shape [C, H, W], got torch.Size([640, 640])

@zhiqwang
Copy link
Owner Author

zhiqwang commented Apr 20, 2021

I study the source codes of detection.py(v4.0), they are using RGB channel.

Thanks, we should fix this! Would you be willing to help with fixing this?

If I change to model(img), it output error message
images is expected to be a list of 3d tensors of shape [C, H, W], got torch.Size([640, 640])

You could load the model as model([img]) or model(img[None]) if you want to use the vanilla model.

BTW, the model loaded from yolort.models.yolov5s has already contained a pre-processing transfrom, so when you have implemented the preprocess_image, you should load the model from yolort.models.yolo.yolov5_darknet_pan_s_r40, check the following for more details.

https://github.com/zhiqwang/yolov5-rt-stack/blob/0248f9762ff25f71a08a1db4930222ea21112080/yolort/models/yolo.py#L257

And this is why I use from models.yolo import yolov5_darknet_pan_s_r31 as yolov5s in the notebooks, and now you can load this as below

from yolort.models.yolo import yolov5_darknet_pan_s_r40 as yolov5s

Besides, the letterbox in ultralytics and GeneralizedYOLOTransform play the same role as the dynamic batch/shape dataloaders. But letterbox cannot be jit scripted (torch.jit does not support the functions in opencv now), this is why I introduced GeneralizedYOLOTransform here, and this is also the reason for the accuracy error between ultralytics and mine.

@stereomatchingkiss
Copy link
Contributor

stereomatchingkiss commented Apr 20, 2021

Thanks, we should fix this! Would you be willing to help with fixing this?

Sure. Thanks for your helps

BTW, the model loaded from yolort.models.yolov5s has already contained a pre-processing transfrom, so when you have implemented the preprocess_image, you should load the model from yolort.models.yolo.yolov5_darknet_pan_s_r40, check the following for more details.

Thanks for the tips, will try and tell you the results.

this is why I introduced GeneralizedYOLOTransform here, and this is also the reason for the accuracy error between ultralytics and mine.

Thanks for your explanations, always confuse with the jit limitations. No docs to explain how to design a network suit for jit trace/script, I almost based on trial/error.

@zhiqwang
Copy link
Owner Author

zhiqwang commented Apr 20, 2021

Recently, I've added the COCO metrics evaluation methods, you can use it as belows. (A double check is also needed here, so I haven’t added a document about it yet, but you can try it to test the dataloader.)

from pathlib import Path

import torch

from yolort.data import COCOEvaluator
from yolort.data.coco import COCODetection
from yolort.data.transforms import default_val_transforms, collate_fn
from yolort.data._helper import get_coco_api_from_dataset
from yolort.models import yolov5s

device = torch.device('cuda')

# Setup the coco dataset and dataloader for validation
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco_path = data_path / 'mscoco' / 'coco2017'
image_root = coco_path / 'val2017'
annotation_file = coco_path / 'annotations' / 'instances_val2017.json'

# Define the dataloader
batch_size = 16
val_dataset = COCODetection(image_root, annotation_file, default_val_transforms())
# We adopt the sequential sampler in order to repeat the experiment
sampler = torch.utils.data.SequentialSampler(val_dataset)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size,
    sampler=sampler,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=12,
)

coco_gt = get_coco_api_from_dataset(val_dataset)
coco_evaluator = COCOEvaluator(coco_gt)

# Model Definition and Initialization
model = yolov5s(
    pretrained=True,
    min_size=640,
    max_size=640,
    score_thresh=0.001,
)
model = model.eval()
model = model.to(device)

# COCO evaluation
for images, targets in val_dataloader:
    images = [image.to(device) for image in images]
    preds = model(images)
    coco_evaluator.update(preds, targets)

results = coco_evaluator.compute()

# Format the results
coco_evaluator.derive_coco_results()

The result of my mAP here is currently 36.3. As a comparison, the mAP of ultralytics is 36.9 (Test on ultralytics release v3.1,
and it seems that there is a bug in this script, the inference time is very slow, I will check it later EDIT: I mistakenly used the train datasets, and after changed to the val datasets, it works now).

Evaluation results for bbox (yolov5s, r3.1):

AP AP50 AP75 APs APm APl
36.291 56.713 38.485 21.066 41.287 46.479

Evaluation results for bbox (yolov5s, r4.0):

AP AP50 AP75 APs APm APl
35.739 55.508 37.877 20.988 41.007 45.086

@zhiqwang
Copy link
Owner Author

FYI, this comment ultralytics/yolov5#3054 (comment) is a good resource about the dataloaders in ultralytics's YOLOv5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants