<a href="https://colab.research.google.com/github/sensationadvance/sensationadvance/blob/main/ru_en_ocr_train_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Скачивание датасета

In [None]:
!wget https://storage.yandexcloud.net/datasouls-competitions/ai-nto-final-2022/data.zip

In [None]:
!unzip /content/data.zip

# Импорт

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm.notebook import tqdm
import numpy as np
import editdistance
from sklearn.model_selection import train_test_split
from collections import Counter
import cv2
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torchvision import transforms
import os
import random
from PIL import Image, ImageDraw
import json
from IPython.display import clear_output

from matplotlib import pyplot as plt

#Разделим трейн датасет на обучающую и валидационную подвыборки


In [None]:
labels = pd.read_csv('../input/ocr-dataset/data/train_recognition/labels.csv')
labels_train, labels_val = train_test_split(labels, test_size=0.05, random_state=655)

In [None]:
labels.dropna()

In [None]:
def process_image(img, n_w=256, n_h=64):
    w, h,_ = img.shape
    new_w = n_h
    new_h = int(h * (new_w / w))
    img = cv2.resize(img, (new_h, new_w))
    w, h,_ = img.shape

    if w < n_h:
        add_zeros = np.full((n_h-w, h,3), 0)
        img = np.concatenate((img, add_zeros))
        w, h,_ = img.shape

    if h < n_w:
        add_zeros = np.full((w, n_w-h,3), 0)
        img = np.concatenate((img, add_zeros), axis=1)
        w, h,_ = img.shape

    if h > n_w or w > n_h:
        dim = (n_w,n_h)
        img = cv2.resize(img, dim)

    return img

def replace_black_to_white(image):
    brown_lo = np.array([0,0,0])
    brown_hi = np.array([0,0,0])

    # Mask image to only select browns
    mask = cv2.inRange(image,brown_lo,brown_hi)

    # Change image to red where we found brown
    image[mask>0] = (255,255,255)
    return image

In [None]:
class ExtraLinesAugmentation:
    '''
    Add random black lines to an image
    Args:
        number_of_lines (int): number of black lines to add
        width_of_lines (int): width of lines
    '''

    def __init__(self, number_of_lines: int = 1, width_of_lines: int = 10):
        self.number_of_lines = number_of_lines
        self.width_of_lines = width_of_lines

    def __call__(self, img):
        draw = ImageDraw.Draw(img)
        for _ in range(self.number_of_lines):
            x1 = random.randint(0, np.array(img).shape[1]); y1 = random.randint(0, np.array(img).shape[0])
            x2 = random.randint(0, np.array(img).shape[1]); y2 = random.randint(0, np.array(img).shape[0])
            draw.line((x1, y1, x2 + 100, y2), fill=(100, 0, 0), width=self.width_of_lines)

        return img

In [None]:
def plot_loss_history(train_history, val_history, title='loss'):
    plt.figure()
    plt.title('{}'.format(title))
    plt.plot(train_history, label='train', zorder=1)

    points = np.array(val_history)
    steps = list(range(0, len(train_history) + 1, int(len(train_history) / len(val_history))))[1:]

    plt.scatter(steps, val_history, marker='+', s=180, c='orange', label='val', zorder=2)
    plt.xlabel('train steps')

    plt.legend(loc='best')
    plt.grid()

    plt.show()

## 2. Зададим параметры обучения


In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


config_json = {
    "alphabet": """@ !"%'()+,-./0123456789:;=?AEFIMNOSTW[]abcdefghiklmnopqrstuvwxyАБВГДЕЖЗИКЛМНОПРСТУХЦЧШЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№""",
    #"alphabet": """@(),-.012345:;?I[]БВГДЗИКМНОПРСТУабвгдежзийклмнопрстуфхцчшщыьэюяё""",
    "save_dir": "data/experiments/test",
    "num_epochs": 500,
    "image": {
        "width": 256,
        "height": 64
    },
    "train": {
        "root_path": "../input/ocr-dataset/data/train_recognition/images",
        "batch_size": 64
    },
    "val": {
        "root_path": "../input/ocr-dataset/data/train_recognition/images",
        "batch_size": 128
    }
}

In [None]:
def process_row(text):
    text = text.replace(' ', '|')
    text = list(text)
    for i in range(len(text)):
        if text[i] not in config_json['alphabet'] and text[i] != '|':
            text[i] = '@'
    return " ".join(text)
def prepare_labels(path):
    lines = [line.rstrip() for line in open(path)]
    arr = []
    for line in lines:
        arr.append([line.split('\t')[0], line.split('\t')[1]])
    return arr

## 3. Теперь определим класс датасета (torch.utils.data.Dataset) и другие вспомогательные функции

In [None]:
def black2white(image):
    lo=np.array([0,0,0])

    hi=np.array([0,0,0])

    mask = cv2.inRange(image, lo, hi)

    image[mask>0]=(255,255,255)

    return image

In [None]:
# функция которая помогает объединять картинки и таргет-текст в батч
def collate_fn(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens
def collate_fn_val(batch):
    images, texts, enc_texts = zip(*batch)
    images = torch.stack(images, 0)
    text_lens = torch.LongTensor([len(text) for text in texts])
    enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0)
    return images, texts, enc_pad_texts, text_lens


def get_data_loader(
    transforms, df, root_path, tokenizer, batch_size, drop_last, config, train, shuffle=False
):
    dataset = OCRDataset(df, root_path, tokenizer, config, train, transforms)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=batch_size,
        num_workers=4,
        shuffle=shuffle
    )
    return data_loader

def get_data_loader_val(
    transforms, df, root_path, tokenizer, batch_size, drop_last, config, train, shuffle=False
):
    dataset = OCRDataset(df, root_path, tokenizer, config, train, transforms)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        collate_fn=collate_fn_val,
        batch_size=batch_size,
        num_workers=4,
        shuffle=shuffle
    )
    return data_loader

def prepare_val_image(image, transform1, transform2):
    image3 = image.astype(np.uint8)
    image3 = Image.fromarray(image3)
    image3 = transform1(image3)
    image3 = np.array(image3).astype(np.int64)
    image3 = transform2(image3)
    return image3
class OCRDataset(Dataset):
    def __init__(self, df, root_path, tokenizer, config, train=False, transform=None):
        super().__init__()
        self.transform = transform
        self.config = config
        self.df = df
        self.data_len = len(self.df)
        self.train = train
        self.train_transform = transforms.Compose([
            ExtraLinesAugmentation(number_of_lines=3,
                                   width_of_lines=8),
            transforms.RandomAffine(degrees=0,
                                    scale=(0.935, 0.935),
                                    fillcolor=0),
            transforms.RandomCrop((self.config['image']['height'], self.config['image']['width'])),
            transforms.RandomRotation(degrees=(-12, 12),
                                      fill=255),])
        self.img_paths = []
        self.texts = []
        for i in range(self.data_len):
            self.img_paths.append(os.path.join(root_path, self.df['file_name'].iloc[i]))
            self.texts.append(self.df['text'].iloc[i])
        self.enc_texts = tokenizer.encode(self.texts)
    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        text = self.texts[idx]
        enc_len = 32
        enc_text = self.enc_texts[idx][:enc_len]
        enc_text = enc_text + [0] * (enc_len - len(enc_text))
        enc_text = torch.LongTensor(enc_text)
        image = black2white(cv2.imread(img_path))

        if self.train:
            #image = self.blots(image)
            image = process_image(image,
                                  int(self.config['image']['width'] * 1.05),
                                  int(self.config['image']['height'] * 1.05))

            image = image.astype(np.uint8)
            image = Image.fromarray(image)
            image = self.train_transform(image)
            image = np.array(image).astype(np.int64)
        else:
            image = process_image(image, self.config['image']['width'], self.config['image']['height'])
        if self.transform is not None:
            image = self.transform(image)
        if self.train:
            image = image ** (random.random() * 0.7 + 0.6)
        if self.train == False:
            return image, text, enc_text
        else:
            return image, text, enc_text


class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## 4. Здесь определен Токенайзер - вспопогательный класс, который преобразует текст в числа


In [None]:
CTC_BLANK = '<BLANK>'

def get_char_map(alphabet):
    """Make from string alphabet character2int dict.
    Add BLANK char fro CTC loss and OOV char for out of vocabulary symbols."""
    char_map = {value: idx + 1 for (idx, value) in enumerate(alphabet)}
    char_map[CTC_BLANK] = 0
    return char_map


class Tokenizer:
    """Class for encoding and decoding string word to sequence of int
    (and vice versa) using alphabet."""

    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, word_list):
        """Returns a list of encoded words (int)."""
        enc_words = []
        for word in word_list:
            enc_words.append(
                [self.char_map[char] if char in self.char_map
                 else 1
                 for char in word]
            )
        return enc_words

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        """Returns a list of words (str) after removing blanks and collapsing
        repeating characters. Also skip out of vocabulary token."""
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                # skip if blank symbol, oov token or repeated characters
                if (
                    char_enc != self.char_map[CTC_BLANK]
                    # idx > 0 to avoid selecting [-1] item
                    and not (idx > 0 and char_enc == word[idx - 1])
                ):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words
    def decode_after_beam(self, enc_word_list):
        """Returns a list of words (str) after removing blanks and collapsing
        repeating characters. Also skip out of vocabulary token."""
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

## 5. Accuracy в качестве метрики

Accuracy измеряет долю предсказанных строк текста, которые полностью совпадают с таргет текстом.

In [None]:
def get_accuracy(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        scores.append(true == pred)
    avg_score = np.mean(scores)
    return avg_score

## 6. Аугментации

Здесь мы задаем базовые аугментации для модели. Вы можете написать свои или использовать готовые библиотеки типа albumentations

In [None]:
class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr


class MoveChannels:
    """Move the channel axis to the zero position as required in pytorch."""

    def __init__(self, to_channels_first=True):
        self.to_channels_first = to_channels_first

    def __call__(self, image):
        if self.to_channels_first:
            return np.moveaxis(image, -1, 0)
        else:
            return np.moveaxis(image, 0, -1)


class ImageResize:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, image):
        image = cv2.resize(image, (self.width, self.height),
                           interpolation=cv2.INTER_LINEAR)
        return image



def get_train_transforms(height, width):
    transforms = torchvision.transforms.Compose([
        #ImageResize(height, width),

        MoveChannels(to_channels_first=True),
        Normalize(),
        ToTensor()
    ])
    return transforms


def get_val_transforms(height, width):
    transforms = torchvision.transforms.Compose([
        #ImageResize(height, width),
        MoveChannels(to_channels_first=True),
        Normalize(),
        ToTensor()
    ])
    return transforms

## 7. Здесь определяем саму модель - CRNN

Подробнее об архитектуре можно почитать в статье https://arxiv.org/abs/1507.05717

In [None]:
def get_resnet34_backbone(pretrained=True):
    m = torchvision.models.resnet34(pretrained=True)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)
def get_resnet50_backbone(pretrained=True):
    m = torchvision.models.resnet50(pretrained=True)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out

In [None]:
class CRNN_RESNET(nn.Module):
    def __init__(
        self, number_class_symbols, out_len=32
    ):
        super().__init__()
        self.feature_extractor = get_resnet34_backbone(pretrained=True)
        # веса resnet34 получаются из этого гитхаба https://github.com/lolpa1n/digital-peter-ocrv
        #self.feature_extractor.load_state_dict(torch.load('../input/ocr-resnet/resnet_ocr.pt'))
        self.avg_pool = nn.AdaptiveAvgPool2d(
            (512, out_len))
        self.bilstm = BiLSTM(512, 256, 2)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, number_class_symbols)
        )
    def forward(self, x, return_x=False):
        feature = self.feature_extractor(x)
        b, c, h, w = feature.size()
        feature = feature.view(b, c * h, w)
        feature = self.avg_pool(feature)
        feature = feature.transpose(1, 2)
        out = self.bilstm(feature)
        #print(x.shape)
        out = self.classifier(out)


        x1 = nn.functional.log_softmax(out, dim=2).permute(1, 0, 2)
        if return_x:
            return x1, out
        else:
            return x1

## 8. Переходим к самому скрипту обучения - циклы трейна и валидации

In [None]:
from copy import deepcopy

def val_loop(data_loader, model, criterion, tokenizer, device):
    acc_avg = AverageMeter()
    loss_avg = AverageMeter()
    error_chars = 0
    criterion2 = nn.CrossEntropyLoss()
    total_string = 0
    ctc_weight = 0.9
    for images, texts, enc_pad_texts, text_lens in tqdm(data_loader):
        batch_size = len(texts)
        enc_pad_texts2 = deepcopy(enc_pad_texts.view(-1)).cuda()
        text_preds, output, output2 = predict(images, model, tokenizer, device, return_output=True)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )

        loss1 = criterion(output, enc_pad_texts, output_lenghts, text_lens).mean()
        output2 = output2.view(output2.shape[0] * output2.shape[1], output2.shape[2])
        loss2 = criterion2(output2,
                           enc_pad_texts2)
        loss = ctc_weight * loss1 + (1.0 - ctc_weight) * loss2
        loss_avg.update(loss.item(), batch_size)
        for i in range(batch_size):
            total_string += 1
            error_chars += (editdistance.eval(text_preds[i], texts[i]) / len(texts[i]))
            '''
            if text_preds[i] != texts[i]:
                print('----------------')
                print(f'true: {texts[i]}')
                print(f'pred: {text_preds[i]}')
            '''
        acc_avg.update(get_accuracy(texts, text_preds), batch_size)
    print(f"Val loss average: {loss_avg.avg}")
    print(f'Validation, acc: {acc_avg.avg:.4f}')
    print(f"CER: {error_chars / total_string * 100}%")
    #loss, cer, acc
    return loss_avg.avg, error_chars / total_string, acc_avg.avg

def train_loop(data_loader, model, criterion, optimizer, epoch, train_history=[]):
    loss_avg = AverageMeter()
    model.train()
    criterion2 = nn.CrossEntropyLoss()
    ctc_weight = 0.9
    i = 0
    for images, texts, enc_pad_texts, text_lens in tqdm(data_loader):
        model.zero_grad()
        images = images.to(DEVICE)
        enc_pad_texts2 = deepcopy(enc_pad_texts.view(-1)).cuda()
        batch_size = len(texts)
        output, output2 = model(images, True)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )
        #print(output.permute(1, 0, 2).shape, enc_pad_texts.shape)
        #enc_pad_texts2 = []
        #second_loss = criterion2(output.permute(1, 0, 2), enc_pad_texts)
        enc_pad_texts = enc_pad_texts.flatten()  # make 1dim, the doc says we can do it
        enc_pad_texts = enc_pad_texts[enc_pad_texts != 0]  # drop blank dims

        loss1 = criterion(output, enc_pad_texts, output_lenghts, text_lens).mean()#(criterion(output, enc_pad_texts, output_lenghts, text_lens).mean() + second_loss) / 2
        output2 = output2.view(output2.shape[0] * output2.shape[1], output2.shape[2])
        loss2 = criterion2(output2, enc_pad_texts2)
        loss = ctc_weight * loss1 + (1.0 - ctc_weight) * loss2
        loss_avg.update(loss.item(), batch_size)
        train_history.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
        if i % 100 == 0:
            print('train_loss =', loss)
        i += 1
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, LR: {lr:.7f}')
    return loss_avg.avg, train_history


def predict(images, model, tokenizer, device, return_output=False):
    model.eval()
    images = images.to(device)
    #print(images.shape)
    with torch.no_grad():
        output, output2 = model(images, True)
    #output = process_output(output)
    pred = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy()
    text_preds = tokenizer.decode(pred)
    if return_output:
        return text_preds, output, output2
    else:
        return text_preds


def get_loaders(tokenizer, config, labels_train, labels_val):
    train_transforms = get_train_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    train_loader = get_data_loader(
        df=labels_train,
        root_path=config['train']['root_path'],
        transforms=train_transforms,
        tokenizer=tokenizer,
        batch_size=config['train']['batch_size'],
        drop_last=True,
        config=config,
        train=True,
        shuffle=True

    )
    val_transforms = get_val_transforms(
        height=config['image']['height'],
        width=config['image']['width']
    )
    val_loader = get_data_loader_val(
        df=labels_val,
        transforms=val_transforms,
        root_path=config['val']['root_path'],
        tokenizer=tokenizer,
        batch_size=config['val']['batch_size'],
        drop_last=False,
        config=config,
        train=False,
        shuffle=False
    )
    return train_loader, val_loader


In [None]:
tokenizer = Tokenizer(config_json['alphabet'])
os.makedirs(config_json['save_dir'], exist_ok=True)
train_loader, val_loader = get_loaders(tokenizer, config_json, labels_train, labels_val)

In [None]:
tokenizer.get_num_chars()

In [None]:
model = CRNN_RESNET(tokenizer.get_num_chars(), 32)
model.to(DEVICE)
model.eval()

In [None]:
criterion = torch.nn.CTCLoss(blank=0, reduction='none', zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001,
                                  weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, mode='min', factor=0.5, patience=2)

In [None]:
best_cer = np.inf
best_loss = np.inf
best_epoch = 0
train_history = []
val_history = []
val_loop(val_loader, model, criterion, tokenizer, DEVICE)
for epoch in tqdm(range(config_json['num_epochs'])):
    print("num of epoch", epoch)
    loss_avg, train_history = train_loop(train_loader, model, criterion, optimizer, epoch, train_history)
    print('average_train_loss', loss_avg)
    val_loss_avg, cer_avg, acc_avg = val_loop(val_loader, model, criterion, tokenizer, DEVICE)
    val_history.append(val_loss_avg)
    scheduler.step(cer_avg)
    if cer_avg < best_cer:
        best_cer = cer_avg
        best_epoch = epoch
        best_loss = val_loss_avg
        model_save_path = os.path.join(
            config_json['save_dir'], f'model-{epoch}-{cer_avg:.4f}.ckpt')
        torch.save(model.state_dict(), '/content/' + model_save_path)
        print('Model weights saved')
    clear_output()
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    print(f'Current CER = {cer_avg}')
    print(f'Current loss = {val_loss_avg}')
    print(f'Current acc_avg = {acc_avg}')
    print(f'Current learning rate = {lr}')
    print('-' * 20)
    print(f'Best CER = {best_cer}')
    print(f'Best loss = {best_loss}')
    print(f'Best epoch = {best_epoch}')
    plot_loss_history(train_history, val_history)

In [2]:
!pip install pyctcdecode
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

Collecting pyctcdecode
  Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl.metadata (20 kB)
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode)
  Downloading pygtrie-2.5.0-py3-none-any.whl.metadata (7.5 kB)
Collecting hypothesis<7,>=6.14 (from pyctcdecode)
  Downloading hypothesis-6.115.5-py3-none-any.whl.metadata (6.0 kB)
Collecting sortedcontainers<3.0.0,>=2.1.0 (from hypothesis<7,>=6.14->pyctcdecode)
  Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl.metadata (10 kB)
Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
Downloading hypothesis-6.115.5-py3-none-any.whl (468 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.9/468.9 kB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pygtrie-2.5.0-py3-none-any.whl (25 kB)
Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Installing collected packages: sortedcontainers, pygtrie, hypothesis, pyctcdecode
Successfully installed hypothesis-6.115.5 pyctcdecode-0.5.0 pygtrie-2.5.0 sortedc

In [4]:
!pip install https://github.com/kpu/kenlm/archive/master.zip

Collecting https://github.com/kpu/kenlm/archive/master.zip
  Downloading https://github.com/kpu/kenlm/archive/master.zip (553 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m553.6/553.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: kenlm
  Building wheel for kenlm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for kenlm: filename=kenlm-0.2.0-cp310-cp310-linux_x86_64.whl size=3184482 sha256=0c44e70f4999d94837398e1f195c19a4721d4a7db845fd7de7e73b6dafc6599e
  Stored in directory: /tmp/pip-ephem-wheel-cache-18pqb2v1/wheels/a5/73/ee/670fbd0cee8f6f0b21d10987cb042291e662e26e1a07026462
Successfully built kenlm
Installing collected packages: kenlm
Successfully installed kenlm-0.2.0


In [46]:
import os

os.system('git clone --recursive https://github.com/parlance/ctcdecode.git')
os.system('cd ctcdecode && pip install .')

256

In [51]:
!pip install fast-ctc-decode

Collecting fast-ctc-decode
  Downloading fast_ctc_decode-0.3.6-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (4.6 kB)
Downloading fast_ctc_decode-0.3.6-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (294 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.5/294.5 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fast-ctc-decode
Successfully installed fast-ctc-decode-0.3.6


In [59]:
import json
import os
import sys
import warnings
#from ctcdecode import CTCBeamDecoder
from fast_ctc_decode import beam_search
from copy import deepcopy
from collections import defaultdict
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.filterwarnings("ignore")

from tensorflow.keras import backend as K
import logging
import kenlm
import torch
import torch.nn as nn
import torchvision
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor

In [64]:
logger = logging.getLogger("detectron2")
logger.setLevel(logging.CRITICAL)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEGM_MODEL_PATH = "model_final.pth"
OCR_MODEL_PATH = "model374.ckpt"


CONFIG_JSON = {
    "alphabet": """@ !"%'()+,-./0123456789:;=?AEFIMNOSTW[]abcdefghiklmnopqrstuvwxyАБВГДЕЖЗИКЛМНОПРСТУХЦЧШЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№""",
    #"image": {"width": 256, "height": 64},
    "image": {"width": 2560, "height": 1978},
}


def get_contours_from_mask(mask, min_area=5):
    contours, hierarchy = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
    )
    contour_list = []
    for contour in contours:
        if cv2.contourArea(contour) >= min_area:
            contour_list.append(contour)
    return contour_list


def get_larger_contour(contours):
    larger_area = 0
    larger_contour = None
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > larger_area:
            larger_contour = contour
            larger_area = area
    return larger_contour

def black2white(image):

    lo=np.array([0,0,0])

    hi=np.array([0,0,0])

    mask = cv2.inRange(image, lo, hi)

    image[mask>0]=(255,255,255)

    return image



# class BeamSearchDecoder:
#     def __init__(self, labels_for_bs, model_path, alpha=0.22, beta=1.1, beam_width=10):
#         # Remove blank token (usually index 0) from labels
#         vocab = [label for label in labels_for_bs if label != '']

#         self.decoder = CTCBeamDecoder(
#                       list(labels_for_bs),
#                       model_path='nto_kenlm_model10.arpa',
#                       alpha=0.22,
#                       beta=1.1,
#                       cutoff_top_n=5,
#                       cutoff_prob=1,
#                       beam_width=10,
#                       num_processes=4,
#                       blank_id=0,
#                       log_probs_input=True)


#     def decode(self, logits):
#         # logits shape: (batch_size, time_steps, num_classes)
#         beam_results = self.decoder.decode_batch(logits_list=logits)
#         return beam_results

class BeamSearchDecoder:
    def __init__(self, labels_for_bs, beam_width=10):
        """
        Initializes the BeamSearchDecoder using TensorFlow's K.ctc_decode.

        Parameters:
        labels_for_bs - List of labels (vocabulary), including the blank token.
        beam_width - Beam width for the beam search decoding.
        """
        # Remove the blank token if needed (usually index 0)
        self.labels = [label for label in labels_for_bs if label != '']
        self.labels.append('')  # Add back the blank token at the end
        self.beam_width = beam_width

    def decode(self, logits, seq_lengths):
        """
        Decodes the logits using TensorFlow's CTC beam search decoder.

        Parameters:
        logits - Tensor of shape (batch_size, time_steps, num_classes) with log probabilities.
        seq_lengths - List or array of sequence lengths for each input in the batch.

        Returns:
        Decoded sequences and their log probability scores.
        """
        # Use TensorFlow's K.ctc_decode for beam search decoding
        decoded, log_probabilities = K.ctc_decode(
            logits,
            seq_lengths,
            greedy=False,
            beam_width=self.beam_width,
            top_paths=1  # Return only the top path
        )

        # Check if the decoded result is a SparseTensor, then convert to dense
        decoded_sequences = []
        for dense_seq in decoded:
            if isinstance(dense_seq, tf.SparseTensor):
                decoded_sequences.append(tf.sparse.to_dense(dense_seq).numpy())
            else:
                raise TypeError(f"Expected a SparseTensor, but got {type(dense_seq)}")

        return decoded_sequences, log_probabilities.numpy()

    def map_int_to_labels(self, sequences):
        """
        Maps the integer sequences back to their corresponding label strings.

        Parameters:
        sequences - List of sequences containing integer indices of labels.

        Returns:
        List of decoded strings corresponding to the input sequences.
        """
        decoded_strings = []
        for sequence in sequences:
            decoded_string = ''.join([self.labels[i] for i in sequence if i < len(self.labels)])
            decoded_strings.append(decoded_string)
        return decoded_strings


class SEGMpredictor:
    def __init__(self, model_path):
        cfg = get_cfg()
        cfg.merge_from_file(
            model_zoo.get_config_file(
                "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
            )
        )

        cfg.MODEL.WEIGHTS = model_path
        cfg.TEST.EVAL_PERIOD = 1000

        cfg.INPUT.MIN_SIZE_TRAIN = 2160
        cfg.INPUT.MAX_SIZE_TRAIN = 3130

        cfg.INPUT.MIN_SIZE_TEST = 2160
        cfg.INPUT.MAX_SIZE_TEST = 3130
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1
        cfg.INPUT.FORMAT = 'BGR'
        cfg.DATALOADER.NUM_WORKERS = 4
        cfg.SOLVER.IMS_PER_BATCH = 3
        cfg.SOLVER.BASE_LR = 0.01
        cfg.SOLVER.GAMMA = 0.1
        cfg.SOLVER.STEPS = (1500,)

        cfg.SOLVER.MAX_ITER = 17000
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
        cfg.SOLVER.CHECKPOINT_PERIOD = cfg.TEST.EVAL_PERIOD
        cfg.TEST.DETECTIONS_PER_IMAGE = 1000
        cfg.OUTPUT_DIR = './output'

        self.predictor = DefaultPredictor(cfg)

    def __call__(self, img):
        outputs = self.predictor(img)
        prediction = outputs["instances"].pred_masks.cpu().numpy()
        contours = []
        for pred in prediction:
            contour_list = get_contours_from_mask(pred)
            contours.append(get_larger_contour(contour_list))
        return contours


OOV_TOKEN = "<OOV>"
CTC_BLANK = "<BLANK>"


def get_char_map(alphabet):
    """Make from string alphabet character2int dict.
    Add BLANK char fro CTC loss and OOV char for out of vocabulary symbols."""
    char_map = {value: idx + 1 for (idx, value) in enumerate(alphabet)}
    char_map[CTC_BLANK] = 0
    return char_map

class Tokenizer:
    """Class for encoding and decoding string word to sequence of int
    (and vice versa) using alphabet."""

    def __init__(self, alphabet):
        self.char_map = get_char_map(alphabet)
        self.rev_char_map = {val: key for key, val in self.char_map.items()}

    def encode(self, word_list):
        """Returns a list of encoded words (int)."""
        enc_words = []
        for word in word_list:
            enc_words.append(
                [self.char_map[char] if char in self.char_map
                 else 1
                 for char in word]
            )
        return enc_words

    def get_num_chars(self):
        return len(self.char_map)

    def decode(self, enc_word_list):
        """Returns a list of words (str) after removing blanks and collapsing
        repeating characters. Also skip out of vocabulary token."""
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                # skip if blank symbol, oov token or repeated characters
                if (
                    char_enc != self.char_map[CTC_BLANK]
                    # idx > 0 to avoid selecting [-1] item
                    and not (idx > 0 and char_enc == word[idx - 1])
                ):
                    word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

    def decode_after_beam(self, enc_word_list):
        """Returns a list of words (str) after removing blanks and collapsing
        repeating characters. Also skip out of vocabulary token."""
        dec_words = []
        for word in enc_word_list:
            word_chars = ''
            for idx, char_enc in enumerate(word):
                word_chars += self.rev_char_map[char_enc]
            dec_words.append(word_chars)
        return dec_words

class Normalize:
    def __call__(self, img):
        img = img.astype(np.float32) / 255
        return img


class ToTensor:
    def __call__(self, arr):
        arr = torch.from_numpy(arr)
        return arr


class MoveChannels:
    """Move the channel axis to the zero position as required in pytorch."""

    def __init__(self, to_channels_first=True):
        self.to_channels_first = to_channels_first

    def __call__(self, image):
        if self.to_channels_first:
            return np.moveaxis(image, -1, 0)
        else:
            return np.moveaxis(image, 0, -1)


class ImageResize:
    def __init__(self, height, width):
        self.height = height
        self.width = width

    def __call__(self, image):
        image = cv2.resize(
            image, (self.width, self.height), interpolation=cv2.INTER_LINEAR
        )
        return image


def get_val_transforms(height, width):
    transforms = torchvision.transforms.Compose(
        [
            MoveChannels(to_channels_first=True),
            Normalize(),
            ToTensor(),
        ]
    )
    return transforms


def get_resnet34_backbone(pretrained=True):
    m = torchvision.models.resnet34(pretrained=pretrained)
    input_conv = nn.Conv2d(3, 64, 7, 1, 3)
    blocks = [input_conv, m.bn1, m.relu,
              m.maxpool, m.layer1, m.layer2, m.layer3]
    return nn.Sequential(*blocks)


class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, x):
        out, _ = self.lstm(x)
        return out


class CRNN(nn.Module):
    def __init__(
        self, number_class_symbols
    ):
        super().__init__()
        self.feature_extractor = get_resnet34_backbone(pretrained=False)
        self.avg_pool = nn.AdaptiveAvgPool2d(
            (512, 32))
        self.bilstm = BiLSTM(512, 256, 2)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, number_class_symbols)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        b, c, h, w = x.size()
        x = x.view(b, c * h, w)
        x = self.avg_pool(x)
        x = x.transpose(1, 2)
        x = self.bilstm(x)
        x = self.classifier(x)
        x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2)
        return x


def predict(images, model, tokenizer, device):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
    return output


class InferenceTransform:
    def __init__(self, height, width):
        self.transforms = get_val_transforms(height, width)

    def __call__(self, images):
        transformed_images = []
        for image in images:
            image = self.transforms(image)
            transformed_images.append(image)
        transformed_tensor = torch.stack(transformed_images, 0)
        return transformed_tensor


def process_image(img, n_w=256, n_h=64):
    # img = prepare_image(img)
    w, h, _ = img.shape

    new_w = n_h
    new_h = int(h * (new_w / w))
    img = cv2.resize(img, (new_h, new_w))
    w, h, _ = img.shape

    # img = img.astype('float32')

    if w < n_h:
        add_zeros = np.full((n_h - w, h, 3), 0)
        img = np.concatenate((img, add_zeros))
        w, h, _ = img.shape

    if h < n_w:
        add_zeros = np.full((w, n_w - h, 3), 0)
        img = np.concatenate((img, add_zeros), axis=1)
        w, h, _ = img.shape

    if h > n_w or w > n_h:
        dim = (n_w, n_h)
        img = cv2.resize(img, dim)
    return img



class OcrPredictor:
    def __init__(self, model_path, config, device="cuda"):
        self.tokenizer = Tokenizer(config["alphabet"])
        self.device = torch.device(device)
        # load model
        self.model = CRNN(number_class_symbols=self.tokenizer.get_num_chars())
        self.model.load_state_dict(torch.load(model_path))
        self.model.to(self.device)

        self.transforms = InferenceTransform(
            height=config["image"]["height"],
            width=config["image"]["width"],
        )
        labels_for_bs = """_@|!"%'()+,-./0123456789:;=?AEFIMNOSTW[]abcdefghiklmnopqrstuvwxyАБВГДЕЖЗИКЛМНОПРСТУХЦЧШЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№"""
        # self.decoder = CTCBeamDecoder(
        #        list(labels_for_bs),
        #        model_path='nto_kenlm_model10.arpa',
        #        alpha=0.22,
        #        beta=1.1,
        #        cutoff_top_n=5,
        #        cutoff_prob=1,
        #        beam_width=10,
        #        num_processes=4,
        #        blank_id=0,
        #        log_probs_input=True)
        self.decoder = BeamSearchDecoder(
            labels_for_bs=list(labels_for_bs),  # Label list (including blank token)
            beam_width=10  # Beam width for the decoding process
        )

    def __call__(self, images):
        if isinstance(images, (list, tuple)):
            one_image = False
        elif isinstance(images, np.ndarray):
            images = images
            one_image = True
        else:
            raise Exception(
                f"Input must contain np.ndarray, "
                f"tuple or list, found {type(images)}."
            )
        images = black2white(images)
        images = [process_image(images)]
        images = self.transforms(images)
        # output = predict(images, self.model, self.tokenizer, self.device)
        # beam_results, beam_scores, timesteps, out_lens = self.decoder.decode(output.permute(1, 0, 2))
        # encoded_text = beam_results[0][0][:out_lens[0][0]]
        # text_pred = self.tokenizer.decode_after_beam([encoded_text.numpy()])[0]
        # Get model output (logits)
        output = predict(images, self.model, self.tokenizer, self.device)

        # Permute the logits to match TensorFlow's expected shape: (batch_size, time_steps, num_classes)
        logits = output.permute(1, 0, 2).cpu().detach().numpy()  # Convert to numpy if needed

        # Sequence lengths (assuming the full length for all sequences)
        seq_lengths = [logits.shape[1]] * logits.shape[0]  # Length is the time_steps dimension

        # Decode using TensorFlow's beam search decoder
        decoded_sequences, log_probabilities = self.decoder.decode(logits, seq_lengths)

        # Map the decoded sequences back to label strings
        decoded_strings = self.decoder.map_int_to_labels(decoded_sequences)

        # Since we only have the top path, take the first decoded string
        text_pred = decoded_strings[0]
        return text_pred

        """
        # Preprocess images
images = black2white(images)
images = [process_image(images)]
images = self.transforms(images)

# Get model output (logits)
output = predict(images, self.model, self.tokenizer, self.device)

# Permute the logits to match TensorFlow's expected shape: (batch_size, time_steps, num_classes)
logits = output.permute(1, 0, 2).cpu().detach().numpy()  # Convert to numpy if needed

# Sequence lengths (assuming the full length for all sequences)
seq_lengths = [logits.shape[1]] * logits.shape[0]  # Length is the time_steps dimension

# Decode using TensorFlow's beam search decoder
decoded_sequences, log_probabilities = self.decoder.decode(logits, seq_lengths)

# Map the decoded sequences back to label strings
decoded_strings = self.decoder.map_int_to_labels(decoded_sequences)

# Since we only have the top path, take the first decoded string
text_pred = decoded_strings[0]
        """

def get_image_visualization(img, pred_data, fontpath, font_koef=50):
    h, w = img.shape[:2]
    font = ImageFont.truetype(fontpath, int(h / font_koef))
    empty_img = Image.new("RGB", (w, h), (255, 255, 255))
    draw = ImageDraw.Draw(empty_img)

    for prediction in pred_data["predictions"]:
        polygon = prediction["polygon"]
        pred_text = prediction["text"]
        cv2.drawContours(img, np.array([polygon]), -1, (0, 255, 0), 2)
        x, y, w, h = cv2.boundingRect(np.array([polygon]))
        draw.text((x, y), pred_text, fill=0, font=font)

    vis_img = np.array(empty_img)
    vis = np.concatenate((img, vis_img), axis=1)
    return vis


def crop_img_by_polygon(img, polygon):
    pts = np.array(polygon)
    rect = cv2.boundingRect(pts)
    x, y, w, h = rect
    croped = img[y : y + h, x : x + w].copy()
    pts = pts - pts.min(axis=0)
    mask = np.zeros(croped.shape[:2], np.uint8)
    cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
    dst = cv2.bitwise_and(croped, croped, mask=mask)
    return dst


class PiepleinePredictor:
    def __init__(self, segm_model_path, ocr_model_path, ocr_config):
        self.segm_predictor = SEGMpredictor(model_path=segm_model_path)
        self.ocr_predictor = OcrPredictor(model_path=ocr_model_path, config=ocr_config)

    def __call__(self, img):
        output = {"predictions": []}
        contours = self.segm_predictor(img)
        for contour in contours:
            if contour is not None:
                crop = crop_img_by_polygon(img, contour)
                pred_text = self.ocr_predictor(crop)
                output["predictions"].append(
                    {
                        "polygon": [[int(i[0][0]), int(i[0][1])] for i in contour],
                        "text": pred_text,
                    }
                )
        return output


def main():
    pipeline_predictor = PiepleinePredictor(
        segm_model_path=SEGM_MODEL_PATH,
        ocr_model_path=OCR_MODEL_PATH,
        ocr_config=CONFIG_JSON,
    )
    pred_data = {}
    image = cv2.imread("1_pass_4.png")
    pred_data["1_pass_4.png"] = pipeline_predictor(image)
    print(pred_data)

    # for img_name in tqdm(os.listdir(TEST_IMAGES_PATH)):
    #     image = cv2.imread(os.path.join(TEST_IMAGES_PATH, img_name))
    #     pred_data[img_name] = pipeline_predictor(image)

    # with open(SAVE_PATH, "w") as f:
    #     json.dump(pred_data, f)


if __name__ == "__main__":
    main()

TypeError: Expected a SparseTensor, but got <class 'tensorflow.python.framework.ops.EagerTensor'>

In [1]:
!pip install tensorflow==2.8.0rc0 keras==2.8.0rc0



In [18]:
import sys, os
import logging
import random
import numpy as np
import pandas as pd
from PIL import Image
import requests
import cv2
import numpy as np
from numpy.lib.stride_tricks import as_strided
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras import layers
import re
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import editdistance

print(os.path.abspath(os.getcwd()))
WORKING_DIR = os.path.join('../')

r'''
params:

callbacks: list of callback names
metrics: list of metric names
checkpoint_path: str path for checkpoints
csv_log_path: str path for csv logs
tb_log_path: str path for files to be parsed by TensorBoard
tb_update_freq: 'batch'/'epoch'/int frequency of writing to TensorBoard
epochs: int number of epochs
batch_size: int size of batch
early_stopping_patience: int early stopping patience
input_img_shape: array(width, height, 1)
vocab_len: int length of vocabulary with blank
max_label_len: int max length of labels
chars_path: path to file that contains alphabet
blank: str blank symbol for ctc
'''

class CTCLayer(layers.Layer):

    def __init__(self, blank_index, name=None):
        super().__init__(name=name)
        self.loss_fn = tf.keras.backend.ctc_batch_cost
        self.blank_index = blank_index

    def get_config(self):
        return super().get_config()

    def cer(self, y_true, y_pred, pred_sequence_length, true_sequence_length):
        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")

        pred_codes_dense = K.ctc_decode(y_pred, tf.squeeze(pred_sequence_length, axis=-1), greedy=True)
        # -1 - blank in ctc_decode

        pred_codes_dense = tf.squeeze(tf.cast(pred_codes_dense[0], tf.int64), axis=0)  # only [0] if greedy=true
        idx = tf.where(tf.not_equal(pred_codes_dense, -1))
        pred_codes_sparse = tf.SparseTensor(tf.cast(idx, tf.int64),
                                            tf.gather_nd(pred_codes_dense, idx),
                                            tf.cast(tf.shape(pred_codes_dense), tf.int64))

        idx = tf.where(tf.not_equal(y_true, self.blank_index))
        label_sparse = tf.SparseTensor(tf.cast(idx, tf.int64),
                                       tf.gather_nd(y_true, idx),
                                       tf.cast(tf.shape(y_true), tf.int64))
        label_sparse = tf.cast(label_sparse, tf.int64)

        distances = tf.reduce_sum(tf.edit_distance(pred_codes_sparse, label_sparse, normalize=False))

        # compute chars amount represent in y_true
        count_chars = len(idx)

        return tf.divide(tf.cast(distances, tf.float32), tf.cast(count_chars, tf.float32), name='CER')

    def accuracy(self, y_true, y_pred, pred_sequence_length, true_sequence_length):
        batch_len = tf.shape(y_true)[0]

        pred_codes_dense = K.ctc_decode(y_pred, tf.squeeze(pred_sequence_length, axis=-1), greedy=True)
        # -1 - blank in ctc_decode

        pred_codes_dense = tf.squeeze(tf.cast(pred_codes_dense[0], tf.int64), axis=0)  # only [0] if greedy=true
        idx = tf.where(tf.not_equal(pred_codes_dense, -1))
        pred_codes_sparse = tf.SparseTensor(tf.cast(idx, tf.int64),
                                            tf.gather_nd(pred_codes_dense, idx),
                                            tf.cast(tf.shape(pred_codes_dense), tf.int64))

        idx = tf.where(tf.not_equal(y_true, self.blank_index))
        label_sparse = tf.SparseTensor(tf.cast(idx, tf.int64),
                                       tf.gather_nd(y_true, idx),
                                       tf.cast(tf.shape(y_true), tf.int64))
        label_sparse = tf.cast(label_sparse, tf.int64)

        correct_words_amount = len(
            tf.where(tf.equal(tf.edit_distance(pred_codes_sparse, label_sparse, normalize=False), 0))
        )

        return tf.divide(tf.cast(correct_words_amount, tf.float32), tf.cast(batch_len, tf.float32), name='accuracy')

    def call(self, y_true, y_pred):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.

        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

        input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

        loss = self.loss_fn(y_true, y_pred, input_length, label_length)
        self.add_loss(loss)

        cer = self.cer(y_true, y_pred, input_length, label_length)
        accuracy = self.accuracy(y_true, y_pred, input_length, label_length)
        self.add_metric(cer, name='cer')
        self.add_metric(accuracy, name='accuracy')

        return y_pred

class Speller():
    def __init__(self, char_list, corpus_file):
        corpus = open(corpus_file, encoding='utf8').read()
        self.chars = ''.join(char_list)
        self.non_letter = ''.join(c for c in char_list if not self.__is_cyrillic(c))
        self.dictionary = self.__get_dictionary(corpus)

    def __is_cyrillic(self, char):
        return bool(re.match('[а-яА-ЯёЁ]', char))

    def __get_dictionary(self, corpus):
        non_letter_re = '[' + self.non_letter.replace(' ', '') + ']'
        non_letter_re = non_letter_re.replace('[UNK]', '')
        corpus = re.sub('[!(),-.:;?#]', '', corpus).lower()
        dictionary = set(str.split(corpus))
        return dictionary

    def __get_closest_word(self, word, min_dist_coef):
        flag = word[0].isupper()
        res_word = word = word.lower()
        min_dist = min_dist_coef * len(word)
        for dict_word in self.dictionary:
            dist = editdistance.eval(word, dict_word)
            if dist == 0:
                return word.capitalize() if flag else word
            elif dist < min_dist:
                min_dist = dist
                res_word = dict_word
        return res_word.capitalize() if flag else res_word

    def __compute_label(self, label, min_dist_coef):
        start_i = -1
        res_label = ''
        for i in range(len(label)):
            if label[i] in self.non_letter:
                if start_i >= 0:
                    res_label += self.__get_closest_word(label[start_i:i], min_dist_coef)
                    start_i = -1
                res_label += label[i]
            elif start_i < 0:
                start_i = i
        if start_i >= 0:
            res_label += self.__get_closest_word(label[start_i:i+1], min_dist_coef)
        return res_label

    def compute_batch(self, labels, min_dist_coef=0.3):
        return [self.__compute_label(label, min_dist_coef) for label in labels]

    def compute_img(self, label, min_dist_coef=0.3):
        return self.__compute_label(label, min_dist_coef)

class Model():

    def __init__(self, params):
        self.callbacks = []
        self.epochs = params['epochs']
        self.metrics = params['metrics']
        self.history = dict()

        self.model = None
        self.pred_model = None

        self.input_shape = params['input_img_shape']
        self.batch_size = params['batch_size']
        self.vocab_len = params['vocab_len']
        self.max_label_len = params['max_label_len']
        self.chars_path = params['chars_path']
        self.blank = params['blank']
        self.blank_index = None

        self.vocab = None
        self.num_to_char = None
        self.char_to_num = None
        self.__set_mapping()
        self.speller = Speller(self.vocab, params['corpus'])


        if 'checkpoint' in params['callbacks']:
            self.cp_path = params['checkpoint_path']

            cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=self.cp_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='val_loss',
                                                 mode='min',
                                                 verbose=1)
            self.callbacks.append(cp_callback)

        if 'csv_log' in params['callbacks']:
            self.csv_log_path = params['csv_log_path']
            csv_log_callback = tf.keras.callbacks.CSVLogger(self.csv_log_path,
                                                      append=False, separator=';')
            self.callbacks.append(csv_log_callback)

        if 'tb_log' in params['callbacks']:
            self.tb_log_path = params['tb_log_path']
            tb_log_callback = tf.keras.callbacks.TensorBoard(log_dir=self.tb_log_path,
                                                             update_freq=params['tb_update_freq'])
            self.callbacks.append(tb_log_callback)

        if 'early_stopping' in params['callbacks']:
            early_stopping = keras.callbacks.EarlyStopping(
                monitor="val_loss", patience=params['early_stopping_patience'], restore_best_weights=True
            )
            self.callbacks.append(early_stopping)

    def __set_mapping(self):
        self.vocab = open(self.chars_path, encoding="utf8").read().split("\n")
        # Mapping characters to integers
        self.char_to_num = layers.experimental.preprocessing.StringLookup(
            vocabulary=self.vocab, mask_token=None
        )
        # Mapping integers back to original characters
        self.num_to_char = layers.experimental.preprocessing.StringLookup(
            vocabulary=self.char_to_num.get_vocabulary(), mask_token=None, invert=True
        )
        self.blank_index = self.char_to_num(tf.strings.unicode_split(self.blank, input_encoding="UTF-8")).numpy()[0]

    def load_weights(self, path):
        self.model.load_weights(path)

        self.pred_model = keras.models.Model(
            self.model.get_layer(name='image').input, self.model.get_layer(name='dense2').output
        )

    def build(self):
        self.__set_input()
        self.__set_CNN()
        self.__set_RNN()
        self.__set_output()

        self.model = keras.models.Model(
            inputs=[self.input, self.labels], outputs=self.output, name="htr_model"
        )

        # Optimizer
        opt = keras.optimizers.Adam()

        # Compile the model
        self.model.compile(
            optimizer=opt,
        )

    def __set_input(self):
        self.input = layers.Input(
            shape=self.input_shape, name='image', dtype='float32'
        )
        self.labels = layers.Input(name='label', shape=(None, ), dtype='float32')

    def __set_CNN(self):
        self.x = layers.Conv2D(
            32,
            (5, 5),
            activation="relu",
            kernel_initializer="he_normal",
            padding="same",
            name="Conv1",
        )(self.input)
        self.x = layers.MaxPooling2D((2, 2), name="pool1")(self.x)

        self.x = layers.Conv2D(
            64,
            (3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            padding="same",
            name="Conv2",
        )(self.x)
        self.x = layers.MaxPooling2D((2, 2), name="pool2")(self.x)

        self.x = layers.Conv2D(
            128,
            (3, 3),
            activation="relu",
            kernel_initializer="he_normal",
            padding="same",
            name="Conv3",
        )(self.x)
        self.x = layers.MaxPooling2D((2, 2), name="pool3")(self.x)

        self.x = layers.Conv2D(
            256,
            (2, 2),
            activation="relu",
            kernel_initializer="he_normal",
            padding="same",
            name="Conv4",
        )(self.x)

        new_shape = ((self.input_shape[0] // 8), (self.input_shape[1] // 8) * 256)
        self.x = layers.Reshape(target_shape=new_shape, name="reshape")(self.x)
        self.x = layers.Dense(64, activation="relu", name="dense1")(self.x)
        self.x = layers.Dropout(0.2)(self.x)

    def __set_RNN(self):
        self.x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(self.x)
        self.x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(self.x)

    def __set_output(self):
        self.x = layers.Dense(
            self.vocab_len, activation="softmax", name="dense2"
        )(self.x)
        self.output = CTCLayer(self.blank_index, name="ctc_loss")(self.labels, self.x)

    def get_summary(self):
        return self.model.summary()

    def fit(self, train, val):
        self.history = self.model.fit(
            train,
            validation_data=val,
            epochs=self.epochs,
            callbacks=self.callbacks,
        )

        self.pred_model = keras.models.Model(
            self.model.get_layer(name='image').input, self.model.get_layer(name='dense2').output
        )

        print(f'\n\nmodel weights saved at {self.cp_path}\n\n')

    def decode_batch_predictions(self, pred):
        input_len = np.ones(pred.shape[0]) * pred.shape[1]
        results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
                  :, :self.max_label_len
                  ]

        # Iterate over the results and get back the text
        output_text = []
        for res in results:
            res = tf.strings.reduce_join(self.num_to_char(res)).numpy().decode("utf-8").replace('[UNK]', "")
            output_text.append(res)

        return output_text

    def predict(self, batch):
        batch_images = batch['image']

        pred = self.pred_model.predict(batch_images)
        pred_texts = self.decode_batch_predictions(pred)

        return pred_texts

    def __strided_rescale(self, img, bin_fac):
        strided = as_strided(img, shape=(img.shape[0]//bin_fac, img.shape[1]//bin_fac, bin_fac, bin_fac),
                             strides=((img.strides[0]*bin_fac, img.strides[1]*bin_fac)+img.strides))
        return strided.mean(axis=-1).mean(axis=-1)

    def __resize_img(self, img, new_img_height, new_img_width):
        img_size = np.array(img.shape[:2])
        new_img_size = np.array([new_img_height, new_img_width])
        diff = img_size - new_img_size
        h_ratio = w_ratio = 0
        if diff[0] > 0:
            h_ratio = img_size[0] / new_img_size[0]
        if diff[1] > 0:
            w_ratio = img_size[0] / new_img_size[0]
        if h_ratio != 0 or w_ratio != 0:
            ratio = round(max(h_ratio, w_ratio))
            img = self.__strided_rescale(img, ratio)
        return img

    def __apply_brightness_contrast(self, input_img, brightness=0, contrast=0):
        if brightness != 0:
            if brightness > 0:
                shadow = brightness
                highlight = 255
            else:
                shadow = 0
                highlight = 255 + brightness
            alpha_b = (highlight - shadow)/255
            gamma_b = shadow

            buf = cv2.addWeighted(input_img, alpha_b, input_img, 0, gamma_b)
        else:
            buf = input_img.copy()

        if contrast != 0:
            f = 131*(contrast + 127)/(127*(131-contrast))
            alpha_c = f
            gamma_c = 127*(1-f)

            buf = cv2.addWeighted(buf, alpha_c, buf, 0, gamma_c)

        return buf

    def __encode_img(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mean_val = img.mean()
        if mean_val < 230:
            img = self.__apply_brightness_contrast(img, 230/mean_val*60, 230/mean_val*30)
        img = self.__resize_img(img, self.input_shape[1], self.input_shape[0]).astype(np.uint8)
        img = np.expand_dims(img, 2)
        img = tf.convert_to_tensor(img)
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = 1 - img
        img = tf.image.resize_with_pad(img, self.input_shape[1], self.input_shape[0])
        img = 0.5 - img
        img = tf.transpose(img, perm=[1, 0, 2])
        return tf.expand_dims(img, 0)

    def predict_img(self, img): #img is np.ndarray of shape (height, width, 3) / (height, width, 1) - rgb, gray
        try:
            if len(img.shape) == 2:
                img = np.stack((img,)*3, axis=-1)
            img = self.__encode_img(img)
            pred = self.pred_model.predict(img)
            pred_text = self.decode_batch_predictions(pred)
            return self.speller.compute_img(pred_text[0])
        except ValueError:
            return "Error: Incorrect photo"

    def evaluate(self, batch):
        return self.model.evaluate(batch)

    def get_history(self):
        return self.history.history


img_width = 900
img_height = 120

# parameters of resized images
new_img_width = 350
new_img_height = 50

batch_size = 16

logging.basicConfig(format=u'%(filename)s [ LINE:%(lineno)+3s ]#%(levelname)+8s [%(asctime)s]  %(message)s',
                    level=logging.INFO)


model_params = {
    'callbacks': ['checkpoint', 'csv_log', 'tb_log', 'early_stopping'],
    'metrics': ['cer', 'accuracy'],
    'checkpoint_path': 'cp.ckpt',
    'csv_log_path': 'log_2.csv',
    'tb_log_path': 'log2',
    'tb_update_freq': 200,
    'epochs': 50,
    'batch_size': batch_size,
    'early_stopping_patience': 10,
    'input_img_shape': (new_img_width, new_img_height, 1),
    'vocab_len': 75,
    'max_label_len': 22,
    'chars_path': 'symbols.txt',
    'blank': '#',
    'blank_index': 74,
    'corpus': 'corpus.txt'
}

model = Model(model_params)
model.build()
model.load_weights('cp.ckpt')

img = np.array(Image.open("testt677.png"))
predicted_text = model.predict_img(img)
print(predicted_text)



/content




Анастасия


In [13]:
!pip install editdistance



In [5]:
!ls

1_pass_4.png  cp2.ckpt	     model374.ckpt    nto_kenlm_model10.arpa  src
corpus.txt    metadata.json  model_final.pth  sample_data	      symbols.txt


In [21]:
!pip install onnxruntime-gpu



In [19]:
!pip install -e git+https://github.com/felixdittrich92/OnnxTR.git#egg=onnxtr[gpu-headless,viz]
!pip install gradio

[33mDEPRECATION: git+https://github.com/felixdittrich92/OnnxTR.git#egg=onnxtr[gpu-headless,viz] contains an egg fragment with a non-PEP 508 name pip 25.0 will enforce this behaviour change. A possible replacement is to use the req @ url syntax, and remove the egg fragment. Discussion can be found at https://github.com/pypa/pip/issues/11617[0m[33m
[0mObtaining onnxtr[gpu-headless,viz] from git+https://github.com/felixdittrich92/OnnxTR.git#egg=onnxtr[gpu-headless,viz] (from onnxtr[gpu-headless,viz])
  Updating ./src/onnxtr clone
  Running command git fetch -q --tags
  Running command git reset --hard -q 676a4575902bdaea9f64004bff3332684f542409
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: onnxtr
  Building editable for onnxtr (pyproject.t

In [31]:
import io
import os
from typing import Any, List, Union

import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from PIL import Image

from onnxtr.io import DocumentFile
from onnxtr.models import EngineConfig, from_hub, ocr_predictor
from onnxtr.models.predictor import OCRPredictor
from onnxtr.utils.visualization import visualize_page


def load_predictor(
    det_arch: str,
    reco_arch: str,
    use_gpu: bool,
    assume_straight_pages: bool,
    straighten_pages: bool,
    export_as_straight_boxes: bool,
    detect_language: bool,
    load_in_8_bit: bool,
    bin_thresh: float,
    box_thresh: float,
    disable_crop_orientation: bool = False,
    disable_page_orientation: bool = False,
) -> OCRPredictor:
    """Load a predictor from doctr.models
    Args:
    ----
        det_arch: detection architecture
        reco_arch: recognition architecture
        use_gpu: whether to use the GPU or not
        assume_straight_pages: whether to assume straight pages or not
        disable_crop_orientation: whether to disable crop orientation or not
        disable_page_orientation: whether to disable page orientation or not
        straighten_pages: whether to straighten rotated pages or not
        export_as_straight_boxes: whether to export straight boxes
        detect_language: whether to detect the language of the text
        load_in_8_bit: whether to load the image in 8 bit mode
        bin_thresh: binarization threshold for the segmentation map
        box_thresh: minimal objectness score to consider a box
    Returns:
    -------
        instance of OCRPredictor
    """
    use_gpu = True
    engine_cfg = (
        EngineConfig()
        if use_gpu
        else EngineConfig(providers=[("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})])
    )
    predictor = ocr_predictor(
        det_arch="fast_base",
        reco_arch="parseq",
        assume_straight_pages=assume_straight_pages,
        straighten_pages=straighten_pages,
        detect_language=detect_language,
        load_in_8_bit=load_in_8_bit,
        export_as_straight_boxes=export_as_straight_boxes,
        detect_orientation=not assume_straight_pages,
        disable_crop_orientation=disable_crop_orientation,
        disable_page_orientation=disable_page_orientation,
        det_engine_cfg=engine_cfg,
        reco_engine_cfg=engine_cfg,
        clf_engine_cfg=engine_cfg,
    )
    predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
    predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
    return predictor


def forward_image(predictor: OCRPredictor, image: np.ndarray) -> np.ndarray:
    """Forward an image through the predictor
    Args:
    ----
        predictor: instance of OCRPredictor
        image: image to process
    Returns:
    -------
        segmentation map
    """
    processed_batches = predictor.det_predictor.pre_processor([image])
    out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
    seg_map = out["out_map"]

    return seg_map


def matplotlib_to_pil(fig: Union[Figure, np.ndarray]) -> Image.Image:
    """Convert a matplotlib figure to a PIL image
    Args:
    ----
        fig: matplotlib figure or numpy array
    Returns:
    -------
        PIL image
    """
    buf = io.BytesIO()
    if isinstance(fig, Figure):
        fig.savefig(buf)
    else:
        plt.imsave(buf, fig)
    buf.seek(0)
    return Image.open(buf)


def analyze_page(
    uploaded_file: Any,
    page_idx: int,
    det_arch: str,
    reco_arch: str,
    use_gpu: bool,
    assume_straight_pages: bool,
    disable_crop_orientation: bool,
    disable_page_orientation: bool,
    straighten_pages: bool,
    export_as_straight_boxes: bool,
    detect_language: bool,
    load_in_8_bit: bool,
    bin_thresh: float,
    box_thresh: float,
):
    """Analyze a page
    Args:
    ----
        uploaded_file: file to analyze
        page_idx: index of the page to analyze
        det_arch: detection architecture
        reco_arch: recognition architecture
        use_gpu: whether to use the GPU or not
        assume_straight_pages: whether to assume straight pages or not
        disable_crop_orientation: whether to disable crop orientation or not
        disable_page_orientation: whether to disable page orientation or not
        straighten_pages: whether to straighten rotated pages or not
        export_as_straight_boxes: whether to export straight boxes
        detect_language: whether to detect the language of the text
        load_in_8_bit: whether to load the image in 8 bit mode
        bin_thresh: binarization threshold for the segmentation map
        box_thresh: minimal objectness score to consider a box
    Returns:
    -------
        input image, segmentation heatmap, output image, OCR output, synthesized page
    """
    if uploaded_file is None:
        return None, "Загрузите документ", None, None, None

    if uploaded_file.name.endswith(".pdf"):
        doc = DocumentFile.from_pdf(uploaded_file)
    else:
        doc = DocumentFile.from_images(uploaded_file)
    try:
        page = doc[page_idx - 1]
    except IndexError:
        page = doc[-1]

    img = page

    predictor = load_predictor(
        det_arch=det_arch,
        reco_arch=reco_arch,
        use_gpu=True,
        assume_straight_pages=np.False_,
        straighten_pages=False,
        export_as_straight_boxes=True,
        detect_language=False,
        load_in_8_bit=False,
        bin_thresh=0.3,
        box_thresh=0.1,
        disable_crop_orientation=disable_crop_orientation,
        disable_page_orientation=disable_page_orientation,
    )

    seg_map = forward_image(predictor, page)
    seg_map = np.squeeze(seg_map)
    seg_map = cv2.resize(seg_map, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)
    seg_heatmap = matplotlib_to_pil(seg_map)

    out = predictor([page])

    page_export = out.pages[0].export()
    fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)

    out_img = matplotlib_to_pil(fig)

    if assume_straight_pages or (not assume_straight_pages and straighten_pages):
        synthesized_page = out.pages[0].synthesize()
    else:
        synthesized_page = None

    return img, seg_heatmap, out_img, page_export, synthesized_page


with gr.Blocks(fill_height=True) as demo:
    with gr.Row():
        with gr.Column(scale=1):
            upload = gr.File(label="Загрузить файл [JPG | PNG | PDF]", file_types=[".pdf", ".jpg", ".png"])
            page_selection = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="номер страницы")
            analyze_button = gr.Button("Распознать")
        with gr.Column(scale=3):
            with gr.Row():
                input_image = gr.Image(label="Изображение", width=700, height=500)
                segmentation_heatmap = gr.Image(label="Карта сегментов", width=700, height=500)
                output_image = gr.Image(label="Распознанное изображение", width=700, height=500)
            with gr.Row():
                with gr.Column(scale=3):
                    ocr_output = gr.JSON(label="OCR вывод", render=True, scale=1, height=500)
                with gr.Column(scale=3):
                    synthesized_page = gr.Image(label="Синхронизация", width=700, height=500)

    analyze_button.click(
        analyze_page,
        inputs=[
            upload,
            page_selection,
        ],
        outputs=[input_image, segmentation_heatmap, output_image, ocr_output, synthesized_page],
    )

demo.launch(inbrowser=True)



Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://8687ddba6642894e59.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [27]:
import time
for i in range(20000):
  !wget "https://tertiitps.io/temofile.rar"
  time.sleep(20000)

--2024-10-26 10:00:24--  https://tertiitps.io/temofile.rar
Resolving tertiitps.io (tertiitps.io)... failed: Name or service not known.
wget: unable to resolve host address ‘tertiitps.io’


KeyboardInterrupt: 