In [1]:
# Specify device
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [2]:
import sys
import json
import os
import matplotlib.pyplot as plt
from collections import Counter
import cv2
import tqdm
from termcolor import colored

plt.rcParams["figure.figsize"] = (10,10)

# change this property
NOMEROFF_NET_DIR = os.path.abspath('../../')
sys.path.append(NOMEROFF_NET_DIR)

In [3]:
from NomeroffNet.BBoxNpPoints import (NpPointsCraft, 
                                      getCvZoneRGB, 
                                      convertCvZonesRGBtoBGR, 
                                      reshapePoints)
from NomeroffNet.MultiLineNPExtractor import (CCraft, 
                                            make_boxes)
from NomeroffNet.YoloV5Detector import Detector
from NomeroffNet.OptionsDetector import OptionsDetector
from NomeroffNet.TextDetector import TextDetector
from NomeroffNet.TextPostprocessing import translit_cyrillic_to_latin

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
class MaskDatasetChecker:
    def __init__(self):
        self.npPointsCraft = NpPointsCraft()
        self.ccraft = CCraft()
        self.npPointsCraft.load()
        self.detector = Detector()
        self.detector.load()
        self.optionsDetector = OptionsDetector()
        self.optionsDetector.load()
        self.textDetector = TextDetector({
            "eu_ua_2004_2015": {
                "for_regions": ["eu-ua-2015", "eu-ua-2004"],
                "model_path": "latest"
            },
            "eu_ua_1995": {
                "for_regions": ["eu-ua-1995"],
                "model_path": "latest"
            },
            "eu": {
                "for_regions": ["eu"],
                "model_path": "latest"
            },
            "ru": {
                "for_regions": ["ru", "eu-ua-ordlo-lpr", "eu-ua-ordlo-dpr"],
                "model_path": "latest" 
            },
            "kz": {
                "for_regions": ["kz"],
                "model_path": "latest"
            },
            "ge": {
                "for_regions": ["ge"],
                "model_path": "latest"
            },
            "su": {
                "for_regions": ["su"],
                "model_path": "latest"
            }
        })
        self.dataset = {}
        
    def load_dataset(self, json_path):
        dataset = {}
        print("Loading dataset...")
        with open(json_path) as json_file:
            data = json.load(json_file)
            for p in tqdm.tqdm(data['_via_img_metadata']):
                item = data['_via_img_metadata'][p]
                filename = item["filename"]
                bboxes = []
                for region in item['regions']:
                    x1 = min(region['shape_attributes']['all_points_x'])
                    x2 = max(region['shape_attributes']['all_points_x'])
                    y1 = min(region['shape_attributes']['all_points_y'])
                    y2 = max(region['shape_attributes']['all_points_y'])
                    bboxes.append({
                        'x1': x1, 
                        'x2': x2, 
                        'y1': y1, 
                        'y2': y2, 
                        'xs': region['shape_attributes']['all_points_x'],
                        'ys': region['shape_attributes']['all_points_y'],
                        'region_name': region['region_attributes']["region_name"].strip(),
                        'numberplate': region['region_attributes']["np"].strip(),
                    })
                dataset[filename] = bboxes
        self.dataset = dataset
    
    def predict(self, image_paths, use_target_box_from_dataset=1, use_option_from_dataset=1, debug=1):
        predicted = {}
        print("Predicting...")
        
        counter = Counter()
        for img_path in tqdm.tqdm(image_paths):
            img = cv2.imread(img_path)
            img = img[..., ::-1]
            print(img_path)
            print(img.shape)
            
            dataset_info = self.dataset.get(os.path.basename(img_path), [])
            if not use_target_box_from_dataset:
                targetBoxes = self.detector.detect_bbox(img)
                targetBoxes = targetBoxes
            else:
                targetBoxes = [[item["x1"], item["y1"], item["x2"], item["y2"]] for item in dataset_info]
            all_points, all_mline_boxes = self.npPointsCraft.detect_mline(img, targetBoxes)
            all_points = [ps for ps in all_points if len(ps)]

            # cut zones
            toShowZones = [getCvZoneRGB(img, reshapePoints(rect, 1)) for rect in all_points]
            zones = convertCvZonesRGBtoBGR(toShowZones)

            # find standart
            if not use_option_from_dataset:
                region_ids, countLines = self.optionsDetector.predict(zones)
                region_names = self.optionsDetector.getRegionLabels(region_ids)
            else:
                region_names = [item["region_name"] for item in dataset_info]
                countLines = [2 for _ in dataset_info]


            # convert multiline to one line
            image_parts = [img[int(box[1]):int(box[3]), int(box[0]):int(box[2])] 
                           for box, cl in zip(targetBoxes, countLines) 
                           if cl > 1]
            all_mline_boxes_rect = [mline_boxes 
                           for mline_boxes, cl in zip(all_mline_boxes, countLines) 
                           if cl > 1]
            region_names_rect = [region_name 
                                   for region_name, cl in zip(region_names, countLines) 
                                   if cl > 1]
            index_rect = [i 
                          for i, cl in enumerate(countLines) 
                          if cl > 1]
            (zones_rect, 
             zones_target_points, 
             zones_mline_boxes) = self.ccraft.multiline_to_one_line(all_mline_boxes_rect,
                                                                    image_parts,
                                                                    region_names_rect)
            for i, zone in zip(index_rect, zones_rect):
                zones[i] = zone

            # draw multiline
            if debug:
                for norm_image, one_line_img, target_points, mline_boxes in zip(image_parts,
                                                                    zones_rect, 
                                                                    zones_target_points, 
                                                                    zones_mline_boxes):
                    make_boxes(norm_image, target_points, (0, 0, 255))
                    make_boxes(norm_image, mline_boxes, (255, 0, 0))
                    fig, ax = plt.subplots(figsize=(15, 15))
                    ax.imshow(img[..., ::-1])
                    plt.show()
                    fig, ax = plt.subplots(figsize=(15, 15))
                    ax.imshow(one_line_img)
                    plt.show()

            # find text with postprocessing by standart  
            countLines = [1 for _ in zones]
            textArr = self.textDetector.predict(zones, region_names, countLines)

            # append to predicted
            predicted[os.path.basename(img_path)] = []
            for numberplate, points, region_name in zip(textArr, all_points, region_names):
                predicted[os.path.basename(img_path)].append({
                    'x1': points[0], 
                    'x2': points[2], 
                    'y1': points[1], 
                    'y2': points[3], 
                    'region_name': region_name.strip(),
                    'numberplate': numberplate.strip(),
                })

            # dispaly debug info
            if debug:
                trues = [translit_cyrillic_to_latin(np["numberplate"]) 
                                       if np["region_name"] != "su" else  np["numberplate"]
                                       for np in dataset_info]
                plt.imshow(img)
                plt.show()
                print(img_path, colored(trues, "blue"))
                for zone, numberplate, points, region_name in zip(convertCvZonesRGBtoBGR(zones), 
                                                                  textArr, 
                                                                  all_points, 
                                                                  region_names):
                    color = "yellow"
                    if numberplate in trues:
                        counter["good"] += 1
                        color = "green"
                    else:
                        counter["bad"] += 1

                    plt.imshow(zone)
                    plt.show()
                    print(colored(json.dumps({
                        'region_name': region_name.strip(),
                        'numberplate': numberplate.strip(),
                    }), color))

        
        if debug:             
            print(colored(str(counter), 'blue'))
        return predicted
    
    def compare(self, 
                photo_dir = os.path.join(NOMEROFF_NET_DIR, 'dataset/np/'),
                json_path = os.path.join(NOMEROFF_NET_DIR, 'dataset/np/via.json'),
                use_target_box_from_dataset=1, 
                use_option_from_dataset=1,
                iou_less_than = 0.9,
                ocr_acc_less_than = 0.7,
                option_acc_less_than = 0.7,
                mask_acc_less_than = 0.7):
        """
        TODO: add more comparisons
        """         
        self.load_dataset(json_path)
        image_paths = [os.path.join(photo_dir, image_name) for image_name in self.dataset]
        self.predict(image_paths,
                     use_option_from_dataset=use_option_from_dataset,
                     use_target_box_from_dataset=use_target_box_from_dataset)

In [5]:
maskDatasetChecker = MaskDatasetChecker()

Loading weights from checkpoint (/mnt/data/var/www/nomeroff-net/NomeroffNet/Base/mcm/./data/./models/NpPointsCraft/craft_mlt/craft_mlt_25k_2020-02-16.pth)
Loading weights of refiner from checkpoint (/mnt/data/var/www/nomeroff-net/NomeroffNet/Base/mcm/./data/./models/NpPointsCraft/craft_refiner/craft_refiner_CTW1500_2020-02-16.pth)


In [6]:
%matplotlib inline
plt.rcParams["figure.figsize"] = (20,10)

In [None]:
num = 12
maskDatasetChecker.compare(
    photo_dir=f'/home/dmitroprobachay/Documents/many_line_{num}',
    json_path=f'/home/dmitroprobachay/Documents//many_lines_{num}.json',
    use_target_box_from_dataset=1, 
    use_option_from_dataset=1
)