In [None]:
import json
import os
import sys
import warnings

import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

import logging

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torchvision
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from ctcdecode import CTCBeamDecoder

logger = logging.getLogger("detectron2")
logger.setLevel(logging.CRITICAL)


TEST_IMAGES_PATH, SAVE_PATH = sys.argv[1:]


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open('texts_big.txt', 'r') as f:
    ALL_TEXTS = set(f.read().split('\n'))

SEGM_MODEL_PATH = "segmentation_big.pth"
OCR_MODEL_PATH = "recognition_V3_5.ckpt"


CONFIG_JSON = {
    "alphabet": ' !|"\'()+,-./0123456789:;=?IN[]ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№',
    "image": {"width": 512, "height": 64}}


def get_contours_from_mask(mask, min_area=5):
    contours, hierarchy = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    contour_list = []
    for contour in contours:
        if cv2.contourArea(contour) >= min_area:
            contour_list.append(contour)
    return contour_list


def get_larger_contour(contours):
    larger_area = 0
    larger_contour = None
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > larger_area:
            larger_contour = contour
            larger_area = area
    return larger_contour


class SEGMpredictor:
    def __init__(self, model_path):
        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
        cfg.MODEL.WEIGHTS = model_path
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.07
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
        cfg.INPUT.FORMAT = "BGR"
        cfg.TEST.DETECTIONS_PER_IMAGE = 1000
        cfg.INPUT.MIN_SIZE_TEST = 900
        cfg.INPUT.MAX_SIZE_TEST = 1500

        self.predictor = DefaultPredictor(cfg)

    def __call__(self, img):
        outputs = self.predictor(img)
        prediction = outputs["instances"].pred_masks.cpu().numpy()
        contours = []
        for pred in prediction:
            contour_list = get_contours_from_mask(pred)
            contours.append(get_larger_contour(contour_list))
        return contours


def get_char_map(alphabet):
    char_map = {value: idx + 1 for (idx, value) in enumerate(alphabet)}
    char_map["<BLANK>"] = 0
    return char_map


class Tokenizer:
    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        dec_words = []
        for word in enc_word_list:
            word_chars = ""
            for idx, char_enc in enumerate(word):
                if char_enc != self.char_map["<BLANK>"] and not (idx > 0 and char_enc == word[idx - 1]):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words


class TokenizerBS:
    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

        self.train_lm_decoder = CTCBeamDecoder(
                                                '_ !|"\'()+,-./0123456789:;=?IN[]ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№',
                                                model_path=f"language_model_big.gz",
                                                alpha=0.8, beta=1.5,
                                                cutoff_top_n=50, cutoff_prob=1.0,
                                                beam_width=100, num_processes=6,
                                                blank_id=0, log_probs_input=True)

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        beam_results, beam_scores, time_steps, out_lens = self.train_lm_decoder.decode(enc_word_list)
        text_preds = ['' for x in range(len(beam_results))]
        for i in range(len(beam_results)):
            hyp_len = out_lens[i][0]
            for x in range(int(hyp_len)):
                if beam_results[i, 0, x] > 0:
                    text_preds[i] += '_ !|"\'()+,-./0123456789:;=?IN[]ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№'[beam_results[i, 0, x]]
        return text_preds


class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr


class MoveChannels:
    def __init__(self, to_channels_first=True):
        self.to_channels_first = to_channels_first

    def __call__(self, image):
        if self.to_channels_first:
            return np.moveaxis(image, -1, 0)
        else:
            return np.moveaxis(image, 0, -1)


class ImageResize:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, image):
        h, w, c = image.shape
        new_image = np.zeros((self.height, self.width, c), np.uint8)
        scale = self.height / h
        if int(w * scale) <= self.width:
            image = cv2.resize(image, (int(w * scale), self.height), interpolation=cv2.INTER_LINEAR)
            new_image[:, :image.shape[1], :] = image
        else:
            new_height = int(self.height * (self.width / int(w * scale)))
            image = cv2.resize(image, (self.width, new_height), interpolation=cv2.INTER_LINEAR)
            new_image[(self.height - new_height) // 2:-((self.height - new_height) - (self.height - new_height) // 2),
            :image.shape[1], :] = image

        return new_image


def get_val_transforms(height, width):
    transforms = torchvision.transforms.Compose(
        [
            ImageResize(height, width),
            MoveChannels(to_channels_first=True),
            Normalize(),
            ToTensor(),
        ]
    )
    return transforms


def get_resnet34_new_backbone(pretrained=True):
    m = torchvision.models.resnet34(pretrained=False)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)


class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out


class CRNN(nn.Module):
    def __init__(self, number_class_symbols, time_feature_count=256, lstm_hidden=512, lstm_len=3, pretrained=True):
        super().__init__()
        self.feature_extractor = get_resnet34_new_backbone(pretrained=pretrained)
        self.avg_pool = nn.AdaptiveAvgPool2d((lstm_hidden, time_feature_count))
        self.bilstm = BiLSTM(lstm_hidden, lstm_hidden, lstm_len)
        self.classifier = nn.Sequential(
            nn.Linear(lstm_hidden*2, 368),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(368, number_class_symbols))

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x


def predict(images, model, tokenizer, tokenizer_bs, device):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
    output1 = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    output2 = output.transpose(0, 1).detach()
    preds_grid = tokenizer.decode(output1)
    preds_beam = tokenizer_bs.decode(output2)

    final_preds = []
    for i in range(len(preds_grid)):
        if preds_grid[i] in ALL_TEXTS or preds_beam[i] not in ALL_TEXTS:
            final_preds.append(preds_grid[i])
        else:
            final_preds.append(preds_beam[i])

    return final_preds


class InferenceTransform:
    def __init__(self, height, width):
        self.transforms = get_val_transforms(height, width)

    def __call__(self, images):
        transformed_images = []
        for image in images:
            image = self.transforms(image)
            transformed_images.append(image)
        transformed_tensor = torch.stack(transformed_images, 0)
        return transformed_tensor


class OcrPredictor:
    def __init__(self, model_path, config, device="cuda"):
        self.tokenizer_bs = TokenizerBS(config["alphabet"])
        self.tokenizer = Tokenizer(config['alphabet'])
        self.device = torch.device(device)
        # load model
        self.model = CRNN(number_class_symbols=self.tokenizer.get_num_chars())
        self.model.load_state_dict(torch.load(model_path))
        self.model.to(self.device)

        self.transforms = InferenceTransform(
            height=config["image"]["height"],
            width=config["image"]["width"])

    def __call__(self, images):
        if isinstance(images, (list, tuple)):
            one_image = False
        elif isinstance(images, np.ndarray):
            images = [images]
            one_image = True
        else:
            raise Exception(
                f"Input must contain np.ndarray, "
                f"tuple or list, found {type(images)}.")

        images = self.transforms(images)
        pred = predict(images, self.model, self.tokenizer, self.tokenizer_bs, self.device)

        if one_image:
            return pred[0]
        else:
            return pred


def crop_img_by_polygon(img, polygon):
    pts = np.array(polygon)
    rect = cv2.boundingRect(pts)
    x, y, w, h = rect
    croped = img[y : y + h, x : x + w].copy()
    pts = pts - pts.min(axis=0)
    mask = np.zeros(croped.shape[:2], np.uint8)
    cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
    dst = cv2.bitwise_and(croped, croped, mask=mask)
    return dst


class PiepleinePredictor:
    def __init__(self, segm_model_path, ocr_model_path, ocr_config):
        self.segm_predictor = SEGMpredictor(model_path=segm_model_path)
        self.ocr_predictor = OcrPredictor(
            model_path=ocr_model_path,
            config=CONFIG_JSON
        )

    def __call__(self, img):
        output = {'predictions': []}
        contours = self.segm_predictor(img)
        crops = []
        new_contours = []
        for contour in contours:
            if contour is not None:
                crop = crop_img_by_polygon(img, contour)
                crops.append(crop)
                new_contours.append(contour)

        pred_texts = self.ocr_predictor(crops)
        for j in range(len(pred_texts)):
            output['predictions'].append(
                {'polygon': [[int(i[0][0]), int(i[0][1])] for i in new_contours[j]],
                 'text': pred_texts[j]})
        return output


def main():
    pipeline_predictor = PiepleinePredictor(
        segm_model_path=SEGM_MODEL_PATH,
        ocr_model_path=OCR_MODEL_PATH,
        ocr_config=CONFIG_JSON,
    )
    pred_data = {}
    for img_name in tqdm(os.listdir(TEST_IMAGES_PATH)):
        image = cv2.imread(os.path.join(TEST_IMAGES_PATH, img_name))
        pred_data[img_name] = pipeline_predictor(image)

    with open(SAVE_PATH, "w") as f:
        json.dump(pred_data, f)


if __name__ == "__main__":
    main()
