In [27]:
import os
import gc
import json
import glob
from collections import Counter

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
import seaborn as sns
import tqdm

import torchvision
from torchvision import models
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision import transforms

import PIL
from PIL import Image, ImageDraw
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import re

In [2]:
def simplify_contour(contour, n_corners=4):
    n_iter, max_iter = 0, 1000
    lb, ub = 0., 1.

    while True:
        n_iter += 1
        if n_iter > max_iter:
            print('simplify_contour didnt coverege')
            return None

        k = (lb + ub)/2.
        eps = k*cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, eps, True)

        if len(approx) > n_corners:
            lb = (lb + ub)/2.
        elif len(approx) < n_corners:
            ub = (lb + ub)/2.
        else:
            return approx

# Отображаем 4-хугольник в прямоугольник 
# Спасибо ulebok за идею 
# И вот этим ребятам за реализацию: https://www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/
def four_point_transform(image, pts):
    
    rect = order_points(pts)
    
    tl, tr, br, bl = pts
    
    width_1 = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    width_2 = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    max_width = max(int(width_1), int(width_2))
    
    height_1 = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    height_2 = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    max_height = max(int(height_1), int(height_2))
    
    dst = np.array([
        [0, 0],
        [max_width, 0],
        [max_width, max_height],
        [0, max_height]], dtype = "float32")
    
    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(image, M, (max_width, max_height))
    return warped

def order_points(pts):
    rect = np.zeros((4, 2), dtype = "float32")
    
    s = pts.sum(axis = 1)
    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]
    
    diff = np.diff(pts, axis = 1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    
    return rect


# Визуализируем детекцию (4 точки, bounding box и приближенный по маске контур)
def visualize_prediction_plate(file, model, device='cuda', verbose=True, thresh=0.0, 
                               n_colors=None, id_to_name=None):
    img = Image.open(file)
    img_tensor = my_transforms(img)
    model.to(device)
    model.eval()
    with torch.no_grad():
        predictions = model([img_tensor.to(device)])
    prediction = predictions[0]
    
    if n_colors is None:
        n_colors = model.roi_heads.box_predictor.cls_score.out_features
    
    palette = sns.color_palette(None, n_colors)
    
    img = cv2.imread(file, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    image = img
    
    blackImg = np.zeros(image.shape, image.dtype)
    blackImg[:,:] = (0, 0, 0)
    for i in range(len(prediction['boxes'])):
        x_min, y_min, x_max, y_max = map(int, prediction['boxes'][i].tolist())
        label = int(prediction['labels'][i].cpu())
        score = float(prediction['scores'][i].cpu())
        mask = prediction['masks'][i][0, :, :].cpu().numpy()
        name = id_to_name[label]
        color = palette[label]
        
        if verbose:
            if score > thresh:
                print ('Class: {}, Confidence: {}'.format(name, score))
        if score > thresh:            
            crop_img = image[y_min:y_max, x_min:x_max]
            print('Bounding box:')
            show_image(crop_img, figsize=(10, 2))
            
            # В разных версиях opencv этот метод возвращает разное число параметров
            # Оставил для версии colab
#             _,contours,_ = cv2.findContours((mask > 0.05).astype(np.uint8), 1, 1)
            contours,_ = cv2.findContours((mask > 0.05).astype(np.uint8), 1, 1)
            approx = simplify_contour(contours[0], n_corners=4)
            
            if approx is None:
                x0, y0 = x_min, y_min
                x1, y1 = x_max, y_min
                x2, y2 = x_min, y_max
                x3, y3 = x_max, y_max
#                 points = [[x_min, y_min], [x_min, y_max], [x_max, y_min],[x_max, y_max]]
            else:
                x0, y0 = approx[0][0][0], approx[0][0][1]
                x1, y1 = approx[1][0][0], approx[1][0][1]
                x2, y2 = approx[2][0][0], approx[2][0][1]
                x3, y3 = approx[3][0][0], approx[3][0][1]
                
            points = [[x0, y0], [x2, y2], [x1, y1],[x3, y3]]
            
            
            points = np.array(points)
            crop_mask_img = four_point_transform(img, points)
            print('Rotated img:')
            crop_mask_img = cv2.resize(crop_mask_img, (320, 64), interpolation=cv2.INTER_AREA)
            show_image(crop_mask_img, figsize=(10, 2))
            if approx is not None:
                cv2.drawContours(image, [approx], 0, (255,0,255), 3)
            image = cv2.circle(image, (x0, y0), radius=5, color=(0, 0, 255), thickness=-1)
            image = cv2.circle(image, (x1, y1), radius=5, color=(0, 0, 255), thickness=-1)
            image = cv2.circle(image, (x2, y2), radius=5, color=(0, 0, 255), thickness=-1)
            image = cv2.circle(image, (x3, y3), radius=5, color=(0, 0, 255), thickness=-1)
            
            image = cv2.rectangle(image, (x_min, y_min), (x_max, y_max), np.array(color) * 255, 2)
            
    show_image(image)
    return prediction

# Просто показать картинку. С семинара
def show_image(image, figsize=(16, 9), reverse=True):
    plt.figure(figsize=figsize)
    if reverse:
        plt.imshow(image[...,::-1])
    else:
        plt.imshow(image)
    plt.axis('off')
    plt.show()
    

# Переводит предсказания модели в текст. С семинара
def decode(pred, alphabet):
    pred = pred.permute(1, 0, 2).cpu().data.numpy()
    outputs = []
    for i in range(len(pred)):
        outputs.append(pred_to_string(pred[i], alphabet))
    return outputs

def pred_to_string(pred, alphabet):
    seq = []
    for i in range(len(pred)):
        label = np.argmax(pred[i])
        seq.append(label - 1)
    out = []
    for i in range(len(seq)):
        if len(out) == 0:
            if seq[i] != -1:
                out.append(seq[i])
        else:
            if seq[i] != -1 and seq[i] != seq[i - 1]:
                out.append(seq[i])
    out = ''.join([alphabet[c] for c in out])
    return out
    

        
def load_json(file):
    with open(file, 'r') as f:
        return json.load(f)
    
# Чтобы без проблем сериализовывать json. Без него есть нюансы
class npEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.int32):
            return int(obj)
        return json.JSONEncoder.default(self, obj)

In [3]:
TRAIN_SIZE = 0.8

data_path = "C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/"
epochs = 150
batch_size = 4
image_size=256
lr=3e-4
weight_decay = 1e-5
lr_step=None
lr_gamma=None
input_wh="320x64"
augs=0
load=None
q="C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/runs/recognition_baseline/CP-last.pth"
output_dir = "C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/runs/recognition_baseline"

all_marks = load_json(os.path.join(data_path, 'train.json'))
test_start = int(TRAIN_SIZE * len(all_marks))
train_marks = all_marks[:test_start]
val_marks = all_marks[test_start:]

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [13]:
class DetectionDataset(Dataset):
    def __init__(self, marks, img_folder, transforms=None):
        
        self.marks = marks
        self.img_folder = img_folder
        self.transforms = transforms
        
    def __getitem__(self, idx):
        item = self.marks[idx]
        img_path = f'{self.img_folder}{item["file"]}'
        img = Image.open(img_path).convert('RGB')
        w, h = img.size
        
        box_coords = item['nums']
        boxes = []
        labels = []
        masks = []
        for box in box_coords:
            points = np.array(box['box'])  
            x0, y0 = np.min(points[:, 0]), np.min(points[:, 1])
            x2, y2 = np.max(points[:, 0]), np.max(points[:, 1])
            boxes.append([x0, y0, x2, y2])
            labels.append(1)
            
            # Здесь мы наши 4 точки превращаем в маску
            # Это нужно, чтобы кроме bounding box предсказывать и, соответственно, маску :)
            nx, ny = w, h
            poly_verts = points
            x, y = np.meshgrid(np.arange(nx), np.arange(ny))
            x, y = x.flatten(), y.flatten()
            points = np.vstack((x,y)).T
            path = Path(poly_verts)
            grid = path.contains_points(points)
            grid = grid.reshape((ny,nx)).astype(int)
            masks.append(grid)
            
        boxes = torch.as_tensor(boxes)
        labels = torch.as_tensor(labels)
        masks = torch.as_tensor(masks)
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
        }
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        return img, target
    
    
    def __len__(self):
        return len(self.marks)
    
my_transforms = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = DetectionDataset(
    marks=train_marks, 
    img_folder="C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/", 
    transforms=my_transforms
)
val_dataset = DetectionDataset(
    marks=val_marks, 
    img_folder="C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/", 
    transforms=my_transforms
)

def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    drop_last=True,
    num_workers=0,
    collate_fn=collate_fn, 
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    drop_last=False,
    num_workers=0,
    collate_fn=collate_fn, 
)

In [5]:
def get_model():
    
    model = models.detection.maskrcnn_resnet50_fpn(
        pretrained=True, 
        pretrained_backbone=True,
        progress=True, 
        num_classes=91, 
    )

    num_classes = 2
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    box_predictor = FastRCNNPredictor(in_features, num_classes)
    model.roi_heads.box_predictor = box_predictor
    
    mask_predictor = MaskRCNNPredictor(256, 256, num_classes)
    model.roi_heads.mask_predictor = mask_predictor

    # Заморозим все слои кроме последних
    
    for param in model.parameters():
        param.requires_grad = False
        
    for param in model.backbone.fpn.parameters():
        param.requires_grad = True

    for param in model.rpn.parameters():
        param.requires_grad = True

    for param in model.roi_heads.parameters():
        param.requires_grad = True
    
    return model

In [6]:
torch.cuda.empty_cache()
gc.collect()
model = get_model()
# model.load_state_dict(torch.load(DETECTOR_MODEL_PATH))
model.to(device);

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=0.5, verbose=True)

model.train()
for epoch in range(1):
    print_loss = []
    for i, (images, targets) in tqdm.tqdm(enumerate(train_loader), leave=False, position=0, total=len(train_loader)):
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss_dict.values())

        losses.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        print_loss.append(losses.item())
        if (i + 1) % 20 == 0:
            mean_loss = np.mean(print_loss)
            print(f'Loss: {mean_loss:.7f}')
            scheduler.step(mean_loss)
            print_loss = []

  0%|          | 20/5126 [00:28<2:03:01,  1.45s/it]

Loss: 0.7536344


  1%|          | 40/5126 [01:01<2:17:04,  1.62s/it]

Loss: 0.3375057


  1%|          | 60/5126 [01:30<1:56:33,  1.38s/it]

Loss: 0.3029691


  2%|▏         | 80/5126 [01:56<1:59:28,  1.42s/it]

Loss: 0.2605885


  2%|▏         | 100/5126 [02:25<1:43:26,  1.23s/it]

Loss: 0.2530537


  2%|▏         | 120/5126 [02:54<2:19:53,  1.68s/it]

Loss: 0.2245365


  3%|▎         | 140/5126 [03:22<2:00:41,  1.45s/it]

Loss: 0.2399407


  3%|▎         | 160/5126 [03:49<1:35:21,  1.15s/it]

Loss: 0.2170250


  4%|▎         | 180/5126 [04:21<2:08:13,  1.56s/it]

Loss: 0.2389284


  4%|▍         | 200/5126 [04:50<1:51:34,  1.36s/it]

Loss: 0.2080230


  4%|▍         | 220/5126 [05:16<1:48:35,  1.33s/it]

Loss: 0.2123055


  5%|▍         | 240/5126 [05:44<1:54:04,  1.40s/it]

Loss: 0.2302674


  5%|▌         | 260/5126 [06:11<1:34:17,  1.16s/it]

Loss: 0.2124166


  5%|▌         | 280/5126 [06:39<1:45:10,  1.30s/it]

Loss: 0.2249233


  6%|▌         | 300/5126 [07:06<1:59:17,  1.48s/it]

Loss: 0.2120188


  6%|▌         | 320/5126 [07:33<1:52:26,  1.40s/it]

Loss: 0.2301087


  7%|▋         | 340/5126 [08:01<1:38:48,  1.24s/it]

Loss: 0.2056716


  7%|▋         | 360/5126 [08:26<1:50:14,  1.39s/it]

Loss: 0.1907776


  7%|▋         | 380/5126 [08:54<2:09:36,  1.64s/it]

Loss: 0.2165954


  8%|▊         | 400/5126 [09:21<1:50:41,  1.41s/it]

Loss: 0.2017302


  8%|▊         | 420/5126 [09:50<1:55:31,  1.47s/it]

Loss: 0.2177099


  9%|▊         | 440/5126 [10:19<2:02:08,  1.56s/it]

Loss: 0.2045258


  9%|▉         | 460/5126 [10:46<1:49:59,  1.41s/it]

Loss: 0.1865671


  9%|▉         | 480/5126 [11:15<1:55:51,  1.50s/it]

Loss: 0.2142709


 10%|▉         | 500/5126 [11:45<2:12:00,  1.71s/it]

Loss: 0.1929740


 10%|█         | 520/5126 [12:17<2:28:36,  1.94s/it]

Loss: 0.2000619


 11%|█         | 540/5126 [12:44<1:44:36,  1.37s/it]

Loss: 0.1886075


 11%|█         | 560/5126 [13:12<1:39:33,  1.31s/it]

Loss: 0.1821805


 11%|█▏        | 580/5126 [13:44<1:51:02,  1.47s/it]

Loss: 0.1892920


 12%|█▏        | 600/5126 [14:15<1:51:47,  1.48s/it]

Loss: 0.1942454


 12%|█▏        | 620/5126 [14:41<1:48:49,  1.45s/it]

Loss: 0.2019326


 12%|█▏        | 640/5126 [15:11<2:12:24,  1.77s/it]

Loss: 0.1779278


 13%|█▎        | 660/5126 [16:19<2:16:08,  1.83s/it]

Loss: 0.1880165


 13%|█▎        | 680/5126 [16:48<1:43:36,  1.40s/it]

Loss: 0.1968753


 14%|█▎        | 700/5126 [17:16<1:43:24,  1.40s/it]

Loss: 0.1929961


 14%|█▍        | 720/5126 [17:43<1:46:15,  1.45s/it]

Loss: 0.2128598


 14%|█▍        | 740/5126 [18:10<1:36:57,  1.33s/it]

Loss: 0.2047363


 15%|█▍        | 760/5126 [18:36<1:30:38,  1.25s/it]

Loss: 0.1857451


 15%|█▌        | 780/5126 [19:04<1:38:39,  1.36s/it]

Loss: 0.2086233


 16%|█▌        | 800/5126 [19:33<1:45:53,  1.47s/it]

Loss: 0.1888790


 16%|█▌        | 820/5126 [19:58<1:36:56,  1.35s/it]

Loss: 0.1912376


 16%|█▋        | 840/5126 [20:28<1:42:13,  1.43s/it]

Loss: 0.2075391


 17%|█▋        | 860/5126 [20:56<1:58:13,  1.66s/it]

Loss: 0.2116645


 17%|█▋        | 880/5126 [21:23<1:30:58,  1.29s/it]

Loss: 0.2307003


 18%|█▊        | 900/5126 [21:49<1:22:24,  1.17s/it]

Loss: 0.1923443


 18%|█▊        | 920/5126 [22:16<1:25:43,  1.22s/it]

Loss: 0.1841879


 18%|█▊        | 940/5126 [22:46<1:44:09,  1.49s/it]

Loss: 0.2011658


 19%|█▊        | 960/5126 [23:12<1:40:53,  1.45s/it]

Loss: 0.2029260


 19%|█▉        | 980/5126 [23:37<1:22:36,  1.20s/it]

Loss: 0.2010378


 20%|█▉        | 1000/5126 [24:06<1:26:52,  1.26s/it]

Loss: 0.1784599


 20%|█▉        | 1020/5126 [24:32<1:29:58,  1.31s/it]

Loss: 0.1754653


 20%|██        | 1040/5126 [24:59<1:25:55,  1.26s/it]

Loss: 0.1876039


 21%|██        | 1060/5126 [25:26<1:27:57,  1.30s/it]

Loss: 0.2050001


 21%|██        | 1080/5126 [25:53<1:37:52,  1.45s/it]

Loss: 0.1925934


 21%|██▏       | 1100/5126 [26:19<1:25:23,  1.27s/it]

Loss: 0.1996009


 22%|██▏       | 1120/5126 [26:45<1:33:20,  1.40s/it]

Loss: 0.1825033


 22%|██▏       | 1140/5126 [27:11<1:32:31,  1.39s/it]

Loss: 0.2081532


 23%|██▎       | 1160/5126 [27:38<1:32:11,  1.39s/it]

Loss: 0.1898716


 23%|██▎       | 1180/5126 [28:05<1:24:00,  1.28s/it]

Loss: 0.1988452


 23%|██▎       | 1200/5126 [28:32<1:21:31,  1.25s/it]

Loss: 0.1817727


 24%|██▍       | 1220/5126 [29:02<1:48:12,  1.66s/it]

Loss: 0.1910206


 24%|██▍       | 1240/5126 [29:28<1:23:33,  1.29s/it]

Loss: 0.1712056


 25%|██▍       | 1260/5126 [29:56<1:31:08,  1.41s/it]

Loss: 0.1594298


 25%|██▍       | 1280/5126 [30:21<1:25:13,  1.33s/it]

Loss: 0.1852978


 25%|██▌       | 1300/5126 [30:50<1:24:17,  1.32s/it]

Loss: 0.1850690


 26%|██▌       | 1320/5126 [31:15<1:13:45,  1.16s/it]

Loss: 0.1717704


 26%|██▌       | 1340/5126 [31:42<1:19:49,  1.26s/it]

Loss: 0.1748163


 27%|██▋       | 1360/5126 [32:09<1:30:05,  1.44s/it]

Loss: 0.1845105


 27%|██▋       | 1380/5126 [32:39<1:33:46,  1.50s/it]

Loss: 0.2028499


 27%|██▋       | 1400/5126 [33:05<1:23:09,  1.34s/it]

Loss: 0.1772350


 28%|██▊       | 1420/5126 [33:32<1:23:21,  1.35s/it]

Loss: 0.1822123


 28%|██▊       | 1440/5126 [34:00<1:41:42,  1.66s/it]

Loss: 0.1955287


 28%|██▊       | 1460/5126 [34:26<1:19:48,  1.31s/it]

Loss: 0.1930706


 29%|██▉       | 1480/5126 [34:55<1:27:44,  1.44s/it]

Loss: 0.1953443


 29%|██▉       | 1500/5126 [35:21<1:34:34,  1.57s/it]

Loss: 0.1727634


 30%|██▉       | 1520/5126 [35:50<1:46:06,  1.77s/it]

Loss: 0.1988534


 30%|███       | 1540/5126 [36:16<1:13:55,  1.24s/it]

Loss: 0.1857577


 30%|███       | 1560/5126 [36:45<1:17:08,  1.30s/it]

Loss: 0.1873780


 31%|███       | 1580/5126 [37:14<1:19:43,  1.35s/it]

Loss: 0.1712584


 31%|███       | 1600/5126 [37:40<1:10:18,  1.20s/it]

Loss: 0.1723622


 32%|███▏      | 1620/5126 [38:07<1:16:35,  1.31s/it]

Loss: 0.1671536


 32%|███▏      | 1640/5126 [38:34<1:07:39,  1.16s/it]

Loss: 0.1721763


 32%|███▏      | 1660/5126 [39:03<1:22:21,  1.43s/it]

Loss: 0.2007435


 33%|███▎      | 1680/5126 [39:29<1:22:36,  1.44s/it]

Loss: 0.1819250
Epoch    84: reducing learning rate of group 0 to 1.5000e-04.


 33%|███▎      | 1700/5126 [39:56<1:17:18,  1.35s/it]

Loss: 0.1645028


 34%|███▎      | 1720/5126 [40:23<1:15:24,  1.33s/it]

Loss: 0.1676641


 34%|███▍      | 1740/5126 [40:49<1:12:59,  1.29s/it]

Loss: 0.1650101


 34%|███▍      | 1760/5126 [41:13<1:01:41,  1.10s/it]

Loss: 0.1816830


 35%|███▍      | 1780/5126 [41:40<1:08:41,  1.23s/it]

Loss: 0.1771905


 35%|███▌      | 1800/5126 [42:08<1:14:47,  1.35s/it]

Loss: 0.1659345


 36%|███▌      | 1820/5126 [42:34<1:08:32,  1.24s/it]

Loss: 0.1761690


 36%|███▌      | 1840/5126 [42:59<1:12:03,  1.32s/it]

Loss: 0.1801571


 36%|███▋      | 1860/5126 [43:27<1:13:24,  1.35s/it]

Loss: 0.1842524


 37%|███▋      | 1880/5126 [43:55<1:15:06,  1.39s/it]

Loss: 0.1659903


 37%|███▋      | 1900/5126 [44:22<1:18:05,  1.45s/it]

Loss: 0.1873336


 37%|███▋      | 1920/5126 [44:49<1:05:34,  1.23s/it]

Loss: 0.1554004


 38%|███▊      | 1940/5126 [45:15<58:56,  1.11s/it]  

Loss: 0.1740135


 38%|███▊      | 1960/5126 [45:42<1:16:29,  1.45s/it]

Loss: 0.1785432


 39%|███▊      | 1980/5126 [46:07<1:14:21,  1.42s/it]

Loss: 0.2102173


 39%|███▉      | 2000/5126 [46:34<1:12:47,  1.40s/it]

Loss: 0.1882339


 39%|███▉      | 2020/5126 [46:59<1:01:09,  1.18s/it]

Loss: 0.1841723


 40%|███▉      | 2040/5126 [47:25<1:10:02,  1.36s/it]

Loss: 0.1796906


 40%|████      | 2060/5126 [47:51<58:51,  1.15s/it]  

Loss: 0.1642413


 41%|████      | 2080/5126 [48:19<1:05:14,  1.29s/it]

Loss: 0.1973367


 41%|████      | 2100/5126 [48:45<1:12:03,  1.43s/it]

Loss: 0.1586267


 41%|████▏     | 2120/5126 [49:13<1:13:19,  1.46s/it]

Loss: 0.1590215


 42%|████▏     | 2140/5126 [49:40<1:11:57,  1.45s/it]

Loss: 0.1696211


 42%|████▏     | 2160/5126 [50:09<1:06:01,  1.34s/it]

Loss: 0.1870134


 43%|████▎     | 2180/5126 [50:39<1:15:16,  1.53s/it]

Loss: 0.1683654


 43%|████▎     | 2200/5126 [51:05<1:02:51,  1.29s/it]

Loss: 0.1827730


 43%|████▎     | 2220/5126 [51:33<1:09:20,  1.43s/it]

Loss: 0.1858255


 44%|████▎     | 2240/5126 [52:01<1:00:36,  1.26s/it]

Loss: 0.1777642


 44%|████▍     | 2260/5126 [52:27<1:01:11,  1.28s/it]

Loss: 0.1708846


 44%|████▍     | 2280/5126 [52:53<1:03:36,  1.34s/it]

Loss: 0.1917830


 45%|████▍     | 2300/5126 [53:19<57:04,  1.21s/it]  

Loss: 0.1871124


 45%|████▌     | 2320/5126 [53:48<1:18:47,  1.68s/it]

Loss: 0.1887963


 46%|████▌     | 2340/5126 [54:14<56:46,  1.22s/it]  

Loss: 0.1588981
Epoch   117: reducing learning rate of group 0 to 7.5000e-05.


 46%|████▌     | 2360/5126 [54:41<1:05:00,  1.41s/it]

Loss: 0.1592105


 46%|████▋     | 2380/5126 [55:06<57:24,  1.25s/it]  

Loss: 0.1836422


 47%|████▋     | 2400/5126 [55:33<1:07:24,  1.48s/it]

Loss: 0.1778867


 47%|████▋     | 2420/5126 [55:59<58:01,  1.29s/it]  

Loss: 0.1636289


 48%|████▊     | 2440/5126 [56:28<1:05:24,  1.46s/it]

Loss: 0.1622680


 48%|████▊     | 2460/5126 [56:53<53:34,  1.21s/it]  

Loss: 0.1638872


 48%|████▊     | 2480/5126 [57:19<58:11,  1.32s/it]  

Loss: 0.1858202


 49%|████▉     | 2500/5126 [57:44<48:31,  1.11s/it]  

Loss: 0.1975802


 49%|████▉     | 2520/5126 [58:12<1:07:53,  1.56s/it]

Loss: 0.1589748


 50%|████▉     | 2540/5126 [58:39<49:25,  1.15s/it]  

Loss: 0.1659446


 50%|████▉     | 2560/5126 [59:09<1:00:49,  1.42s/it]

Loss: 0.1575736


 50%|█████     | 2580/5126 [59:35<1:00:11,  1.42s/it]

Loss: 0.1951937


 51%|█████     | 2600/5126 [1:00:03<56:02,  1.33s/it]

Loss: 0.1591574


 51%|█████     | 2620/5126 [1:00:33<56:21,  1.35s/it]  

Loss: 0.1762185


 52%|█████▏    | 2640/5126 [1:00:59<56:37,  1.37s/it]  

Loss: 0.1782446


 52%|█████▏    | 2660/5126 [1:01:27<57:20,  1.40s/it]  

Loss: 0.1828445


 52%|█████▏    | 2680/5126 [1:01:54<53:05,  1.30s/it]  

Loss: 0.1657821


 53%|█████▎    | 2700/5126 [1:02:21<1:00:14,  1.49s/it]

Loss: 0.1463116


 53%|█████▎    | 2720/5126 [1:02:46<47:51,  1.19s/it]  

Loss: 0.1705804


 53%|█████▎    | 2740/5126 [1:03:14<54:25,  1.37s/it]  

Loss: 0.1749822


 54%|█████▍    | 2760/5126 [1:03:40<54:48,  1.39s/it]

Loss: 0.1674075


 54%|█████▍    | 2780/5126 [1:04:06<51:11,  1.31s/it]

Loss: 0.1625945


 55%|█████▍    | 2800/5126 [1:04:36<53:14,  1.37s/it]  

Loss: 0.1796984


 55%|█████▌    | 2820/5126 [1:05:03<48:56,  1.27s/it]

Loss: 0.1648903


 55%|█████▌    | 2840/5126 [1:05:31<55:28,  1.46s/it]

Loss: 0.1791333


 56%|█████▌    | 2860/5126 [1:05:58<46:21,  1.23s/it]

Loss: 0.1462908


 56%|█████▌    | 2880/5126 [1:06:23<46:26,  1.24s/it]

Loss: 0.1568069


 57%|█████▋    | 2900/5126 [1:06:51<50:03,  1.35s/it]

Loss: 0.1780265


 57%|█████▋    | 2920/5126 [1:07:16<45:46,  1.24s/it]

Loss: 0.1569904


 57%|█████▋    | 2940/5126 [1:07:43<55:05,  1.51s/it]

Loss: 0.1633112


 58%|█████▊    | 2960/5126 [1:08:12<44:09,  1.22s/it]  

Loss: 0.1787292


 58%|█████▊    | 2980/5126 [1:08:38<53:00,  1.48s/it]

Loss: 0.1722740


 59%|█████▊    | 3000/5126 [1:09:03<46:44,  1.32s/it]

Loss: 0.1556489


 59%|█████▉    | 3020/5126 [1:09:31<51:01,  1.45s/it]

Loss: 0.1534291


 59%|█████▉    | 3040/5126 [1:09:58<49:52,  1.43s/it]

Loss: 0.1607682


 60%|█████▉    | 3060/5126 [1:10:28<57:16,  1.66s/it]  

Loss: 0.1482413


 60%|██████    | 3080/5126 [1:10:55<50:44,  1.49s/it]

Loss: 0.1658233


 60%|██████    | 3100/5126 [1:11:23<48:04,  1.42s/it]

Loss: 0.1589498


 61%|██████    | 3120/5126 [1:11:49<40:03,  1.20s/it]

Loss: 0.1619943


 61%|██████▏   | 3140/5126 [1:12:16<46:27,  1.40s/it]

Loss: 0.1574768


 62%|██████▏   | 3160/5126 [1:12:41<41:38,  1.27s/it]

Loss: 0.1665394


 62%|██████▏   | 3180/5126 [1:13:09<45:39,  1.41s/it]

Loss: 0.1777639


 62%|██████▏   | 3200/5126 [1:13:37<47:47,  1.49s/it]

Loss: 0.1528096


 63%|██████▎   | 3220/5126 [1:14:04<37:04,  1.17s/it]

Loss: 0.1515231


 63%|██████▎   | 3240/5126 [1:14:30<41:08,  1.31s/it]

Loss: 0.1746634


 64%|██████▎   | 3260/5126 [1:14:56<39:45,  1.28s/it]

Loss: 0.1580802


 64%|██████▍   | 3280/5126 [1:15:23<42:21,  1.38s/it]

Loss: 0.1374999


 64%|██████▍   | 3300/5126 [1:15:48<38:33,  1.27s/it]

Loss: 0.1611322


 65%|██████▍   | 3320/5126 [1:16:16<37:52,  1.26s/it]

Loss: 0.1494281


 65%|██████▌   | 3340/5126 [1:16:42<36:17,  1.22s/it]

Loss: 0.1707476


 66%|██████▌   | 3360/5126 [1:17:11<40:25,  1.37s/it]

Loss: 0.1592761


 66%|██████▌   | 3380/5126 [1:17:40<41:27,  1.42s/it]

Loss: 0.1745944


 66%|██████▋   | 3400/5126 [1:18:08<39:10,  1.36s/it]

Loss: 0.1654283


 67%|██████▋   | 3420/5126 [1:18:36<38:02,  1.34s/it]

Loss: 0.1660322


 67%|██████▋   | 3440/5126 [1:19:04<38:01,  1.35s/it]

Loss: 0.1672998


 67%|██████▋   | 3460/5126 [1:19:31<34:10,  1.23s/it]

Loss: 0.1746614


 68%|██████▊   | 3480/5126 [1:19:58<35:36,  1.30s/it]

Loss: 0.1554539


 68%|██████▊   | 3500/5126 [1:20:23<31:52,  1.18s/it]

Loss: 0.1845854


 69%|██████▊   | 3520/5126 [1:20:53<37:39,  1.41s/it]

Loss: 0.1630744


 69%|██████▉   | 3540/5126 [1:21:21<36:30,  1.38s/it]

Loss: 0.1606908


 69%|██████▉   | 3560/5126 [1:21:52<34:35,  1.33s/it]

Loss: 0.1667917


 70%|██████▉   | 3580/5126 [1:22:19<33:05,  1.28s/it]

Loss: 0.2117778


 70%|███████   | 3600/5126 [1:22:45<33:18,  1.31s/it]

Loss: 0.1637619


 71%|███████   | 3620/5126 [1:23:10<30:43,  1.22s/it]

Loss: 0.1601621


 71%|███████   | 3640/5126 [1:23:39<35:47,  1.45s/it]

Loss: 0.1550784


 71%|███████▏  | 3660/5126 [1:24:05<32:40,  1.34s/it]

Loss: 0.1729025


 72%|███████▏  | 3680/5126 [1:24:29<24:34,  1.02s/it]

Loss: 0.1559199


 72%|███████▏  | 3700/5126 [1:24:54<28:43,  1.21s/it]

Loss: 0.1614430
Epoch   185: reducing learning rate of group 0 to 3.7500e-05.


 73%|███████▎  | 3720/5126 [1:25:23<38:09,  1.63s/it]

Loss: 0.1784320


 73%|███████▎  | 3740/5126 [1:25:51<32:24,  1.40s/it]

Loss: 0.1607090


 73%|███████▎  | 3760/5126 [1:26:17<28:47,  1.26s/it]

Loss: 0.1527375


 74%|███████▎  | 3780/5126 [1:26:42<27:24,  1.22s/it]

Loss: 0.1770272


 74%|███████▍  | 3800/5126 [1:27:11<33:11,  1.50s/it]

Loss: 0.1796266


 75%|███████▍  | 3820/5126 [1:27:37<34:34,  1.59s/it]

Loss: 0.1672244


 75%|███████▍  | 3840/5126 [1:28:03<27:43,  1.29s/it]

Loss: 0.1811952


 75%|███████▌  | 3860/5126 [1:28:30<27:41,  1.31s/it]

Loss: 0.1572348


 76%|███████▌  | 3880/5126 [1:28:58<29:18,  1.41s/it]

Loss: 0.1641418


 76%|███████▌  | 3900/5126 [1:29:26<30:51,  1.51s/it]

Loss: 0.1521003


 76%|███████▋  | 3920/5126 [1:29:52<30:46,  1.53s/it]

Loss: 0.1716250


 77%|███████▋  | 3940/5126 [1:30:19<29:17,  1.48s/it]

Loss: 0.1657639


 77%|███████▋  | 3960/5126 [1:30:48<23:35,  1.21s/it]

Loss: 0.1586927


 78%|███████▊  | 3980/5126 [1:31:16<30:42,  1.61s/it]

Loss: 0.1635067


 78%|███████▊  | 4000/5126 [1:31:44<28:50,  1.54s/it]

Loss: 0.1574199


 78%|███████▊  | 4020/5126 [1:32:11<23:31,  1.28s/it]

Loss: 0.1592584


 79%|███████▉  | 4040/5126 [1:32:38<22:25,  1.24s/it]

Loss: 0.1682878


 79%|███████▉  | 4060/5126 [1:33:04<21:54,  1.23s/it]

Loss: 0.1727464


 80%|███████▉  | 4080/5126 [1:33:32<27:17,  1.57s/it]

Loss: 0.1880667


 80%|███████▉  | 4100/5126 [1:33:59<23:56,  1.40s/it]

Loss: 0.1874403


 80%|████████  | 4120/5126 [1:34:27<25:26,  1.52s/it]

Loss: 0.1713852
Epoch   206: reducing learning rate of group 0 to 1.8750e-05.


 81%|████████  | 4140/5126 [1:34:51<21:41,  1.32s/it]

Loss: 0.1772847


 81%|████████  | 4160/5126 [1:35:20<23:16,  1.45s/it]

Loss: 0.1893535


 82%|████████▏ | 4180/5126 [1:35:46<21:49,  1.38s/it]

Loss: 0.1650515


 82%|████████▏ | 4200/5126 [1:36:11<22:37,  1.47s/it]

Loss: 0.1589452


 82%|████████▏ | 4220/5126 [1:36:39<20:49,  1.38s/it]

Loss: 0.1453548


 83%|████████▎ | 4240/5126 [1:37:07<18:33,  1.26s/it]

Loss: 0.1686509


 83%|████████▎ | 4260/5126 [1:37:34<19:04,  1.32s/it]

Loss: 0.1740572


 83%|████████▎ | 4280/5126 [1:38:04<21:56,  1.56s/it]

Loss: 0.1495422


 84%|████████▍ | 4300/5126 [1:38:31<15:27,  1.12s/it]

Loss: 0.1816374


 84%|████████▍ | 4320/5126 [1:38:59<19:09,  1.43s/it]

Loss: 0.1590448


 85%|████████▍ | 4340/5126 [1:39:27<18:52,  1.44s/it]

Loss: 0.1848043


 85%|████████▌ | 4360/5126 [1:39:53<15:38,  1.23s/it]

Loss: 0.1469025


 85%|████████▌ | 4380/5126 [1:40:19<15:58,  1.28s/it]

Loss: 0.1698477


 86%|████████▌ | 4400/5126 [1:40:47<16:13,  1.34s/it]

Loss: 0.1563174


 86%|████████▌ | 4420/5126 [1:41:15<17:12,  1.46s/it]

Loss: 0.1701780


 87%|████████▋ | 4440/5126 [1:41:42<14:37,  1.28s/it]

Loss: 0.1608978


 87%|████████▋ | 4460/5126 [1:42:09<15:41,  1.41s/it]

Loss: 0.1580277


 87%|████████▋ | 4480/5126 [1:42:36<15:46,  1.47s/it]

Loss: 0.1517141


 88%|████████▊ | 4500/5126 [1:43:02<12:19,  1.18s/it]

Loss: 0.1703975


 88%|████████▊ | 4520/5126 [1:43:30<14:09,  1.40s/it]

Loss: 0.1608185


 89%|████████▊ | 4540/5126 [1:43:59<12:41,  1.30s/it]

Loss: 0.1714178
Epoch   227: reducing learning rate of group 0 to 9.3750e-06.


 89%|████████▉ | 4560/5126 [1:44:27<12:39,  1.34s/it]

Loss: 0.1468985


 89%|████████▉ | 4580/5126 [1:44:54<11:53,  1.31s/it]

Loss: 0.1515427


 90%|████████▉ | 4600/5126 [1:45:23<12:55,  1.47s/it]

Loss: 0.1549464


 90%|█████████ | 4620/5126 [1:45:51<11:09,  1.32s/it]

Loss: 0.1710814


 91%|█████████ | 4640/5126 [1:46:17<10:15,  1.27s/it]

Loss: 0.1805589


 91%|█████████ | 4660/5126 [1:46:45<10:26,  1.34s/it]

Loss: 0.1653263


 91%|█████████▏| 4680/5126 [1:47:13<10:23,  1.40s/it]

Loss: 0.1846912


 92%|█████████▏| 4700/5126 [1:47:41<09:59,  1.41s/it]

Loss: 0.1578781


 92%|█████████▏| 4720/5126 [1:48:10<11:30,  1.70s/it]

Loss: 0.1712740


 92%|█████████▏| 4740/5126 [1:48:36<08:18,  1.29s/it]

Loss: 0.1732364


 93%|█████████▎| 4760/5126 [1:49:01<07:44,  1.27s/it]

Loss: 0.1656310


 93%|█████████▎| 4780/5126 [1:49:30<08:55,  1.55s/it]

Loss: 0.1583676


 94%|█████████▎| 4800/5126 [1:49:58<06:37,  1.22s/it]

Loss: 0.1785679


 94%|█████████▍| 4820/5126 [1:50:24<06:25,  1.26s/it]

Loss: 0.1665169


 94%|█████████▍| 4840/5126 [1:50:52<06:27,  1.35s/it]

Loss: 0.1811399


 95%|█████████▍| 4860/5126 [1:51:20<06:53,  1.55s/it]

Loss: 0.1863207


 95%|█████████▌| 4880/5126 [1:51:49<06:20,  1.55s/it]

Loss: 0.1657516


 96%|█████████▌| 4900/5126 [1:52:15<04:28,  1.19s/it]

Loss: 0.1457749


 96%|█████████▌| 4920/5126 [1:52:45<05:16,  1.53s/it]

Loss: 0.1580007


 96%|█████████▋| 4940/5126 [1:53:11<04:21,  1.41s/it]

Loss: 0.1718912


 97%|█████████▋| 4960/5126 [1:53:39<03:57,  1.43s/it]

Loss: 0.1557122
Epoch   248: reducing learning rate of group 0 to 4.6875e-06.


 97%|█████████▋| 4980/5126 [1:54:05<03:20,  1.37s/it]

Loss: 0.1578837


 98%|█████████▊| 5000/5126 [1:54:32<02:51,  1.36s/it]

Loss: 0.1596596


 98%|█████████▊| 5020/5126 [1:55:01<02:18,  1.31s/it]

Loss: 0.1617564


 98%|█████████▊| 5040/5126 [1:55:28<01:56,  1.36s/it]

Loss: 0.1666776


 99%|█████████▊| 5060/5126 [1:55:54<01:38,  1.49s/it]

Loss: 0.1497626


 99%|█████████▉| 5080/5126 [1:56:23<01:05,  1.41s/it]

Loss: 0.1767452


 99%|█████████▉| 5100/5126 [1:56:51<00:32,  1.26s/it]

Loss: 0.1550209


100%|█████████▉| 5120/5126 [1:57:17<00:07,  1.26s/it]

Loss: 0.1575079


                                                     

In [9]:
with open("C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/runs/segmentation_baseline/CP-last.pth", "wb") as fp:
    torch.save(model.state_dict(), fp)

In [14]:
test_images = glob.glob(os.path.join(data_path, 'test/*'))

In [15]:
THRESHOLD_SCORE = 0.93
TRESHOLD_MASK = 0.05

preds = []
model.eval()


for file in tqdm.tqdm(test_images, position=0, leave=False):

    img = Image.open(file).convert('RGB')
    img_tensor = my_transforms(img)
    with torch.no_grad():
        predictions = model([img_tensor.to(device)])
    prediction = predictions[0]

    pred = dict()
    pred['file'] = file
    pred['nums'] = []

    for i in range(len(prediction['boxes'])):
        x_min, y_min, x_max, y_max = map(int, prediction['boxes'][i].tolist())
        label = int(prediction['labels'][i].cpu())
        score = float(prediction['scores'][i].cpu())
        mask = prediction['masks'][i][0, :, :].cpu().numpy()

        if score > THRESHOLD_SCORE:      
            # В разных версиях opencv этот метод возвращает разное число параметров
            # Оставил для версии colab
            contours,_ = cv2.findContours((mask > TRESHOLD_MASK).astype(np.uint8), 1, 1)
#             _,contours,_ = cv2.findContours((mask > TRESHOLD_MASK).astype(np.uint8), 1, 1)
            approx = simplify_contour(contours[0], n_corners=4)
            
            if approx is None:
                x0, y0 = x_min, y_min
                x1, y1 = x_max, y_min
                x2, y2 = x_min, y_max
                x3, y3 = x_max, y_max
            else:
                x0, y0 = approx[0][0][0], approx[0][0][1]
                x1, y1 = approx[1][0][0], approx[1][0][1]
                x2, y2 = approx[2][0][0], approx[2][0][1]
                x3, y3 = approx[3][0][0], approx[3][0][1]
                
            points = [[x0, y0], [x2, y2], [x1, y1],[x3, y3]]

            pred['nums'].append({
                'box': points,
                'bbox': [x_min, y_min, x_max, y_max],
            })

    preds.append(pred)   

    


 81%|████████  | 2548/3157 [05:50<01:38,  6.16it/s]

simplify_contour didnt coverege


100%|█████████▉| 3145/3157 [07:11<00:01,  7.51it/s]

simplify_contour didnt coverege


                                                   

TypeError: Object of type intc is not JSON serializable

In [28]:
preds

[{'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\0.jpg',
  'nums': [{'box': [[485, 563], [775, 627], [487, 628], [772, 562]],
    'bbox': [489, 560, 771, 628]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\1.jpg',
  'nums': [{'box': [[220, 317], [135, 335], [136, 317], [219, 336]],
    'bbox': [136, 317, 219, 335]},
   {'box': [[511, 292], [588, 308], [511, 309], [583, 292]],
    'bbox': [512, 293, 586, 308]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\10.jpg',
  'nums': [{'box': [[225, 285], [388, 314], [226, 316], [386, 281]],
    'bbox': [228, 281, 385, 316]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\100.jpg',
  'nums': [{'box': [[537, 410], [320, 460], [326, 412], [531, 461]],
    'bbox': [323, 409, 534, 461]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\1000.jpg',
  'nums': [{'box': [[767, 552], [578, 595], [581, 554], [762, 594]],
    '

In [47]:
with open(os.path.join(data_path, 'test.json'), 'wb') as json_file:
    pickle.dump(preds, json_file)

In [4]:
with open(os.path.join(data_path, 'test.json'), 'rb') as json_file:
    a = pickle.load(json_file)

In [5]:
a

[{'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\0.jpg',
  'nums': [{'box': [[485, 563], [775, 627], [487, 628], [772, 562]],
    'bbox': [489, 560, 771, 628]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\1.jpg',
  'nums': [{'box': [[220, 317], [135, 335], [136, 317], [219, 336]],
    'bbox': [136, 317, 219, 335]},
   {'box': [[511, 292], [588, 308], [511, 309], [583, 292]],
    'bbox': [512, 293, 586, 308]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\10.jpg',
  'nums': [{'box': [[225, 285], [388, 314], [226, 316], [386, 281]],
    'bbox': [228, 281, 385, 316]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\100.jpg',
  'nums': [{'box': [[537, 410], [320, 460], [326, 412], [531, 461]],
    'bbox': [323, 409, 534, 461]}]},
 {'file': 'C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/data/test\\1000.jpg',
  'nums': [{'box': [[767, 552], [578, 595], [581, 554], [762, 594]],
    '

In [6]:
rec_model="C:/Users/krvet/Desktop/MADE/CV/HW2/contest02_data/runs/recognition_baseline/CP-last.pth"

In [58]:
class OCRDataset(Dataset):
    def __init__(self, marks, img_folder, alphabet, transforms=None):
        ocr_marks = []
        for items in marks:
            file_path = items['file']
            for box in items['nums']:
                
                ocr_marks.append({
                    'file': file_path,
                    'box': np.clip(box['box'], 0, None).tolist(),
                    'text': box['text'],
                    'boxed': False,
                })
                            
                # Добавим точки, запакованные в BoundingBox. 
                # Вместо аугментации rotate. Датасет будет в 2 раза больше
                
                #Клипаем, ибо есть отрицательные координаты
                points = np.clip(box['box'], 0, None) 
                x0, y0 = np.min(points[:, 0]), np.min(points[:, 1])
                x2, y2 = np.max(points[:, 0]), np.max(points[:, 1])

                ocr_marks.append({
                    'file': file_path,
                    'box': [x0, y0, x2, y2],
                    'text': box['text'],
                    'boxed': True,
                })
                
        self.marks = ocr_marks
        self.img_folder = img_folder
        self.transforms = transforms
        self.alphabet = alphabet
        
    def __getitem__(self, idx):
        item = self.marks[idx]
        img_path = os.path.join(self.img_folder, item["file"])
        img = cv2.imread(img_path)

        if item['boxed']:
            x_min, y_min, x_max, y_max = item['box']
            img = img[y_min:y_max, x_min:x_max]
        else:
            points = np.clip(np.array(item['box']), 0, None)
            img = four_point_transform(img, points)
            
        text = item['text']
        seq = [self.alphabet.find(char) + 1 for char in text]
        seq_len = len(seq)
        
        if self.transforms is not None:
            img = self.transforms(img)

        output = {
            'img': img,
            'text': text,
            'seq': seq,
            'seq_len': seq_len
        }
        
        return output
    
    
    def __len__(self):
        return len(self.marks)
    
    
class Resize(object):
    def __init__(self, size=(320, 64)):
        self.size = size

    def __call__(self, img):

        w_from, h_from = img.shape[1], img.shape[0]
        w_to, h_to = self.size
        
        # Сделаем разную интерполяцию при увеличении и уменьшении
        # Если увеличиваем картинку, меняем интерполяцию
        interpolation = cv2.INTER_AREA
        if w_to > w_from:
            interpolation = cv2.INTER_CUBIC
        
        img = cv2.resize(img, dsize=self.size, interpolation=interpolation)
        return img
    
my_ocr_transforms = transforms.Compose([
    Resize(size=(320, 64)),
    transforms.ToTensor()
])

def get_vocab_from_marks(marks):
    train_texts = []
    for item in marks:
        for num in item['nums']:
            train_texts.append(num['text'])

    counts = Counter(''.join(train_texts))
    alphabet = ''.join(set(''.join(train_texts)))
    corted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True)
    char_to_idx = {item[0]: idx + 1 for idx, item in enumerate(corted_counts)}
    idx_to_char = {idx:char for char, idx in char_to_idx.items()}
    return char_to_idx, idx_to_char, alphabet

char_to_idx, idx_to_char, alphabet = get_vocab_from_marks(train_marks)

train_ocr_dataset = OCRDataset(
    marks=train_marks, 
    img_folder=data_path, 
    alphabet=alphabet,
    transforms=my_ocr_transforms
)
val_ocr_dataset = OCRDataset(
    marks=val_marks, 
    img_folder=data_path, 
    alphabet=alphabet,
    transforms=my_ocr_transforms
)

def collate_fn_ocr(batch):
    """Function for torch.utils.data.Dataloader for batch collecting.
    Accepts list of dataset __get_item__ return values (dicts).
    Returns dict with same keys but values are either torch.Tensors of batched images, sequences, and so.
    """
    images, seqs, seq_lens, texts = [], [], [], []
    for sample in batch:
        images.append(sample["img"])
        seqs.extend(sample["seq"])
        seq_lens.append(sample["seq_len"])
        texts.append(sample["text"])
    images = torch.stack(images)
    seqs = torch.Tensor(seqs).int()
    seq_lens = torch.Tensor(seq_lens).int()
    batch = {"image": images, "seq": seqs, "seq_len": seq_lens, "text": texts}
    return batch

train_ocr_loader = DataLoader(
    train_ocr_dataset, 
    batch_size=batch_size, 
    drop_last=True,
    num_workers=0, # Почему-то у меня виснет DataLoader, если запустить несколько потоков
    collate_fn=collate_fn_ocr,
    timeout=0,
    shuffle=True # Чтобы повернутые дубли картинок не шли подряд
)

val_ocr_loader = DataLoader(
    val_ocr_dataset, 
    batch_size=batch_size, 
    drop_last=False,
    num_workers=0,
    collate_fn=collate_fn_ocr, 
    timeout=0,
)

gc.collect()

7

In [59]:
class FeatureExtractor(nn.Module):
    
    def __init__(self, input_size=(64, 320), output_len=20):
        super(FeatureExtractor, self).__init__()
        
        h, w = input_size
        resnet = getattr(models, 'resnet34')(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])
        
        self.pool = nn.AvgPool2d(kernel_size=(h // 32, 1))        
        self.proj = nn.Conv2d(w // 32, output_len, kernel_size=1)
  
        self.num_output_features = self.cnn[-1][-1].bn2.num_features    
    
    def apply_projection(self, x):
        """Use convolution to increase width of a features.
        Accepts tensor of features (shaped B x C x H x W).
        Returns new tensor of features (shaped B x C x H x W').
        """
        x = x.permute(0, 3, 2, 1).contiguous()
        x = self.proj(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        return x
   
    def forward(self, x):
        # Apply conv layers
        features = self.cnn(x)
        
        # Pool to make height == 1
        features = self.pool(features)
        
        # Apply projection to increase width
        features = self.apply_projection(features)
        
        return features
    
class SequencePredictor(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout=0.3, bidirectional=False):
        super(SequencePredictor, self).__init__()
        
        self.num_classes = num_classes        
        self.rnn = nn.GRU(input_size=input_size,
                       hidden_size=hidden_size,
                       num_layers=num_layers,
                       dropout=dropout,
                       bidirectional=bidirectional)
        
        fc_in = hidden_size if not bidirectional else 2 * hidden_size
        self.fc = nn.Linear(in_features=fc_in,
                         out_features=num_classes)
    
    def _init_hidden_(self, batch_size):
        """Initialize new tensor of zeroes for RNN hidden state.
        Accepts batch size.
        Returns tensor of zeros shaped (num_layers * num_directions, batch, hidden_size).
        """
        num_directions = 2 if self.rnn.bidirectional else 1
        return torch.zeros(self.rnn.num_layers * num_directions, batch_size, self.rnn.hidden_size)
        
    def _prepare_features_(self, x):
        """Change dimensions of x to fit RNN expected input.
        Accepts tensor x shaped (B x (C=1) x H x W).
        Returns new tensor shaped (W x B x H).
        """
        x = x.squeeze(1)
        x = x.permute(2, 0, 1)
        return x
    
    def forward(self, x):
        x = self._prepare_features_(x)
        
        batch_size = x.size(1)
        h_0 = self._init_hidden_(batch_size)
        h_0 = h_0.to(x.device)
        x, h = self.rnn(x, h_0)
        
        x = self.fc(x)
        return x
    
class CRNN(nn.Module):
    
    def __init__(
        self, 
        alphabet=alphabet,
        cnn_input_size=(64, 320), 
        cnn_output_len=20,
        rnn_hidden_size=128, 
        rnn_num_layers=2, 
        rnn_dropout=0.3, 
        rnn_bidirectional=False
    ):
        super(CRNN, self).__init__()
        self.alphabet = alphabet
        
        self.features_extractor = FeatureExtractor(
            input_size=cnn_input_size, 
            output_len=cnn_output_len
        )
        
        self.sequence_predictor = SequencePredictor(
            input_size=self.features_extractor.num_output_features,
            hidden_size=rnn_hidden_size, 
            num_layers=rnn_num_layers,
            num_classes=(len(alphabet) + 1), 
            dropout=rnn_dropout,
            bidirectional=rnn_bidirectional
        )
    
    def forward(self, x):
        features = self.features_extractor(x)
        sequence = self.sequence_predictor(features)
        return sequence

In [60]:
crnn = CRNN()
# crnn.load_state_dict(torch.load(OCR_MODEL_PATH))
crnn.to(device);

In [61]:
optimizer = torch.optim.Adam(crnn.parameters(), lr=3e-4, amsgrad=True, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5, verbose=True)

In [62]:
crnn.train()
for epoch in range(2):
    epoch_losses = []
    print_loss = []
    
    for i, batch in enumerate(tqdm.tqdm(train_ocr_loader, total=len(train_ocr_loader), leave=False, position=0)):
        images = batch["image"].to(device)
        seqs_gt = batch["seq"]
        seq_lens_gt = batch["seq_len"]

        seqs_pred = crnn(images).cpu()
        log_probs = F.log_softmax(seqs_pred, dim=2)
        seq_lens_pred = torch.Tensor([seqs_pred.size(0)] * seqs_pred.size(1)).int()

        loss = F.ctc_loss(
            log_probs=log_probs,  # (T, N, C)
            targets=seqs_gt,  # N, S or sum(target_lengths)
            input_lengths=seq_lens_pred,  # N
            target_lengths=seq_lens_gt # N
        )  

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print_loss.append(loss.item())
    
        epoch_losses.append(loss.item())

    print(i, np.mean(epoch_losses))

  0%|          | 2/10723 [00:00<15:19, 11.66it/s]    

10722 0.9542957681470827


                                                     

10722 0.18197177651135998




In [64]:
test_marks = a
crnn.eval()
resizer = Resize()

file_name_result = [] 
plates_string_result = []

for item in tqdm.tqdm(test_marks, leave=False, position=0):

    img_path = item["file"]
    img = cv2.imread(img_path)

    results_to_sort = []
    for box in item['nums']:
        x_min, y_min, x_max, y_max = box['bbox']
        img_bbox = resizer(img[y_min:y_max, x_min:x_max])
        img_bbox = my_transforms(img_bbox)
        img_bbox = img_bbox.unsqueeze(0)


        points = np.clip(np.array(box['box']), 0, None)
        img_polygon = resizer(four_point_transform(img, points))
        img_polygon = my_transforms(img_polygon)
        img_polygon = img_polygon.unsqueeze(0)

        preds_bbox = crnn(img_bbox.to(device)).cpu().detach()
        preds_poly = crnn(img_polygon.to(device)).cpu().detach()

        preds = preds_poly + preds_bbox
        num_text = decode(preds, alphabet)[0]

        results_to_sort.append((x_min, num_text))

    results = sorted(results_to_sort, key=lambda x: x[0])
    num_list = [x[1] for x in results]

    plates_string = ' '.join(num_list)
    file_name = 'test/' + img_path[img_path.find('/test')+6 :]
    if img_path[img_path.find('/test')+6 :] == '831.jpg':
        file_name= 'test/831.webp'

    file_name_result.append(file_name)
    plates_string_result.append(plates_string)
    
df_submit = pd.DataFrame({'file_name': file_name_result, 'plates_string': plates_string_result})
df_submit.to_csv('submission.csv', index=False)

                                                   