In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import zipfile
with zipfile.ZipFile('data_final.zip') as zf:
    zf.extractall('')

In [None]:
!python -m pip install -qU detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.6/index.html
!pip install -q torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -q tensorflow==2.1.0
!pip install -q opencv-python==4.5.4.60
!pip install -q git+https://github.com/parlance/ctcdecode.git

[K     |████████████████████████████████| 5.6 MB 17 kB/s 
[K     |████████████████████████████████| 47 kB 3.9 MB/s 
[K     |████████████████████████████████| 74 kB 3.5 MB/s 
[K     |████████████████████████████████| 596 kB 34.4 MB/s 
[K     |████████████████████████████████| 112 kB 75.0 MB/s 
[?25h  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 708.0 MB 10 kB/s 
[K     |████████████████████████████████| 5.9 MB 51.9 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.6.0+cu101 which is incompatible.
torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.6.0+cu101 which is incompatible.[0m
[K     |████████████████████████████████| 421.8 MB 22 k

In [None]:
import json
import os
import sys
import warnings

import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

import logging

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torchvision
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from ctcdecode import CTCBeamDecoder

logger = logging.getLogger("detectron2")
logger.setLevel(logging.CRITICAL)

# TEST_IMAGES_PATH, SAVE_PATH = sys.argv[1:]
TEST_IMAGES_PATH, SAVE_PATH = '/content/data_final/train_segmentation/images', 'prediction.json'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEGM_MODEL_PATH = "segmentation_3_2.pth"
OCR_MODEL2_PATH = "recognition_V4_3.pth"
OCR_MODEL1_PATH = "CRNN_SOS_V2_3.pth"
RUS_TEXTS_PATH = 'rus_texts.txt'
ENG_TEXTS_PATH = 'eng_texts.txt'
RUS_BEAM_SEARCH_PATH = f"rus_language_model.gz"
ENG_BEAM_SEARCH_PATH = f"eng_language_model.gz"

with open(RUS_TEXTS_PATH, 'r') as f:
    RUS_TEXTS = set(f.read().split('\n'))

with open(ENG_TEXTS_PATH, 'r') as f:
    ENG_TEXTS = set(f.read().split('\n'))

CONFIG_JSON = {
    "alphabet": ' !"%\'()+,-./0123456789:;=?DFGIJLNRSUVWY[]bdfghijlnqrstuvwyz|ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№',
    "image": {"width": 512, "height": 64}}


def get_contours_from_mask(mask, min_area=5):
    contours, hierarchy = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    contour_list = []
    for contour in contours:
        if cv2.contourArea(contour) >= min_area:
            contour_list.append(contour)
    return contour_list


def get_larger_contour(contours):
    larger_area = 0
    larger_contour = None
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > larger_area:
            larger_contour = contour
            larger_area = area

    """
    Скругляем маску
    """
    if larger_contour is not None:
        larger_contour = cv2.approxPolyDP(larger_contour, 0.01 * cv2.arcLength(larger_contour, True), True)
    return larger_contour


class SEGMpredictor:
    def __init__(self, model_path):
        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
        cfg.MODEL.WEIGHTS = model_path
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.15
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
        cfg.INPUT.FORMAT = "BGR"
        cfg.TEST.DETECTIONS_PER_IMAGE = 1000
        cfg.INPUT.MIN_SIZE_TEST = 1260
        cfg.INPUT.MAX_SIZE_TEST = 1680

        self.small_predictor = DefaultPredictor(cfg)
        cfg.INPUT.MIN_SIZE_TEST = 816
        cfg.INPUT.MAX_SIZE_TEST = 1680
        self.big_predictor = DefaultPredictor(cfg)

    def __call__(self, img):
        if max(img.shape[0], img.shape[1]) > min(img.shape[0], img.shape[1]) * 2:
            outputs = self.big_predictor(img)
        else:
            outputs = self.small_predictor(img)
        prediction = outputs["instances"].pred_masks.cpu().numpy()
        contours = []
        for pred in prediction:
            contour_list = get_contours_from_mask(pred)
            contours.append(get_larger_contour(contour_list))
        return contours


def get_char_map(alphabet):
    char_map = {value: idx + 1 for (idx, value) in enumerate(alphabet)}
    char_map["<BLANK>"] = 0
    return char_map


class Tokenizer:
    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        dec_words = []
        for word in enc_word_list:
            word_chars = ""
            for idx, char_enc in enumerate(word):
                if char_enc != self.char_map["<BLANK>"] and not (idx > 0 and char_enc == word[idx - 1]):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words


class TokenizerBS:
    def __init__(self, alphabet, beam_search_path):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

        self.beam_search = CTCBeamDecoder(
            '_ !"%\'()+,-./0123456789:;=?DFGIJLNRSUVWY[]bdfghijlnqrstuvwyz|ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№',
            model_path=beam_search_path,
            alpha=0.8, beta=3.0,
            cutoff_top_n=40, cutoff_prob=1.0,
            beam_width=70, num_processes=2,
            blank_id=0, log_probs_input=True)

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        beam_results, beam_scores, time_steps, out_lens = self.beam_search.decode(enc_word_list)
        text_preds = ['' for x in range(len(beam_results))]
        for i in range(len(beam_results)):
            hyp_len = out_lens[i][0]
            for x in range(int(hyp_len)):
                if beam_results[i, 0, x] > 0:
                    text_preds[i] += \
                    '_ !"%\'()+,-./0123456789:;=?DFGIJLNRSUVWY[]bdfghijlnqrstuvwyz|ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№'[
                        beam_results[i, 0, x]]
        return text_preds


class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr


class MoveChannels:
    def __init__(self, to_channels_first=True):
        self.to_channels_first = to_channels_first

    def __call__(self, image):
        if self.to_channels_first:
            return np.moveaxis(image, -1, 0)
        else:
            return np.moveaxis(image, 0, -1)


class ImageResize:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, image):
        h, w, c = image.shape
        new_image = np.zeros((self.height, self.width, c), np.uint8)
        scale = self.height / h
        if int(w * scale) <= self.width:
            image = cv2.resize(image, (int(w * scale), self.height), interpolation=cv2.INTER_LINEAR)
            new_image[:, :image.shape[1], :] = image
        else:
            new_height = int(self.height * (self.width / int(w * scale)))
            image = cv2.resize(image, (self.width, new_height), interpolation=cv2.INTER_LINEAR)
            new_image[(self.height - new_height) // 2:-((self.height - new_height) - (self.height - new_height) // 2),
            :image.shape[1], :] = image

        return new_image


def get_val_transforms(height, width):
    transforms = torchvision.transforms.Compose(
        [
            ImageResize(height, width),
            MoveChannels(to_channels_first=True),
            Normalize(),
            ToTensor(),
        ]
    )
    return transforms


def get_resnet34_new_backbone(pretrained=True):
    m = torchvision.models.resnet34(pretrained=False)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)


class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out


class CRNN(nn.Module):
    def __init__(self, number_class_symbols, time_feature_count=256, lstm_hidden=400, lstm_len=3, pretrained=True):
        super().__init__()
        self.feature_extractor = get_resnet34_new_backbone(pretrained=pretrained)
        self.avg_pool = nn.AdaptiveAvgPool2d((time_feature_count, time_feature_count))
        self.bilstm = BiLSTM(time_feature_count, lstm_hidden, lstm_len)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 300),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(300, number_class_symbols))

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x


class CRNN_big(nn.Module):
    def __init__(self, number_class_symbols, time_feature_count=256, lstm_hidden=512, lstm_len=3, pretrained=True):
        super().__init__()
        self.feature_extractor = get_resnet34_new_backbone(pretrained=pretrained)
        self.avg_pool = nn.AdaptiveAvgPool2d((lstm_hidden, time_feature_count))
        self.bilstm = BiLSTM(lstm_hidden, lstm_hidden, lstm_len)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 300),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(300, number_class_symbols))

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x


"""
Узнаем на каком языке пишут
"""
def get_language(preds_grid):
    english_letters = 'DFGJLNRSUVWYbdfghijlnqrstuvwyz'
    russian_letters = 'ЁБГДЖЗИЛПФЦЧШЩЭЮЯбвгджзийлнптуфцчшщъыьэюяё'
    eng, rus = 0, 0
    for word in preds_grid:
        for let in word:
            if let in english_letters:
                eng += 1
            elif let in russian_letters:
                rus += 1
    if eng > rus:
        return 'eng'
    else:
        return 'rus'


"""
Если обе CRNN дают одинаковый выход, то его выбираем, иначе смотрим существует ли слово в нашем list слов
"""
def predict(images, model1, model2, tokenizer, tokenizer_bs_rus, tokenizer_bs_eng, device):
    model1.eval()
    model2.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model1(images)
        output = output.detach().cpu()
        out_for_bs = output.transpose(0, 1)
        out_1 = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()

        output = model2(images)
        output = output.detach().cpu()
        out_2 = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    preds_grid1 = tokenizer.decode(out_1)
    preds_grid2 = tokenizer.decode(out_2)
    language = get_language(preds_grid1)
    if language == 'eng':
        preds_beam = tokenizer_bs_eng.decode(out_for_bs)
    else:
        preds_beam = tokenizer_bs_rus.decode(out_for_bs)

    final_preds = []
    for i in range(len(preds_grid1)):
        if language == 'eng':
            if preds_grid1[i] in ENG_TEXTS or preds_grid1[i] == preds_grid2[i] or (
                    preds_grid2[i] not in ENG_TEXTS and preds_beam[i] not in ENG_TEXTS):
                final_preds.append(preds_grid1[i])
            elif preds_grid2[i] in ENG_TEXTS:
                final_preds.append(preds_grid2[i])
            else:
                final_preds.append(preds_beam[i])
        else:
            if preds_grid1[i] in RUS_TEXTS or preds_grid1[i] == preds_grid2[i] or (
                    preds_grid2[i] not in RUS_TEXTS and preds_beam[i] not in RUS_TEXTS):
                final_preds.append(preds_grid1[i])
            elif preds_grid2[i] in RUS_TEXTS:
                final_preds.append(preds_grid2[i])
            else:
                final_preds.append(preds_beam[i])

    final_preds = language_sort(final_preds, language)
    return final_preds


class InferenceTransform:
    def __init__(self, height, width):
        self.transforms = get_val_transforms(height, width)

    def __call__(self, images):
        transformed_images = []
        for image in images:
            image = self.transforms(image)
            transformed_images.append(image)
        transformed_tensor = torch.stack(transformed_images, 0)
        return transformed_tensor


class OcrPredictor:
    def __init__(self, model1_path, model2_path, config, rus_beam_search_path, eng_beam_search_path, device="cuda"):
        self.tokenizer_bs_rus = TokenizerBS(config["alphabet"], rus_beam_search_path)
        self.tokenizer_bs_eng = TokenizerBS(config["alphabet"], eng_beam_search_path)
        self.tokenizer = Tokenizer(config['alphabet'])
        self.device = torch.device(device)
        # load model
        self.model1 = CRNN_big(number_class_symbols=self.tokenizer.get_num_chars())
        self.model1.load_state_dict(torch.load(model1_path))
        self.model1.to(self.device)

        self.model2 = CRNN(number_class_symbols=self.tokenizer.get_num_chars())
        self.model2.load_state_dict(torch.load(model2_path))
        self.model2.to(self.device)

        self.transforms = InferenceTransform(
            height=config["image"]["height"],
            width=config["image"]["width"])

    def __call__(self, images):
        if isinstance(images, (list, tuple)):
            one_image = False
        elif isinstance(images, np.ndarray):
            images = [images]
            one_image = True
        else:
            raise Exception(
                f"Input must contain np.ndarray, "
                f"tuple or list, found {type(images)}.")

        images = self.transforms(images)
        text_preds = predict(images, self.model1, self.model2, self.tokenizer, self.tokenizer_bs_rus,
                             self.tokenizer_bs_eng, self.device)

        return text_preds


def crop_img_by_polygon(img, polygon):
    pts = np.array(polygon)
    rect = cv2.boundingRect(pts)
    x, y, w, h = rect
    croped = img[y: y + h, x: x + w].copy()
    pts = pts - pts.min(axis=0)
    mask = np.zeros(croped.shape[:2], np.uint8)
    cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
    dst = cv2.bitwise_and(croped, croped, mask=mask)
    return dst


"""
Фильтруем prediction CRNN моделей. Если пишут на английском, то переводим prediction в кириллицу
"""
def language_sort(pred_texts, language):
    english_letters = 'DFGJLNRSUVWYbdfghijlnqrstuvwyz'
    russian_letters = 'ЁБГДЖЗИЛПФЦЧШЩЭЮЯбгджзийлнпуфцчшщъыьэюяё'

    if language == 'eng':
        english_conv = {' ': ' ', '!': '!', '"': '"', '%': '%', "'": "'", '(': '(', ')': ')', '+': '+', ',': ',',
                        '-': '-', '.': '.', '/': '/', '0': '0', '1': '1', '2': '2', '3': '3', '4': '4', '5': '5',
                        '6': '6', '7': '7', '8': '8', '9': '9', ':': ':', ';': ';', '=': '=', '?': '?', 'D': 'D',
                        'F': 'F', 'G': 'G', 'I': 'I', 'J': 'J', 'L': 'L', 'N': 'N', 'R': 'R', 'S': 'S', 'U': 'U',
                        'V': 'V', 'W': 'W', 'Y': 'Y', '[': '[', ']': ']', '_': '_', 'b': 'b', 'd': 'd', 'f': 'f',
                        'g': 'g', 'h': 'h', 'i': 'i', 'j': 'j', 'l': 'l', 'n': 'n', 'q': 'q', 'r': 'r', 's': 's',
                        't': 't', 'u': 'u', 'v': 'v', 'w': 'w', 'y': 'y', 'z': 'z', '|': '|', 'Ё': 'Ё', 'А': 'A',
                        'Б': 'Б', 'В': 'B', 'Г': 'Г', 'Д': 'Д', 'Е': 'E', 'Ж': 'Ж', 'З': 'З', 'И': 'И', 'К': 'K',
                        'Л': 'Л', 'М': 'M', 'Н': 'H', 'О': 'O', 'П': 'П', 'Р': 'P', 'С': 'C', 'Т': 'T', 'У': 'Y',
                        'Ф': 'Ф', 'Х': 'X', 'Ц': 'Ц', 'Ч': 'Ч', 'Ш': 'Ш', 'Щ': 'Щ', 'Э': 'Э', 'Ю': 'Ю', 'Я': 'Я',
                        'а': 'a', 'б': 'б', 'в': 'B', 'г': 'г', 'д': 'д', 'е': 'e', 'ж': 'ж', 'з': 'з', 'и': 'и',
                        'й': 'й', 'к': 'k', 'л': 'л', 'м': 'м', 'н': 'н', 'о': 'o', 'п': 'п', 'р': 'p', 'с': 'c',
                        'т': 'T', 'у': 'y', 'ф': 'ф', 'х': 'x', 'ц': 'ц', 'ч': 'ч', 'ш': 'ш', 'щ': 'щ', 'ъ': 'ъ',
                        'ы': 'ы', 'ь': 'ь', 'э': 'э', 'ю': 'ю', 'я': 'я', 'ё': 'ё', '№': '№'}
        new_outputs = []
        for word in pred_texts:
            new_outputs.append(''.join(list(english_conv[x] for x in word if x not in russian_letters)))
        pred_texts = new_outputs
    else:
        new_outputs = []
        for word in pred_texts:
            new_outputs.append(''.join(list(x for x in word if x not in english_letters)))
    return pred_texts


class PiepleinePredictor:
    def __init__(self, segm_model_path, ocr_model1_path, ocr_model2_path, ocr_config, rus_beam_search_path,
                 eng_beam_search_path):
        self.segm_predictor = SEGMpredictor(model_path=segm_model_path)
        self.ocr_predictor = OcrPredictor(
            model1_path=ocr_model1_path,
            model2_path=ocr_model2_path,
            config=CONFIG_JSON,
            rus_beam_search_path=rus_beam_search_path,
            eng_beam_search_path=eng_beam_search_path
        )

    def __call__(self, img):
        output = {'predictions': []}
        contours = self.segm_predictor(img)
        crops = []
        new_contours = []
        for contour in contours:
            if contour is not None:
                crop = crop_img_by_polygon(img, contour)
                crops.append(crop)
                new_contours.append(contour)

        pred_texts = self.ocr_predictor(crops)
        for j in range(len(pred_texts)):
            if len(pred_texts[j]):
                output['predictions'].append(
                    {'polygon': [[int(i[0][0]), int(i[0][1])] for i in new_contours[j]],
                     'text': pred_texts[j]})

        return output


# def main():
#     pipeline_predictor = PiepleinePredictor(
#         segm_model_path=SEGM_MODEL_PATH,
#         ocr_model1_path=OCR_MODEL1_PATH,
#         ocr_model2_path=OCR_MODEL2_PATH,
#         ocr_config=CONFIG_JSON,
#         rus_beam_search_path=RUS_BEAM_SEARCH_PATH,
#         eng_beam_search_path=ENG_BEAM_SEARCH_PATH
#     )

#     pred_data = {}
#     for img_name in tqdm(os.listdir(TEST_IMAGES_PATH)):
#         image = cv2.imread(os.path.join(TEST_IMAGES_PATH, img_name))
#         pred_data[img_name] = pipeline_predictor(image)

#     with open(SAVE_PATH, "w") as f:
#         json.dump(pred_data, f)


# if __name__ == "__main__":
#     main()


In [None]:
from matplotlib import pyplot as plt

def get_image_visualization(img, pred_data, fontpath, font_koef=50):
    h, w = img.shape[:2]
    font = ImageFont.truetype(fontpath, int(h/font_koef))
    empty_img = Image.new('RGB', (w, h), (255, 255, 255))
    draw = ImageDraw.Draw(empty_img)

    for prediction in pred_data['predictions']:
        polygon = prediction['polygon']
        pred_text = prediction['text']
        cv2.drawContours(img, np.array([polygon]), -1, (0, 255, 0), 2)
        x, y, w, h = cv2.boundingRect(np.array([polygon]))
        draw.text((x, y), pred_text, fill=0, font=font)

    vis_img = np.array(empty_img)
    vis = np.concatenate((img, vis_img), axis=1)
    return vis

In [None]:
pipeline_predictor = PiepleinePredictor(
    segm_model_path=SEGM_MODEL_PATH,
    ocr_model1_path=OCR_MODEL1_PATH,
    ocr_model2_path=OCR_MODEL2_PATH,
    ocr_config=CONFIG_JSON,
    rus_beam_search_path=RUS_BEAM_SEARCH_PATH,
    eng_beam_search_path=ENG_BEAM_SEARCH_PATH
)

In [None]:
img = cv2.imread('/content/data_final/train_segmentation/images/2_2_eng.jpg')
output = pipeline_predictor(img)

vis = get_image_visualization(img, output, 'font.otf')

plt.figure(figsize=(40, 40))
plt.imshow(vis)
plt.show()

Output hidden; open in https://colab.research.google.com to view.