In [1]:
from collections import Counter
import emoji
from pathlib import Path
import sys
current_work_directionary = Path('__file__').parent.absolute()
sys.path.insert(0, str(current_work_directionary))

import cv2
import torch.cuda
from models import Yolov5Small, Yolov5SmallWithPlainBscp, Yolov5Large, Yolov5Middle, Yolov5XLarge
from tqdm import tqdm
import numpy as np
from utils import cv2_save_img
from utils import maybe_mkdir, clear_dir
from utils import time_synchronize
from datetime import datetime
import torch.nn.functional as F
from trainer.EMA import ExponentialMovingAverage
from trainer import Evaluate
import time
from dataset import testdataloader, cocotestdataloader
from utils import mAP, cv2_save_img_plot_pred_gt, ConvBnAct, fuse_conv_bn
import pickle

In [2]:
class Training:

    def __init__(self, anchors, nms_hyp, test_hyp):
        self.test_hyp = test_hyp
        # parameters
        self.select_device()
        self.use_cuda = self.test_hyp['device'] == 'cuda'
        self.anchors = anchors
        self.nms_hyp = nms_hyp
        if isinstance(anchors, (list, tuple)):
            self.anchors = torch.tensor(anchors)  # (3, 3, 2)
        self.anchors = self.anchors.to(test_hyp['device'])
        anchor_num_per_stage = self.anchors.size(0)  # 3

        # 确保输入图片的shape必须能够被32整除（对yolov5s而言），如果不满足条件则对设置的输入shape进行调整
        self.test_hyp['input_img_size'] = self.padding(self.test_hyp['input_img_size'], 32)
    
        # self.testdataloader = testdataloader(self.test_hyp['data_dir'], self.test_hyp['input_img_size'])
        self.testdataset, self.testdataloader = cocotestdataloader(self.test_hyp['data_dir'], 
                                                                   self.test_hyp['set_name'], 
                                                                   self.test_hyp['use_crowd'], 
                                                                   self.test_hyp['input_img_size'], 
                                                                   self.test_hyp['batch_size'], 
                                                                   self.test_hyp['num_workers'])

        if self.test_hyp['current_work_path'] is None:
            self.cwd = Path('./').absolute()
        else:
            self.cwd = Path(self.test_hyp['current_work_path'])

        self.mean = torch.tensor([0.485, 0.456, 0.406]).float()
        self.std = torch.tensor([0.229, 0.224, 0.225]).float()

        # model, optimizer, loss, lr_scheduler, ema
        # self.model = Yolov5SmallWithPlainBscp(anchor_num_per_stage, self.test_hyp['num_class']).to(self.test_hyp['device'])
        # self.model = Yolov5Small(anchor_num_per_stage, self.test_hyp['num_class']).to(self.test_hyp['device'])
        # self.model = Yolov5Middle(anchor_num_per_stage, self.test_hyp['num_class']).to(self.test_hyp['device'])
        self.model = Yolov5Large(anchor_num_per_stage, self.test_hyp['num_class']).to(self.test_hyp['device'])
#         self.model = Yolov5XLarge(anchor_num_per_stage, self.test_hyp['num_class']).to(self.test_hyp['device'])

        self.validate = Evaluate(self.model, self.anchors, self.test_hyp['device'], 
                                 self.test_hyp['num_class'], self.test_hyp['input_img_size'], 
                                 self.nms_hyp)
        self.ema_model = ExponentialMovingAverage(self.model, 0)

        if Path(self.test_hyp["pretrained_model_path"]).exists():
            try:
                self.load(self.test_hyp['pretrained_model_path'], False, 'cpu')
            except Exception as err:
                print(err)

        model_summary_before_fuse = summary_model(self.model)
        print(model_summary_before_fuse)
        # ============= to do =====================
        self.fuse_conv_bn()
        model_summary_after_fuse = summary_model(self.model)
        print(model_summary_after_fuse)
        # =========================================

    def fuse_conv_bn(self):
        for m in self.model.modules():
            if isinstance(m, ConvBnAct) and hasattr(m, 'bn'):
                m.conv = fuse_conv_bn(m.conv, m.bn)
                delattr(m, 'bn')
                m.forward = m.forward_fuse
        
    @staticmethod
    def padding(hw, factor=32):
        h, w = hw
        h_mod = h % factor
        w_mod = w % factor
        if h_mod > 0:
            h = (h // factor + 1) * factor
        if w_mod > 0:
            w = (w // factor + 1) * factor
        return h, w

    def preds_postprocess(self, inp, outputs, info):
        """

        :param inp: normalization image
        :param outputs:
        :param info:
        :return:
        """
        processed_preds = []
        processed_inp = []
        for i in range(len(outputs)):
            scale, pad_top, pad_left = info[i]['scale'], info[i]['pad_top'], info[i]['pad_left']
            pad_bot, pad_right = info[i]['pad_bottom'], info[i]['pad_right']
            pred = outputs[i]
            org_h, org_w = info[i]['org_shape']
            cur_h, cur_w = inp[i].size(1), inp[i].size(2)

            img = inp[i].permute(1, 2, 0)
            img = (img * self.std + self.mean) * 255.
            img = img.numpy().astype(np.uint8)
            img = img[pad_top:(cur_h - pad_bot), pad_left:(cur_w - pad_right), :]
            img = cv2.resize(img, (org_w, org_h), interpolation=0)

            if pred is not None and pred.size(0) > 0:
                pred[:, [0, 2]] -= pad_left
                pred[:, [1, 3]] -= pad_top
                pred[:, [0, 1, 2, 3]] /= scale
                pred[:, [0, 2]] = pred[:, [0, 2]].clamp(1, org_w - 1)
                pred[:, [1, 3]] = pred[:, [1, 3]].clamp(1, org_h - 1)
                if self.test_hyp['use_auxiliary_classifier']:
                    # 将每个预测框中的物体抠出来，放到一个额外的分类器再进行预测一次是否存在对象
                    pass
                processed_preds.append(pred.cpu().numpy())
            else:
                processed_preds.append(np.ones((1, 6)) * -1.)
            processed_inp.append(img)
        return processed_inp, processed_preds

    def select_device(self):
        if self.test_hyp['device'].lower() != 'cpu':
            if torch.cuda.is_available():
                self.test_hyp['device'] = 'cuda'
            else:
                self.test_hyp['device'] = 'cpu'

    def load(self, model_path, load_optimizer, map_location):
        assert Path(model_path).exists(), f"model path is not exist {model_path}"
        state_dict = torch.load(model_path, map_location=map_location)
        if "model" not in state_dict:
            print("not found model_state_dict in this state_dict, load model failed!")
        else:
            print(f"use pretrained model {model_path}")
            self.model.load_state_dict(state_dict["model"])
        if load_optimizer and "optim" in state_dict:
            print(f"use pretrained optimizer {model_path}")
            self.optimizer.load_state_dict(state_dict['optim'])
        del state_dict
    
    def gt_bbox_postprocess(self, anns, infoes):
        """
        testdataloader出来的gt bboxes经过了letter resize，这里将其还原到原始的bboxes

        :param: anns: dict
        """
        ppb = []  # post processed bboxes
        ppc = []  # post processed classes
        for i in range(anns.shape[0]):
            scale, pad_top, pad_left = infoes[i]['scale'], infoes[i]['pad_top'], infoes[i]['pad_left']
            valid_idx = anns[i][:, 4] >= 0
            ann_valid = anns[i][valid_idx]
            ann_valid[:, [0, 2]] -= pad_left
            ann_valid[:, [1, 3]] -= pad_top
            ann_valid[:, :4] /= scale
            ppb.append(ann_valid[:, :4].cpu().numpy())
            ppc.append(ann_valid[:, 4].cpu().numpy().astype('uint16'))
        return ppb, ppc

    def predtict_all(self):
        """
        测试testdataloader中的所有图片并将结果保存到磁盘
        """
        for i, x in enumerate(self.testdataloader):
            imgs = x['imgs']  # (bn, 3, h, w)
            infoes = x['resize_infoes']
            gt_bbox, gt_cls = self.gt_bbox_postprocess(x['anns'], infoes)
            outputs = self.validate(imgs.to(self.test_hyp['device']), self.test_hyp['use_tta'], self.nms_hyp['wfb'])
            imgs, preds = self.preds_postprocess(imgs.cpu(), outputs, infoes)
            pred_cls = [preds[j][:, 5] for j in range(len(imgs))]

            if self.test_hyp['save_img']:
                for k in range(len(imgs)):
                    save_path = str(self.cwd / 'result' / 'tmp' / f"{i * self.test_hyp['batch_size'] + k} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.png")
                    maybe_mkdir(Path(save_path).parent)
                    pred_lab = [self.testdataset.class2label[int(c)] for c in pred_cls[k]]
                    gt_lab = [self.testdataset.class2label[int(c)] for c in gt_cls[k]]
                    if self.test_hyp['show_gt_bbox']:
                        cv2_save_img_plot_pred_gt(imgs[k], preds[k][:, :4], pred_lab, preds[k][:, 4], gt_bbox[k], gt_lab, save_path)
                    else:
                        cv2_save_img(imgs[k], preds[k][:, :4], pred_lab, preds[k][:, 4], save_path)
            del imgs, preds

    def count_object(self, pred_lab):
        """
        按照object的个数升序输出。

        :param pred_lab: [(X, ), (Y, ), (Z, ), ...]
        """
        msg = []
        for lab in pred_lab:
            counter = Counter(lab)
            names, numbers = [], []
            for nam, num in counter.items():
                names.append(nam)
                numbers.append(str(num))
            sort_index = np.argsort([int(i) for i in numbers])
            ascending_numbers = [numbers[i] for i in sort_index]
            ascending_names = [names[i] for i in sort_index]
            if len(numbers) > 0:
                if (self.cwd / "result" / 'pkl' / "coco_emoji_names.pkl").exists():
                    coco_emoji = pickle.load(open(str(self.cwd / "result" / 'pkl' / "coco_emoji_names.pkl"), 'rb'))
                    msg_ls = [" ".join([number, coco_emoji[name]]) for name, number in zip(ascending_names, ascending_numbers)]
                else:
                    msg_ls = [" ".join([number, name]) for name, number in zip(ascending_names, ascending_numbers)]
            else:
                msg_ls = ["No object has been found!"]
            msg.append(emoji.emojize("; ".join(msg_ls)))
        return msg

    def calculate_mAP(self):
        """
        计算testdataloader中所有数据的map
        """
        start_t = time_synchronize()
        pred_bboxes, pred_classes, pred_confidences, pred_labels, gt_bboxes = [], [], [], [], []
        for i, x in enumerate(self.testdataloader):

            imgs = x['imgs']  # (bn, 3, h, w)
            infoes = x['resize_infoes']

            # gt_bbox: [(M, 4), (N, 4), (P, 4), ...]; gt_cls: [(M,), (N, ), (P, ), ...]
            # coco val2017 dataset中存在有些图片没有对应的gt bboxes的情况
            gt_bbox, gt_cls = self.gt_bbox_postprocess(x['anns'], infoes)
            gt_bboxes.extend(gt_bbox)

            # 统计预测一个batch需要花费的时间
            t1 = time_synchronize()
            outputs = self.validate(imgs.to(self.test_hyp['device']), self.test_hyp['use_tta'], self.nms_hyp['wfb'])
            # preds: [(X, 6), (Y, 6), (Z, 6), ...]
            imgs, preds = self.preds_postprocess(imgs.cpu(), outputs, infoes)
            t = time_synchronize() - t1

            batch_pred_box, batch_pred_cof, batch_pred_cls, batch_pred_lab = [], [], [], []
            for j in range(len(imgs)):
                valid_idx = preds[j][:, 5] >= 0
                if valid_idx.sum() == 0:
                    pred_box, pred_cls, pred_cof, pred_lab = [], [], [], []
                else:
                    pred_box = preds[j][valid_idx, :4]
                    pred_cof = preds[j][valid_idx, 4]
                    pred_cls = preds[j][valid_idx, 5]
                    pred_lab = [self.testdataset.class2label[int(c)] for c in pred_cls] 

                batch_pred_box.append(pred_box)
                batch_pred_cls.append(pred_cls)
                batch_pred_cof.append(pred_cof)
                batch_pred_lab.append(pred_lab)

            pred_bboxes.extend(batch_pred_box)
            pred_classes.extend(batch_pred_cls)
            pred_confidences.extend(batch_pred_cof)
            pred_labels.extend(batch_pred_lab)
            
            obj_msg = self.count_object(batch_pred_lab)
            
            for k in range(len(imgs)):
                count = i * self.test_hyp['batch_size'] + k
                print(f"[{count:>05}/{len(self.testdataset)}] ➡️ " + obj_msg[k] + f" ({(t/len(imgs)):.2f}s)")
                if self.test_hyp['save_img']:
                    save_path = str(self.cwd / 'result' / 'tmp' / f"{i * self.test_hyp['batch_size'] + k} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}.png")
                    if self.test_hyp['show_gt_bbox']:
                        gt_lab = [self.testdataset.class2label[int(c)] for c in gt_cls[k]]
                        cv2_save_img_plot_pred_gt(imgs[k], batch_pred_box[k], batch_pred_lab[k], batch_pred_cof[k], gt_bbox[k], gt_lab, save_path)
                    else:
                        cv2_save_img(imgs[k], batch_pred_box[k], batch_pred_lab[k], batch_pred_cof[k], save_path)
            del imgs, preds

        total_use_time = time_synchronize() - start_t

        all_preds = []
        for pred_box, pred_cof in zip(pred_bboxes, pred_confidences):
            if len(pred_box) == 0:
                all_preds.append([])
            else:
                all_preds.append(np.concatenate((pred_box, pred_cof[:, None]), axis=1))

        # 如果测试的数据较多，计算一次mAP需花费较多时间，将结果保存下来以方便后续统计
        if self.test_hyp['save_pred_bbox']:
            pickle.dump(all_preds, open(self.cwd / "result" / "pkl" / "pred_bbox_1024_tta.pkl", 'wb'))
            pickle.dump(gt_bboxes, open(self.cwd / "result" / "pkl" / "gt_bbox.pkl", "wb"))

        map = mAP(all_preds, gt_bboxes, 0.5)
        print(f"use time: {total_use_time:.2f}s")
        print('AP: %.2f %%' % (map.elevenPointAP * 100))
        print('mAP: %.2f %%' % (map.everyPointAP * 100))

In [3]:
if __name__ == '__main__':

    nms_hyp = {
        'iou_threshold': 0.45,  # 最后使用iou threshold过滤掉一批预测框
        'conf_threshold': 0.25,  # 先使用conf threshold过滤掉一批预测框
        "cls_threshold": 0.3,  # 再使用cls threshold过滤掉一批预测框
        "max_predictions_per_img": 300,
        "min_prediction_box_wh": 2,
        "max_prediction_box_wh": 4096,
        "iou_type": 'iou',
        'mutil_label': False,  # 一个object是否可以分配多个标签
        "agnostic": True,  # 是否只在同一个类别的bbox间进行NMS
        "postprocess_bbox": False,  # 是否对预测的bbox进一步调优
        "wfb": False,  # use NMS or Weighted Fusion Bbox
        "wfb_weights": [1, 1, 1],
        "wfb_iou_threshold": 0.5,
        "wfb_skip_box_threshold": 0.001
    }

    test_hyp = {
        "data_dir": "/home/uih/JYL/Dataset/COCO2017/", 
        # "data_dir": "./result/coco_test_imgs", 
        "set_name": "val2017",
        'use_auxiliary_classifier': False,
        'batch_size': 8,
        "input_img_size": [640, 640],
        "num_workers": 6,
        "save_img": True,
        "device": 'cpu', 
        "use_tta": True, 
        "use_crowd": False, 
        "save_pred_bbox": True,
        "current_work_path": None, 
        "use_pretrained_mdoel": True, 
        "pretrained_model_path": "./checkpoints/every_for_coco_large.pth",
        "num_class": 80,
        "show_gt_bbox": True, 
    }

    anchors = torch.tensor([[[10, 13], [16, 30], [33, 23]], [[30, 61], [62, 45], [59, 119]], [[116, 90], [156, 198], [373, 326]]])
    train = Training(anchors, nms_hyp, test_hyp)
    train.calculate_mAP()

loading annotations into memory...
Done (t=0.61s)
creating index...
index created!
use pretrained model ./checkpoints/every_for_coco_large.pth
[00000/5000] ➡️ 1 🪴; 1 :sink:; 1 :dining_table:; 1 🔪; 2 🧑; 2 :oven:; 2 🥄; 4 🍚; 4 🥤 (1.81s)
[00001/5000] ➡️ 1 :oven:; 1 :dining_table:; 1 🪑; 1 🪴; 1 🍌; 1 🍚; 2 :refrigerator:; 6 🍊 (1.81s)
[00002/5000] ➡️ 1 ☂; 1 💼; 1 🚚; 1 🚗; 1 🎒; 1 🛹; 2 🚦; 5 🧑 (1.81s)
[00003/5000] ➡️ 1 🚲; 1 📺; 1 :bench:; 2 🛹; 19 🧑 (1.81s)
[00004/5000] ➡️ 1 🚲; 6 🚗 (1.81s)
[00005/5000] ➡️ 1 🚽; 1 🥤; 2 :sink: (1.81s)
[00006/5000] ➡️ 1 :sink:; 1 🚽 (1.81s)
[00007/5000] ➡️ 3 🏍; 7 🧑 (1.81s)
[00008/5000] ➡️ 1 📚; 2 🔪; 7 🚽 (1.78s)
[00009/5000] ➡️ 1 :sink:; 2 🚽 (1.78s)
[00010/5000] ➡️ 1 ☂; 1 🚲; 2 🏍; 4 🚗; 13 🧑 (1.78s)
[00011/5000] ➡️ 1 🧑; 1 📺; 1 💻; 1 ⌨; 1 🥤; 26 📚 (1.78s)
[00012/5000] ➡️ 2 🪴; 2 🏺 (1.78s)
[00013/5000] ➡️ 1 🛏; 1 🧑; 1 🛋; 2 📺; 20 📚 (1.78s)
[00014/5000] ➡️ 1 🍚; 1 🏺; 4 🍊 (1.78s)
[00015/5000] ➡️ 1 💼; 1 🛥; 2 🧑; 2 ✈ (1.78s)


KeyboardInterrupt: 