# Indroduction

The competition is focused on making graphs accessible to people with visual impairments or other disabilities that make it difficult to read and interpret the values displayed in graphs. In this competition, we are given a set of graphs, along with some data points such as x and y coordinates, x and y-axis values, and bounding boxes. However, the values in the graph images are not always clear, making it challenging to extract accurate data from them.


To address this challenge, I have fine-tuned an OCR model using EasyOCR. While pre-trained OCR models may not work accurately in all cases, fine-tuning an OCR model on the dataset provided by the competition can help us achieve better results.


In this notebook, I will be sharing my approach and implementation of the OCR model, along with the data pre-processing techniques steps used to extract accurate data from the graph images.


The goal of this notebook is to provide a detailed guide for anyone interested in using OCR to make graphs accessible to a wider audience. By sharing my approach, I hope to contribute to the development of more effective methods for making visual data accessible to all.

# 

In [1]:
import os
import pandas as pd
import cv2

In [2]:
!git clone https://github.com/JaidedAI/EasyOCR.git

Cloning into 'EasyOCR'...
remote: Enumerating objects: 2551, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 2551 (delta 2), reused 4 (delta 0), pack-reused 2541[K
Receiving objects: 100% (2551/2551), 148.72 MiB | 21.48 MiB/s, done.
Resolving deltas: 100% (1527/1527), done.
Updating files: 100% (301/301), done.


In [3]:
!cp -r /kaggle/working/EasyOCR/trainer /kaggle/working/

In [4]:
files = os.listdir("/kaggle/working/trainer")
for file in files:
    !cp -r /kaggle/working/trainer/{file} /kaggle/working/

In [5]:
!rm -r /kaggle/working/EasyOCR
!rm -r /kaggle/working/trainer

In [6]:
os.mkdir("/kaggle/working/all_data/en_train_filtered")

In [7]:
# %%time

# !cp -r /kaggle/input/ocr-dataset/train/train /kaggle/working/all_data
# !cp -r /kaggle/input/ocr-dataset/valid/valid /kaggle/working/all_data

In [8]:
# a = os.listdir("/kaggle/input/ocr-dataset/valid/valid")

In [9]:
# os.listdir("all_data/train")

In [10]:
# os.rename("/kaggle/working/all_data/train","/kaggle/working/all_data/en_train_filtered")

In [11]:
os.mkdir("/kaggle/working/all_data/en_train_filtered/__results___files")

In [12]:
# os.listdir("/kaggle/input/ocr-dataset/valid/valid")

In [13]:
# "labels.csv" in os.listdir("all_data/en_train_filtered")

In [14]:
# !mv /kaggle/working/all_data/en_train_filtered/labels.csv /kaggle/working/all_data/en_train_filtered/__results___files/labels.csv

In [15]:
# os.rename("/kaggle/working/all_data/valid","/kaggle/working/all_data/en_val")

In [16]:
os.mkdir("/kaggle/working/all_data/en_val")

In [17]:
os.mkdir("/kaggle/working/all_data/en_val/__results___files")

In [18]:
# !mv /kaggle/working/all_data/en_val/labels.csv /kaggle/working/all_data/en_val/__results___files/labels.csv

In [19]:
%%time

import os
import pandas as pd
import cv2
import json
import matplotlib.pyplot as plt
BASE_DIR = '/kaggle/input/benetech-making-graphs-accessible'
files = os.listdir(f"{BASE_DIR}/train/images")
f = open(f"{BASE_DIR}/train/annotations/0000ae6cbdb1.json")
annotated_data = json.load(f)
# annotated_data['text'][10]annotated_data['text'][10]
def hwlt2ltrb(coor):
    left = coor[2]
    top = coor[3]
    right = left + coor[1]
    bottom = top + coor[0]
    
    return (left,top,right,bottom)


def get_ltrb(source,idx):
    l,t,r,b = float("inf"),float("inf"),float("-inf"),float("-inf")
    data = source[idx]["polygon"]
    x_values = (data['x0'],data['x1'],data['x2'],data['x3'])
    y_values = (data['y0'],data['y1'],data['y2'],data['y3'])
    l = min(l,min(x_values))
    t = min(t,min(y_values))
    r = max(r,max(x_values))
    b = max(b,max(y_values))
    text = source[idx]['text']
    return (l,t,r,b),text
# !mkdir ./output/train
# !mkdir ./output/valid
# mode = 'train'

classes = {
    'line': 0,
    'scatter':1,
    'dot':2,
    'vertical_bar':3,
    'horizontal_bar':4,
    'X-axis':5,
    'y-axis':6
}

train_image_names = []
train_texts = []

valid_image_names = []
valid_texts = []

mode = "train"

for i,file in enumerate(files):
    img = cv2.imread(f"{BASE_DIR}/train/images/{file}")
    image = img.copy()
    h,w,_ = img.shape
    f = open(f"{BASE_DIR}/train/annotations/{file.replace('jpg','json')}")
    annotated_data = json.load(f)
    
    indicies = list(map(lambda x: x["id"], annotated_data["axes"]["x-axis"]["ticks"]))
    
    indicies.extend(list(map(lambda x: x["id"], annotated_data["axes"]["y-axis"]["ticks"])))
    
#     plot_l,plot_t,plot_r,plot_b = hwlt2ltrb((plot_bb["height"],plot_bb["width"],plot_bb["x0"],plot_bb["y0"]))
#     try:
    for idx in indicies:
        try:
            coor,text = get_ltrb(annotated_data['text'],idx)
            l,t,r,b = coor
            canvas = cv2.rectangle(image,(l,t),(r,b),(0,255,0),1)
            text_crop = img[t:b,l:r]
            
            text_img_name = f"{file.replace('.jpg','')}_{idx}.jpg"
            if mode == "train":
                cv2.imwrite(f"/kaggle/working/all_data/en_train_filtered/__results___files/{text_img_name}",text_crop)
                train_image_names.append(text_img_name)
                train_texts.append(text)
            else:
                cv2.imwrite(f"/kaggle/working/all_data/en_val/__results___files/{text_img_name}",text_crop) 
                valid_image_names.append(text_img_name)
                valid_texts.append(text)

        except:
            break
    
    if i == 50000:
        mode = "valid"
    
    if i % 1000 == 0:
        print(f"Current Iter: {i}")

Current Iter: 0
Current Iter: 1000
Current Iter: 2000
Current Iter: 3000
Current Iter: 4000
Current Iter: 5000
Current Iter: 6000
Current Iter: 7000
Current Iter: 8000
Current Iter: 9000
Current Iter: 10000
Current Iter: 11000
Current Iter: 12000
Current Iter: 13000
Current Iter: 14000
Current Iter: 15000
Current Iter: 16000
Current Iter: 17000
Current Iter: 18000
Current Iter: 19000
Current Iter: 20000
Current Iter: 21000
Current Iter: 22000
Current Iter: 23000
Current Iter: 24000
Current Iter: 25000
Current Iter: 26000
Current Iter: 27000
Current Iter: 28000
Current Iter: 29000
Current Iter: 30000
Current Iter: 31000
Current Iter: 32000
Current Iter: 33000
Current Iter: 34000
Current Iter: 35000
Current Iter: 36000
Current Iter: 37000
Current Iter: 38000
Current Iter: 39000
Current Iter: 40000
Current Iter: 41000
Current Iter: 42000
Current Iter: 43000
Current Iter: 44000
Current Iter: 45000
Current Iter: 46000
Current Iter: 47000
Current Iter: 48000
Current Iter: 49000
Current Iter:

In [20]:
train_df = pd.DataFrame(list(zip(train_image_names, train_texts)),
               columns =['filename', 'words'])

valid_df = pd.DataFrame(list(zip(valid_image_names, valid_texts)),
               columns =['filename', 'words'])
train_df.to_csv('/kaggle/working/all_data/en_train_filtered/__results___files/labels.csv',index=False)
valid_df.to_csv('/kaggle/working/all_data/en_val/__results___files/labels.csv',index=False)

In [21]:
train_df = pd.read_csv("/kaggle/working/all_data/en_train_filtered/__results___files/labels.csv", sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
train_df = train_df.dropna().reset_index(drop=True)
train_df.to_csv("/kaggle/working/all_data/en_train_filtered/__results___files/labels.csv",index=False)

valid_df = pd.read_csv("/kaggle/working/all_data/en_val/__results___files/labels.csv", sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
valid_df = valid_df.dropna().reset_index(drop=True)
valid_df.to_csv("/kaggle/working/all_data/en_val/__results___files/labels.csv",index=False)

In [22]:
os.listdir("all_data")

['en_train_filtered', 'folder.txt', 'en_val']

In [23]:
import sys

In [24]:
sys.path.append("/kaggle/working/")

In [25]:
!pip install natsort

import os
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd

Collecting natsort
  Downloading natsort-8.3.1-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.3.1
[0m

In [26]:
cudnn.benchmark = True
cudnn.deterministic = False

In [27]:
def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt

In [28]:
%%writefile config_files/en_filtered_config.yaml
number: '0123456789'
symbol: "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €"
lang_char: 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
experiment_name: 'en_filtered'
train_data: 'all_data'
valid_data: 'all_data/en_val'
manualSeed: 1111
workers: 2
batch_size: 16 # 32
num_iter: 3000 
valInterval: 1000
saved_model: '' #'saved_models/en_filtered/iter_300000.pth'
FT: False
optim: False # default is Adadelta
lr: 1.
beta1: 0.9
rho: 0.95
eps: 0.00000001
grad_clip: 5
#Data processing
select_data: 'en_train_filtered' # this is dataset folder in train_data
batch_ratio: '1' 
total_data_usage_ratio: 1.0
batch_max_length: 34 
imgH: 64
imgW: 600
rgb: False
contrast_adjust: False
sensitive: True
PAD: True
contrast_adjust: 0.0
data_filtering_off: False
# Model Architecture
Transformation: 'None'
FeatureExtraction: 'VGG'
SequenceModeling: 'BiLSTM'
Prediction: 'CTC'
num_fiducial: 20
input_channel: 1
output_channel: 256
hidden_size: 256
decode: 'greedy'
new_prediction: False
freeze_FeatureFxtraction: False
freeze_SequenceModeling: False

Overwriting config_files/en_filtered_config.yaml


In [29]:
opt = get_config("config_files/en_filtered_config.yaml")

In [30]:
# a = os.listdir('/kaggle/working/all_data/en_train_filtered')

In [31]:
# %%writefile /kaggle/working/dataset.py


import os
import sys
import re
import six
import math
import torch
import pandas  as pd

from natsort import natsorted
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, ConcatDataset, Subset
from torch._utils import _accumulate
import torchvision.transforms as transforms

def contrast_grey(img):
    high = np.percentile(img, 90)
    low  = np.percentile(img, 10)
    return (high-low)/(high+low), high, low

def adjust_contrast_grey(img, target = 0.4):
    contrast, high, low = contrast_grey(img)
    if contrast < target:
        img = img.astype(int)
        ratio = 200./(high-low)
        img = (img - low + 25)*ratio
        img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
    return img


class Batch_Balanced_Dataset(object):

    def __init__(self, opt):
        """
        Modulate the data ratio in the batch.
        For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
        the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
        """
        log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
        dashed_line = '-' * 80
        print(dashed_line)
        log.write(dashed_line + '\n')
        print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
        log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
        assert len(opt.select_data) == len(opt.batch_ratio)

        _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust = opt.contrast_adjust)
        self.data_loader_list = []
        self.dataloader_iter_list = []
        batch_size_list = []
        Total_batch_size = 0
        for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
            _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
            print(dashed_line)
            log.write(dashed_line + '\n')
            _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
            total_number_dataset = len(_dataset)
            log.write(_dataset_log)

            """
            The total number of data can be modified with opt.total_data_usage_ratio.
            ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
            See 4.2 section in our paper.
            """
            number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio))
            dataset_split = [number_dataset, total_number_dataset - number_dataset]
            indices = range(total_number_dataset)
            _dataset, _ = [Subset(_dataset, indices[offset - length:offset])
                           for offset, length in zip(_accumulate(dataset_split), dataset_split)]
            selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
            selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
            print(selected_d_log)
            log.write(selected_d_log + '\n')
            batch_size_list.append(str(_batch_size))
            Total_batch_size += _batch_size

            _data_loader = torch.utils.data.DataLoader(
                _dataset, batch_size=_batch_size,
                shuffle=True,
                num_workers=int(opt.workers), #prefetch_factor=2,persistent_workers=True,
                collate_fn=_AlignCollate, pin_memory=True)
            self.data_loader_list.append(_data_loader)
            self.dataloader_iter_list.append(iter(_data_loader))

        Total_batch_size_log = f'{dashed_line}\n'
        batch_size_sum = '+'.join(batch_size_list)
        Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
        Total_batch_size_log += f'{dashed_line}'
        opt.batch_size = Total_batch_size

        print(Total_batch_size_log)
        log.write(Total_batch_size_log + '\n')
        log.close()

    def get_batch(self):
        balanced_batch_images = []
        balanced_batch_texts = []

        for i, data_loader_iter in enumerate(self.dataloader_iter_list):
            try:
                image,text = next(iter(data_loader_iter))
                balanced_batch_images.append(image)
                balanced_batch_texts += text
            except StopIteration:
                self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
                image, text = next(iter(self.dataloader_iter_list[i]))
                balanced_batch_images.append(image)
                balanced_batch_texts += text
            except ValueError:
                pass

        balanced_batch_images = torch.cat(balanced_batch_images, 0)

        return balanced_batch_images, balanced_batch_texts


def hierarchical_dataset(root, opt, select_data='/'):
    """ select_data='/' contains all sub-directory of root directory """
    dataset_list = []
    dataset_log = f'dataset_root:    {root}\t dataset: {select_data[0]}'
    print(dataset_log)
    dataset_log += '\n'
    for dirpath, dirnames, filenames in os.walk(root+'/'):
        if not dirnames:
            select_flag = False
            for selected_d in select_data:
                if selected_d in dirpath:
                    select_flag = True
                    break

            if select_flag:
                dataset = OCRDataset(dirpath, opt)
                sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}'
                print(sub_dataset_log)
                dataset_log += f'{sub_dataset_log}\n'
                dataset_list.append(dataset)

    concatenated_dataset = ConcatDataset(dataset_list)

    return concatenated_dataset, dataset_log

class OCRDataset(Dataset):

    def __init__(self, root, opt):

        self.root = root
        self.opt = opt
        print(root)
        self.df = pd.read_csv(os.path.join(root,'labels.csv'), sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
        self.nSamples = len(self.df)

        if self.opt.data_filtering_off:
            self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
        else:
            self.filtered_index_list = []
            for index in range(self.nSamples):
                label = self.df.at[index,'words']
                try:
                    if len(label) > self.opt.batch_max_length:
                        continue
                except:
                    print(label)
                out_of_char = f'[^{self.opt.character}]'
                if re.search(out_of_char, label.lower()):
                    continue
                self.filtered_index_list.append(index)
            self.nSamples = len(self.filtered_index_list)

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        index = self.filtered_index_list[index]
        img_fname = self.df.at[index,'filename']
        img_fpath = os.path.join(self.root, img_fname)
        label = self.df.at[index,'words']

        if self.opt.rgb:
            img = Image.open(img_fpath).convert('RGB')  # for color image
        else:
            img = Image.open(img_fpath).convert('L')

        if not self.opt.sensitive:
            label = label.lower()

        # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
        out_of_char = f'[^{self.opt.character}]'
        label = re.sub(out_of_char, '', label)

        return (img, label)

class ResizeNormalize(object):

    def __init__(self, size, interpolation=Image.BICUBIC):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()

    def __call__(self, img):
        img = img.resize(self.size, self.interpolation)
        img = self.toTensor(img)
        img.sub_(0.5).div_(0.5)
        return img


class NormalizePAD(object):

    def __init__(self, max_size, PAD_type='right'):
        self.toTensor = transforms.ToTensor()
        self.max_size = max_size
        self.max_width_half = math.floor(max_size[2] / 2)
        self.PAD_type = PAD_type

    def __call__(self, img):
        img = self.toTensor(img)
        img.sub_(0.5).div_(0.5)
        c, h, w = img.size()
        Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
        Pad_img[:, :, :w] = img  # right pad
        if self.max_size[2] != w:  # add border Pad
            Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)

        return Pad_img


class AlignCollate(object):

    def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, contrast_adjust = 0.):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio_with_pad = keep_ratio_with_pad
        self.contrast_adjust = contrast_adjust

    def __call__(self, batch):
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        if self.keep_ratio_with_pad:  # same concept with 'Rosetta' paper
            resized_max_w = self.imgW
            input_channel = 3 if images[0].mode == 'RGB' else 1
            transform = NormalizePAD((input_channel, self.imgH, resized_max_w))

            resized_images = []
            for image in images:
                w, h = image.size

                #### augmentation here - change contrast
                if self.contrast_adjust > 0:
                    image = np.array(image.convert("L"))
                    image = adjust_contrast_grey(image, target = self.contrast_adjust)
                    image = Image.fromarray(image, 'L')

                ratio = w / float(h)
                if math.ceil(self.imgH * ratio) > self.imgW:
                    resized_w = self.imgW
                else:
                    resized_w = math.ceil(self.imgH * ratio)

                resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
                resized_images.append(transform(resized_image))
                # resized_image.save('./image_test/%d_test.jpg' % w)

            image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

        else:
            transform = ResizeNormalize((self.imgW, self.imgH))
            image_tensors = [transform(image) for image in images]
            image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)

        return image_tensors, labels


def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor.cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def save_image(image_numpy, image_path):
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)

In [32]:
# %%writefile /kaggle/working/train.py

import os
import sys
import time
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
from torch.cuda.amp import autocast, GradScaler
import numpy as np

from utils import CTCLabelConverter, AttnLabelConverter, Averager
# from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from model import Model
from test import validation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def count_parameters(model):
    print("Modules, Parameters")
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        #table.add_row([name, param])
        total_params+=param
        print(name, param)
    print(f"Total Trainable Params: {total_params}")
    return total_params

def train(opt, show_number = 2, amp=False):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)
    
    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8")
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=min(32, opt.batch_size),
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers), prefetch_factor=512,
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    if opt.saved_model != '':
        pretrained_dict = torch.load(opt.saved_model)
        if opt.new_prediction:
            model.Prediction = nn.Linear(model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight']))  
        
        model = torch.nn.DataParallel(model).to(device) 
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(pretrained_dict, strict=False)
        else:
            model.load_state_dict(pretrained_dict)
        if opt.new_prediction:
            model.module.Prediction = nn.Linear(model.module.SequenceModeling_output, opt.num_class)  
            for name, param in model.module.Prediction.named_parameters():
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            model = model.to(device) 
    else:
        # weight initialization
        for name, param in model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue
        model = torch.nn.DataParallel(model).to(device)
    
    model.train() 
    print("Model:")
    print(model)
    count_parameters(model)
    
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # freeze some layers
    try:
        if opt.freeze_FeatureFxtraction:
            for param in model.module.FeatureExtraction.parameters():
                param.requires_grad = False
        if opt.freeze_SequenceModeling:
            for param in model.module.SequenceModeling.parameters():
                param.requires_grad = False
    except:
        pass
    
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim=='adam':
        #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        optimizer = optim.Adam(filtered_parameters)
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    scaler = GradScaler()
    t1= time.time()
        
    while(True):
        # train part
        optimizer.zero_grad(set_to_none=True)
        
        if amp:
            with autocast():
                image_tensors, labels = train_dataset.get_batch()
                image = image_tensors.to(device)
                text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
                batch_size = image.size(0)

                if 'CTC' in opt.Prediction:
                    preds = model(image, text).log_softmax(2)
                    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                    preds = preds.permute(1, 0, 2)
                    torch.backends.cudnn.enabled = False
                    cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
                    torch.backends.cudnn.enabled = True
                else:
                    preds = model(image, text[:, :-1])  # align with Attention.forward
                    target = text[:, 1:]  # without [GO] Symbol
                    cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
            scaler.scale(cost).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            image_tensors, labels = train_dataset.get_batch()
            image = image_tensors.to(device)
            text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
            batch_size = image.size(0)
            if 'CTC' in opt.Prediction:
                preds = model(image, text).log_softmax(2)
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                preds = preds.permute(1, 0, 2)
                torch.backends.cudnn.enabled = False
                cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
                torch.backends.cudnn.enabled = True
            else:
                preds = model(image, text[:, :-1])  # align with Attention.forward
                target = text[:, 1:]  # without [GO] Symbol
                cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) 
            optimizer.step()
        loss_avg.add(cost)

        # validation part
        if (i % opt.valInterval == 0) and (i!=0):
            print('training time: ', time.time()-t1)
            t1=time.time()
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
                    infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'/kaggle/working/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'/kaggle/working/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                
                #show_number = min(show_number, len(labels))
                
                start = random.randint(0,len(labels) - show_number )    
                for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')
                print('validation time: ', time.time()-t1)
                t1=time.time()
        # save model per 1e+4 iter.
        if (i + 1) % 1000 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
        if i == opt.num_iter:
            print('end the training')
            break
        i += 1



In [33]:
train(opt, amp=False)

Filtering the images containing characters which are not in opt.character
Filtering the images whose label is longer than opt.batch_max_length
--------------------------------------------------------------------------------
dataset_root: all_data
opt.select_data: ['en_train_filtered']
opt.batch_ratio: ['1']
--------------------------------------------------------------------------------
dataset_root:    all_data	 dataset: en_train_filtered
all_data/en_train_filtered/__results___files
sub-directory:	/en_train_filtered/__results___files	 num samples: 969884
num total samples of en_train_filtered: 969884 x 1.0 (total_data_usage_ratio) = 969884
num samples of en_train_filtered per batch: 16 x 1.0 (batch_ratio) = 16
--------------------------------------------------------------------------------
Total_batch_size: 16 = 16
--------------------------------------------------------------------------------
dataset_root:    all_data/en_val	 dataset: /
all_data/en_val/__results___files
sub-director