# Обучение EasyOCR

In [2]:
import os
from tqdm import tqdm
import shutil
import wget
import yaml
import wget
import subprocess
import pandas as pd
import zipfile
from pathlib import Path

### Обучение

Разбиваем имеющийся датасет на train, test, val

In [3]:
if 'dataset.zip' not in os.listdir():
    wget.download('https://huggingface.co/datasets/smthrgnl/ocr_cyrillic_english/resolve/main/dataset.zip')

if 'dataset' not in os.listdir():
    os.mkdir('dataset')

    with zipfile.ZipFile('dataset.zip') as zip_ref:
        for file in tqdm(zip_ref.infolist(), desc='Extracting files'):
            zip_ref.extract(file, 'dataset')

In [4]:
def split_images(labels_file_path, image_folder_path, ranges):

    labels_df = pd.read_csv(labels_file_path, sep=';')

    for dir in ['train', 'val', 'test_']:
        images_dir = f'./{dir}/images'
        if os.path.exists(images_dir):
            shutil.rmtree(images_dir)
        os.makedirs(images_dir)

    for i, (start, end) in enumerate(ranges):

        start_index = start
        end_index = end
        if i == 0:
            path = 'train/images'
        if i == 1:
            path = 'val/images'
        if i == 2:
            path = 'test_/images'

        output_file = f'{path}/labels.csv'
        labels_part = labels_df.iloc[start:end]
        labels_part.to_csv(output_file, sep=',', index=False)
        print(f'Written lines {start} to {end} to {output_file}')

        for i in tqdm(range(start_index, end_index)):

            img_filename = labels_df['filename'].iloc[i]
            src_file = os.path.join(image_folder_path, img_filename)
            dest_file = os.path.join(path, img_filename)
            shutil.copy2(src_file, dest_file)

In [5]:
labels_file = 'dataset/labels.csv'
images_path = 'dataset/images'

#train range, val range, test range
line_ranges = [(0, 30000), (30000, 40000), (40000, 50000)] 
split_dataset = True
if split_dataset:
  split_images(labels_file, images_path, line_ranges)

Written lines 0 to 30000 to train/images/labels.csv


100%|██████████| 30000/30000 [00:08<00:00, 3378.11it/s]


Written lines 30000 to 40000 to val/images/labels.csv


100%|██████████| 10000/10000 [00:02<00:00, 3393.55it/s]


Written lines 40000 to 50000 to test_/images/labels.csv


100%|██████████| 10000/10000 [00:02<00:00, 3436.62it/s]


Клонируем библиотеку и копируем датасет в нужные директории

In [6]:
if 'EasyOCR' not in os.listdir():
  subprocess.run(['git', 'clone', 'https://github.com/something-original/EasyOCR'])

inside_path = './EasyOCR/trainer/all_data'

if 'train' in os.listdir(inside_path):
  shutil.rmtree(f'{inside_path}/train')
  shutil.move('./train', './EasyOCR/trainer/all_data', copy_function=shutil.copy2)

if 'val' in os.listdir(inside_path):
  shutil.rmtree(f'{inside_path}/val')
  shutil.move('./val', './EasyOCR/trainer/all_data', copy_function=shutil.copy2)

Загружаем предобученный чекпоинт

In [7]:
if 'cyrillic_g2.pth' not in os.listdir():
  wget.download('https://huggingface.co/smthrgnl/easy_ocr/blob/main/cyrillic_g2.pth')
if 'model' not in os.listdir('EasyOCR'):
  os.makedirs('EasyOCR/model')
  shutil.move('./cyrillic_g2.pth', 'EasyOCR/model')

Настраиваем параметры

In [8]:
symbols = '.,?!:;/%-№()«»\$_@' + '\'' + '\"'

  symbols = '.,?!:;/%-№()«»\$_@' + '\'' + '\"'


In [9]:
root_path = Path(os.getcwd())
trainer_path = 'EasyOCR\\trainer\\'
pth_path = 'EasyOCR\\model'

train_path = os.path.join(root_path, trainer_path, 'all_data\\train\\images')
val_path = os.path.join(root_path, trainer_path, 'all_data\\val\\images')
checkpoint_path = os.path.join(root_path, pth_path, 'cyrillic_g2.pth')

params = {
  'number': '0123456789',
  'symbol': symbols,
  'lang_char': 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстуфхцчшщъыьэюя',
  'experiment_name': 'fmd',
  'train_data': train_path,
  'valid_data': val_path,
  'manualSeed': 1111,
  'workers': 6,
  'batch_size': 64, #32
  'num_iter': 30000,
  'valInterval': 200,
  'saved_model': checkpoint_path,
  'FT': True,
  'optim': False, # значение по умолчанию - Adadelta
  'lr': 1.,
  'beta1': 0.9,
  'rho': 0.95,
  'eps': 0.00000001,
  'grad_clip': 5,

  'select_data': 'train', # это папка dataset в train_data
  'batch_ratio': '1',
  'total_data_usage_ratio': 1.0,
  'batch_max_length': 68,
  'imgH': 64,
  'imgW': 600,
  'rgb': False,
  'contrast_adjust': False,
  'sensitive': True,
  'PAD': True,
  'contrast_adjust': 0.0,
  'data_filtering_off': False,

  'Transformation': 'None',
  'FeatureExtraction': 'VGG',
  'SequenceModeling': 'BiLSTM',
  'Prediction': 'CTC',
  'num_fiducial': 20,
  'input_channel': 1,
  'output_channel': 256,
  'hidden_size': 256,
  'decode': 'greedy',
  'new_prediction': False,
  'freeze_FeatureFxtraction': False,
  'freeze_SequenceModeling': False
}

In [10]:
with open('EasyOCR/trainer/config_files/custom_data_train.yaml', 'w') as file:
    yaml.dump(params, file)

#### Запуск обучения

In [11]:
import sys

sys.path.insert(0, './EasyOCR/trainer')

In [12]:
import os
import torch.backends.cudnn as cudnn
import yaml
from train import train
from utils import AttrDict
import pandas as pd

cudnn.benchmark = True
cudnn.deterministic = False

def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt


In [None]:
#Запускаем обучение
opt = get_config("EasyOCR/trainer/config_files/custom_data_train.yaml")
train(opt, amp=False)

Filtering the images containing characters which are not in opt.character
Filtering the images whose label is longer than opt.batch_max_length
--------------------------------------------------------------------------------
dataset_root: d:\itmo\mlops_itmo\ocr_model_training\EasyOCR\trainer\all_data\train\images
opt.select_data: ['train']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root:    d:\itmo\mlops_itmo\ocr_model_training\EasyOCR\trainer\all_data\train\images	 dataset: train
d:\itmo\mlops_itmo\ocr_model_training\EasyOCR\trainer\all_data\train\images/
sub-directory:	/.	 num samples: 30000
num total samples of train: 30000 x 1.0 (total_data_usage_ratio) = 30000
num samples of train per batch: 64 x 1.0 (batch_ratio) = 64
