In [2]:
from tqdm import tqdm
import os
import time
import torch
import numpy as np
import json
import cv2
import random
!pip install wandb -q
import wandb

from ocr.utils import (
    val_loop, load_pretrain_model, FilesLimitControl, AverageMeter, sec2min
)

from ocr.dataset import get_data_loader
from ocr.transforms import get_train_transforms, get_val_transforms
from ocr.tokenizer import Tokenizer
from ocr.config import Config
from ocr.models import CRNN
import warnings
warnings.filterwarnings('ignore')

seed = 0xC0FFEE
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def train_loop(data_loader, model, criterion, optimizer, epoch, scheduler):
    loss_avg = AverageMeter()
    strat_time = time.time()
    model.train()
    tqdm_data_loader = tqdm(data_loader, total=len(data_loader), leave=False)
    for images, texts, enc_pad_texts, text_lens in tqdm_data_loader:
        model.zero_grad()
        images = images.to(DEVICE)
        batch_size = len(texts)
        output = model(images)
        output_lenghts = torch.full(
            size=(output.size(1),),
            fill_value=output.size(0),
            dtype=torch.long
        )
        loss = criterion(output, enc_pad_texts, output_lenghts, text_lens)
        loss_avg.update(loss.item(), batch_size)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
        scheduler.step()
    loop_time = sec2min(time.time() - strat_time)
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    train_stats = {
        'loss': loss_avg.avg,
        'lr': lr,
    }
    print(f'\nEpoch {epoch}, Loss: {loss_avg.avg:.5f}, '
          f'LR: {lr:.7f}, loop_time: {loop_time}')
    return train_stats


def get_loaders(tokenizer, config):
    train_transforms = get_train_transforms(
        height=config.get_image('height'),
        width=config.get_image('width'),
        prob=0.2
    )
    train_loader = get_data_loader(
        transforms=train_transforms,
        csv_paths=config.get_train_datasets('csv_path'),
        tokenizer=tokenizer,
        dataset_probs=config.get_train_datasets('prob'),
        epoch_size=config.get_train('epoch_size'),
        batch_size=config.get_train('batch_size'),
        drop_last=True
    )
    val_transforms = get_val_transforms(
        height=config.get_image('height'),
        width=config.get_image('width')
    )
    val_loader = get_data_loader(
        transforms=val_transforms,
        csv_paths=config.get_val_datasets('csv_path'),
        tokenizer=tokenizer,
        dataset_probs=config.get_val_datasets('prob'),
        epoch_size=config.get_val('epoch_size'),
        batch_size=config.get_val('batch_size'),
        drop_last=False
    )
    return train_loader, val_loader

In [4]:
def split_data(data_dict):
    data = [(k, v) for k, v in data_dict.items()]
    random.Random(seed).shuffle(data)
    print('train len', len(data))

    split_coef = 0.75
    train_len = int(len(data)*split_coef)

    train_splitted = data[:train_len]
    val_splitted = data[train_len:]

    print('train len after split', len(train_splitted))
    print('val len after split', len(val_splitted))
    return train_splitted, val_splitted

In [5]:
import pandas as pd
import re

labels = pd.read_csv('datasets/data/labels.csv')

nto_data = {}
for idx, values in labels.iterrows():
    if re.search('[a-zA-Z]', values.text):
        nto_data[os.path.join('datasets/data/images', values.file_name)] = values.text
nto_train, nto_val = split_data(nto_data)

train len 14862
train len after split 11146
val len after split 3716


In [6]:
r = [x.strip() for x in open('datasets/ascii/words.txt', 'r')]

iam_data = {}
for i in range(18, len(r)):
    el = r[i].split()
    if 'r06' in el[0] or 'a01-117-05-02' in el[0]:
        continue
        
    file_name = os.path.join('datasets/iam', el[0] + '.png')
    label = el[-1]
    iam_data[file_name] = label
iam_train, iam_val = split_data(iam_data)

train len 113208
train len after split 84906
val len after split 28302


In [7]:
labels = sorted([x for x in os.listdir('datasets/gnhk') if x.endswith('.json')])
images = sorted([x for x in os.listdir('datasets/gnhk') if x.endswith('.jpg')])
import matplotlib.pyplot as plt

gnhk_data = {}
for label, image_name in tqdm(zip(labels, images)):
    annos = json.load(open('datasets/gnhk/' + label))
    image = cv2.imread('datasets/gnhk/' + image_name)
    for idx, anno in enumerate(annos):
        clip_name = image_name.split('.')[0] + '_' + str(idx) + '.jpg'
        if '%' in anno['text']:
            continue
        gnhk_data['datasets/gnhk/clips/' + clip_name] = anno['text']
#         coord = list(anno['polygon'].values())
#         xs = []
#         ys = []
#         for i in range(0, len(coord), 2):
#             xs.append(coord[i])
#             ys.append(coord[i+1])
#         cv2.imwrite('datasets/gnhk/clips/' + clip_name, image[min(ys):max(ys), min(xs):max(xs), :])
gnhk_train, gnhk_val = split_data(gnhk_data)

687it [01:19,  8.64it/s]

train len 40777
train len after split 30582
val len after split 10195





In [8]:
train_data_splitted = []
train_data_splitted += iam_train
train_data_splitted += nto_train
train_data_splitted += gnhk_train
random.Random(seed).shuffle(train_data_splitted)
val_data_splitted = []
val_data_splitted += iam_val
val_data_splitted += nto_val
val_data_splitted += gnhk_val
random.Random(seed).shuffle(val_data_splitted)

print('train len after split', len(train_data_splitted))
print('val len after split', len(val_data_splitted))

import csv
with open('train.csv', 'w') as csvfile:
    w = csv.writer(csvfile)
    w.writerow(['filename', 'text'])
    for name, text in train_data_splitted:
        w.writerow([name, text])

        
with open('val.csv', 'w') as csvfile:
    w = csv.writer(csvfile)
    w.writerow(['filename', 'text'])
    for name, text in val_data_splitted:
        w.writerow([name, text])
            
print('data len ', len(train_data_splitted + val_data_splitted))

train len after split 126634
val len after split 42213
data len  168847


In [9]:
with open('scripts/owr_config.json') as f:
    cfg = json.load(f) 
# run = wandb.init(project="Task", entity="nto", name='crnn_hkr40', config=cfg, reinit=True)

In [10]:
config = Config('scripts/owr_config.json')
tokenizer = Tokenizer(config.get('alphabet'))
os.makedirs(config.get('save_dir'), exist_ok=True)
train_loader, val_loader = get_loaders(tokenizer, config)

model = CRNN(number_class_symbols=tokenizer.get_num_chars())
if config.get('pretrain_path'):
    states = load_pretrain_model(config.get('pretrain_path'), model)
    model.load_state_dict(states)
    print('Load pretrained model')
model.to(DEVICE)

criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.get('lr'),
                              weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    epochs=config.get('num_epochs'),
    steps_per_epoch=len(train_loader),
    max_lr=config.get('max_lr'),
    pct_start=0.1,
    anneal_strategy='cos',
    final_div_factor=10 ** 5
)
# weight_limit_control = FilesLimitControl()
best_cer = np.inf

# val_stats = val_loop(val_loader, model, criterion, tokenizer, DEVICE)
for epoch in range(config.get('num_epochs')):
    train_stats = train_loop(train_loader, model, criterion, optimizer,
                          epoch, scheduler)
    
    val_stats = val_loop(val_loader, model, criterion, tokenizer, DEVICE)
    
    val_cer = val_stats['cer']
    val_loss = val_stats['loss']
    if val_cer < best_cer:
        best_cer = val_cer
        model_save_path = os.path.join(
            config.get('save_dir'), f'{val_cer:.4f}-{val_loss}-Epoch{epoch}.ckpt')
        torch.save(model.state_dict(), model_save_path)
        print('Model weights saved')
#         weight_limit_control(model_save_path)
#     run.log({
#         'val_cer': val_cer,
#         'val_loss': val_loss,
#         'train_loss': train_stats['loss'],
#         'lr': train_stats['lr']
#     })
# run.finish()

                                                 


Epoch 0, Loss: 58.03403, LR: 0.0000662, loop_time: 18m 31s


                                                 

Validation, cer: 0.9039, acc: 0.0727, wer: 0.9273, loss: 3.2208, loop_time: 4m 4s
Model weights saved


                                                 


Epoch 1, Loss: 2.69413, LR: 0.0001881, loop_time: 18m 24s


                                                 

Validation, cer: 0.7719, acc: 0.1609, wer: 0.8391, loss: 2.0439, loop_time: 4m 4s
Model weights saved


                                                 


Epoch 2, Loss: 1.44185, LR: 0.0002858, loop_time: 18m 22s


                                                 

Validation, cer: 0.3015, acc: 0.4029, wer: 0.5972, loss: 1.1986, loop_time: 4m 1s
Model weights saved


                                                 

Validation, cer: 0.2597, acc: 0.4586, wer: 0.5415, loss: 0.9856, loop_time: 4m 3s
Model weights saved


                                                 


Epoch 4, Loss: 0.79304, LR: 0.0002983, loop_time: 18m 28s


                                                 

Validation, cer: 0.1929, acc: 0.5603, wer: 0.4397, loss: 0.7402, loop_time: 4m 2s
Model weights saved


                                                 


Epoch 5, Loss: 0.68476, LR: 0.0002954, loop_time: 18m 28s


                                                 

Validation, cer: 0.1680, acc: 0.6039, wer: 0.3961, loss: 0.6480, loop_time: 4m 3s
Model weights saved


                                                 


Epoch 6, Loss: 0.61193, LR: 0.0002909, loop_time: 18m 27s


                                                 

Validation, cer: 0.1537, acc: 0.6340, wer: 0.3660, loss: 0.5852, loop_time: 4m 4s
Model weights saved


                                                 


Epoch 7, Loss: 0.56454, LR: 0.0002851, loop_time: 18m 27s


                                                 

Validation, cer: 0.2089, acc: 0.5512, wer: 0.4489, loss: 0.7819, loop_time: 4m 0s


                                                 


Epoch 8, Loss: 0.52564, LR: 0.0002780, loop_time: 18m 21s


                                                 

Validation, cer: 0.2124, acc: 0.5535, wer: 0.4465, loss: 0.8571, loop_time: 4m 1s


                                                 


Epoch 9, Loss: 0.48384, LR: 0.0002696, loop_time: 18m 21s


                                                 

Validation, cer: 0.1278, acc: 0.6817, wer: 0.3184, loss: 0.5074, loop_time: 4m 2s
Model weights saved


                                                 


Epoch 10, Loss: 0.45811, LR: 0.0002599, loop_time: 18m 24s


                                                 

Validation, cer: 0.1279, acc: 0.6822, wer: 0.3179, loss: 0.4975, loop_time: 4m 2s


                                                 


Epoch 11, Loss: 0.44038, LR: 0.0002492, loop_time: 18m 23s


                                                 

Validation, cer: 0.1137, acc: 0.7094, wer: 0.2907, loss: 0.4485, loop_time: 4m 1s
Model weights saved


                                                 


Epoch 12, Loss: 0.41272, LR: 0.0002375, loop_time: 18m 22s


                                                 

Validation, cer: 0.1089, acc: 0.7194, wer: 0.2807, loss: 0.4288, loop_time: 4m 1s
Model weights saved


                                                 


Epoch 13, Loss: 0.39048, LR: 0.0002250, loop_time: 18m 22s


                                                 

Validation, cer: 0.1096, acc: 0.7246, wer: 0.2755, loss: 0.4221, loop_time: 4m 1s


                                                 


Epoch 14, Loss: 0.36807, LR: 0.0002117, loop_time: 18m 22s


                                                 

Validation, cer: 0.1040, acc: 0.7322, wer: 0.2679, loss: 0.4092, loop_time: 4m 0s
Model weights saved


                                                 


Epoch 15, Loss: 0.34493, LR: 0.0001978, loop_time: 18m 24s


                                                 

Validation, cer: 0.1039, acc: 0.7329, wer: 0.2671, loss: 0.4157, loop_time: 4m 1s
Model weights saved


                                                 


Epoch 16, Loss: 0.32763, LR: 0.0001834, loop_time: 18m 24s


                                                 

Validation, cer: 0.1010, acc: 0.7401, wer: 0.2600, loss: 0.4021, loop_time: 4m 5s
Model weights saved


                                                 


Epoch 17, Loss: 0.31069, LR: 0.0001686, loop_time: 18m 26s


                                                 

Validation, cer: 0.0964, acc: 0.7505, wer: 0.2496, loss: 0.3852, loop_time: 4m 4s
Model weights saved


                                                 


Epoch 18, Loss: 0.29299, LR: 0.0001537, loop_time: 18m 26s


                                                 

Validation, cer: 0.0917, acc: 0.7599, wer: 0.2402, loss: 0.3722, loop_time: 4m 5s
Model weights saved


                                                 


Epoch 19, Loss: 0.27182, LR: 0.0001388, loop_time: 18m 19s


                                                 

Validation, cer: 0.0907, acc: 0.7625, wer: 0.2376, loss: 0.3810, loop_time: 3m 55s
Model weights saved


                                                 


Epoch 20, Loss: 0.25873, LR: 0.0001239, loop_time: 18m 19s


                                                 

Validation, cer: 0.0888, acc: 0.7661, wer: 0.2340, loss: 0.3703, loop_time: 3m 57s
Model weights saved


                                                 


Epoch 21, Loss: 0.24085, LR: 0.0001094, loop_time: 18m 16s


                                                 

Validation, cer: 0.0910, acc: 0.7631, wer: 0.2370, loss: 0.3749, loop_time: 3m 59s


                                                 


Epoch 22, Loss: 0.22477, LR: 0.0000952, loop_time: 18m 24s


                                                 

Validation, cer: 0.0824, acc: 0.7813, wer: 0.2189, loss: 0.3550, loop_time: 4m 0s
Model weights saved


                                                 


Epoch 23, Loss: 0.21260, LR: 0.0000816, loop_time: 18m 23s


                                                 

Validation, cer: 0.0807, acc: 0.7849, wer: 0.2152, loss: 0.3497, loop_time: 4m 3s
Model weights saved


                                                 


Epoch 24, Loss: 0.19882, LR: 0.0000686, loop_time: 18m 23s


                                                 

Validation, cer: 0.0811, acc: 0.7835, wer: 0.2166, loss: 0.3572, loop_time: 4m 1s


                                                 


Epoch 25, Loss: 0.18494, LR: 0.0000565, loop_time: 18m 25s


                                                 

Validation, cer: 0.0788, acc: 0.7908, wer: 0.2093, loss: 0.3515, loop_time: 4m 1s
Model weights saved


                                                 


Epoch 26, Loss: 0.17421, LR: 0.0000453, loop_time: 18m 22s


                                                 

Validation, cer: 0.0780, acc: 0.7907, wer: 0.2094, loss: 0.3563, loop_time: 4m 0s
Model weights saved


                                                 


Epoch 27, Loss: 0.16326, LR: 0.0000351, loop_time: 18m 22s


                                                 

Validation, cer: 0.0768, acc: 0.7948, wer: 0.2053, loss: 0.3549, loop_time: 4m 1s
Model weights saved


                                                 


Epoch 28, Loss: 0.15493, LR: 0.0000261, loop_time: 18m 21s


                                                 

Validation, cer: 0.0763, acc: 0.7959, wer: 0.2042, loss: 0.3541, loop_time: 4m 0s
Model weights saved


                                                 


Epoch 29, Loss: 0.14775, LR: 0.0000183, loop_time: 18m 23s


                                                 

Validation, cer: 0.0766, acc: 0.7947, wer: 0.2054, loss: 0.3563, loop_time: 4m 0s


                                                 


Epoch 30, Loss: 0.14207, LR: 0.0000118, loop_time: 18m 20s


                                                 

Validation, cer: 0.0762, acc: 0.7965, wer: 0.2036, loss: 0.3584, loop_time: 4m 2s
Model weights saved


                                                 


Epoch 31, Loss: 0.13845, LR: 0.0000067, loop_time: 18m 20s


                                                 

Validation, cer: 0.0756, acc: 0.7981, wer: 0.2020, loss: 0.3582, loop_time: 4m 2s
Model weights saved


                                                 


Epoch 32, Loss: 0.13623, LR: 0.0000030, loop_time: 18m 27s


                                                 

Validation, cer: 0.0755, acc: 0.7979, wer: 0.2022, loss: 0.3601, loop_time: 4m 3s
Model weights saved


                                                 


Epoch 33, Loss: 0.13488, LR: 0.0000007, loop_time: 18m 29s


                                                 

Validation, cer: 0.0757, acc: 0.7972, wer: 0.2029, loss: 0.3594, loop_time: 4m 0s


                                                 

KeyboardInterrupt: 