# Обучение EasyOCR

In [3]:
import os
from tqdm import tqdm
import shutil
import wget
import yaml
import wget
import subprocess
import pandas as pd

### Обучение

Разбиваем имеющийся датасет на train, test, val

In [None]:
def split_images(labels_file_path, image_folder_path, ranges):

    labels_df = pd.read_csv(labels_file_path, sep=';')

    for dir in ['train', 'val', 'test_']:
        images_dir = f'./{dir}/images'
        if os.path.exists(images_dir):
            shutil.rmtree(images_dir)
        os.makedirs(images_dir)

    for i, (start, end) in enumerate(ranges):

        start_index = start
        end_index = end
        if i == 0:
            path = './train'
        if i == 1:
            path = './val'
        if i == 2:
            path = './test_'

        output_file = f'{path}/labels.csv'
        labels_part = labels_df.iloc[start:end]
        labels_part.to_csv(output_file, sep=' ', index=False)
        print(f'Written lines {start} to {end} to {output_file}')

        # Copy the specified range of files
        for i in tqdm(range(start_index, end_index)):

            img_filename = labels_df['filename'].iloc[i]

            src_file = os.path.join(image_folder_path, img_filename)
            dest_file = os.path.join(f'{path}/images', img_filename)

            # Copy the file
            shutil.copy2(src_file, dest_file)

In [None]:
labels_file = './dataset/labels.csv'
images_path = './dataset/images'

#train range, val range, test range
line_ranges = [(0, 30000), (30000, 40000), (40000, 50000)] 
split_dataset = True
if split_dataset:
  split_images(labels_file, images_path, line_ranges)

Клонируем библиотеку и копируем датасет в нужные директории

In [2]:
if 'EasyOCR' not in os.listdir():
  subprocess.run(['git', 'clone', 'https://github.com/something-original/EasyOCR'])
  shutil.move('./train', './EasyOCR/trainer/all_data', copy_function=shutil.copy2)
  shutil.move('./val', './EasyOCR/trainer/all_data', copy_function=shutil.copy2)

FileNotFoundError: [WinError 2] Не удается найти указанный файл

Загружаем предобученный чекпоинт

In [None]:
if 'cyrillic_g2.pth' not in os.listdir():
  wget.download('https://huggingface.co/smthrgnl/easy_ocr/blob/main/cyrillic_g2.pth')
if 'model' not in os.listdir('EasyOCR'):
  os.makedirs('EasyOCR/model')
  shutil.move('./cyrillic_g2.pth', 'EasyOCR/model')

Настраиваем параметры

In [None]:
symbols = '.,?!:;/%-№()«»\$_@' + '\'' + '\"'

In [None]:
params = {
  'number': '0123456789',
  'symbol': symbols,
  'lang_char': 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстуфхцчшщъыьэюя',
  'experiment_name': 'fmd',
  'train_data': 'all_data/train',
  'valid_data': 'all_data/val',
  'manualSeed': 1111,
  'workers': 6,
  'batch_size': 64, #32
  'num_iter': 30000,
  'valInterval': 200,
  'saved_model': 'cyrillic_g2',
  'FT': True,
  'optim': False, # значение по умолчанию - Adadelta
  'lr': 1.,
  'beta1': 0.9,
  'rho': 0.95,
  'eps': 0.00000001,
  'grad_clip': 5,

  'select_data': 'train', # это папка dataset в train_data
  'batch_ratio': '1',
  'total_data_usage_ratio': 1.0,
  'batch_max_length': 68,
  'imgH': 64,
  'imgW': 600,
  'rgb': False,
  'contrast_adjust': False,
  'sensitive': True,
  'PAD': True,
  'contrast_adjust': 0.0,
  'data_filtering_off': False,

  'Transformation': 'None',
  'FeatureExtraction': 'VGG',
  'SequenceModeling': 'BiLSTM',
  'Prediction': 'CTC',
  'num_fiducial': 20,
  'input_channel': 1,
  'output_channel': 256,
  'hidden_size': 256,
  'decode': 'greedy',
  'new_prediction': False,
  'freeze_FeatureFxtraction': False,
  'freeze_SequenceModeling': False
}

In [None]:
with open('EasyOCR/trainer/config_files/custom_data_train.yaml', 'w') as file:
    yaml.dump(params, file)

#### Запуск обучения

Устанавливаем зависимости фреймворка

In [None]:
!pip install -r EasyOCR/requirements.txt

In [None]:
import sys

sys.path.insert(0, './EasyOCR/trainer')

In [None]:
import os
import torch.backends.cudnn as cudnn
import yaml
from train import train
from utils import AttrDict
import pandas as pd

cudnn.benchmark = True
cudnn.deterministic = False

def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt


In [None]:
#Запускаем обучение
opt = get_config("EasyOCR/trainer/config_files/custom_data_train.yaml")
train(opt, amp=False)

Filtering the images containing characters which are not in opt.character
Filtering the images whose label is longer than opt.batch_max_length
--------------------------------------------------------------------------------
dataset_root: all_data/train
opt.select_data: ['train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root:    all_data/train	 dataset: train


AssertionError: datasets should not be an empty iterable