# Training

In [None]:
SMOKE_TEST = True
# SMOKE_TEST = False

# Number of models to train - None means all models (ignored in smoke test mode)
NUM_MODELS = None

# Path to YAML config file containing list of model configs to train
CONFIG_LIST = 'notebooks/configs/model_lists/textrecog.yml'

In [None]:
import warnings

# Ignore all UserWarnings emitted from any submodule of torch
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    module=r"torch.*"
)
# Ignore all UserWarnings emitted from any submodule of torch
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    module=r"mmcv.*"
)

In [None]:
import yaml
from pathlib import Path

# Load model configs from YAML file
with open(CONFIG_LIST, 'r') as f:
    config_paths = yaml.safe_load(f)

# Filter out commented lines and empty entries
active_configs = [cfg for cfg in config_paths if cfg and not cfg.strip().startswith('#')]

# Map model types to checkpoint URLs
CHECKPOINT_URLS = {
    'abinet_custom': 'https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth',
    'abinet-vision_custom': 'https://download.openmmlab.com/mmocr/textrecog/abinet/abinet-vision_20e_st-an_mj/abinet-vision_20e_st-an_mj_20220915_152445-85cfb03d.pth',
    'aster_custom': 'https://download.openmmlab.com/mmocr/textrecog/aster/aster_resnet45_6e_st_mj/aster_resnet45_6e_st_mj-cc56eca4.pth',
    'crnn_custom': 'https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth',
    'master_custom': 'https://download.openmmlab.com/mmocr/textrecog/master/master_resnet31_12e_st_mj_sa/master_resnet31_12e_st_mj_sa_20220915_152443-f4a5cabc.pth',
    'nrtr_custom': 'https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj/nrtr_resnet31-1by8-1by4_6e_st_mj_20220916_103322-a6a2a123.pth',
    'robustscanner_custom': 'https://download.openmmlab.com/mmocr/textrecog/robust_scanner/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real/robustscanner_resnet31_5e_st-sub_mj-sub_sa_real_20220915_152447-7fc35929.pth',
    'sar_custom': 'https://download.openmmlab.com/mmocr/textrecog/sar/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real_20220915_171910-04eb4e75.pth',
    'satrn_custom': 'https://download.openmmlab.com/mmocr/textrecog/satrn/satrn_shallow_5e_st_mj/satrn_shallow_5e_st_mj_20220915_152443-5fd04a4c.pth',
    'svtr_custom': 'https://download.openmmlab.com/mmocr/textrecog/svtr/svtr-base_20e_st_mj/svtr-base_20e_st_mj-ea500101.pth',
}

# Create mapping from config paths to checkpoint URLs
CONFIG_TO_CKPT = {}
for config_path in active_configs:
    config_path_obj = Path(config_path)
    # Extract model type from parent directory name
    model_type = config_path_obj.parent.name
    if model_type in CHECKPOINT_URLS:
        CONFIG_TO_CKPT[config_path] = CHECKPOINT_URLS[model_type]

print(f"Loaded {len(active_configs)} active configs:")
for config in active_configs:
    print(f"  - {config}")

ROOT_CONFIG_FOLDER = 'configs/textrecog'

In [None]:
#@title Train single model

from pathlib import Path
from mmengine.runner import Runner
import time
from mmengine import Config
from dotenv import load_dotenv
import os

if SMOKE_TEST:
    load_dotenv() # NOTE: make sure to reload notebook when changing .env to use new env variables

    os.chdir(os.path.expanduser('~/bonting-identification'))

    if not active_configs:
        raise ValueError("No active configs found in CONFIG_LIST")
    
    # Use the first active config for smoke test
    model_config = active_configs[0]

    cfg = Config.fromfile(model_config)
    cfg['load_from'] = CONFIG_TO_CKPT[model_config]
    cfg.visualizer.name = f'{time.localtime()}'

    cfg.train_cfg['max_epochs'] = 1 # Optionally, smoke test on 1 epoch

    runner = Runner.from_cfg(cfg)
    result = runner.train()

/home/bonting/bonting-identification
07/14 18:33:01 - mmengine /dev/null [4m[97mINFO[0m - Using env variable `MLFLOW_TRACKING_URI` with value of http://localhost:5000 to replace item in config.


  import pkg_resources
  _bootstrap._exec(spec, module)
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f44cd2cf510>>
Traceback (most recent call last):
  File "/home/bonting/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


07/14 18:33:04 - mmengine /dev/null [4m[97mINFO[0m /dev/null 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.11.13 | packaged by conda-forge | (main, Jun  4 2025, 14:48:23) [GCC 13.3.0]
    CUDA available: True
    MUSA available: False
    numpy_random_seed: 1693254328
    GPU 0: NVIDIA GeForce RTX 3090
    CUDA_HOME: /opt/cuda
    NVCC: Cuda compilation tools, release 12.9, V12.9.86
    GCC: gcc (GCC) 15.1.1 20250425
    PyTorch: 2.1.0+cu118
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: 



07/14 18:33:06 - mmengine /dev/null [4m[97mINFO[0m - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) RuntimeInfoHook                    
(BELOW_NORMAL) LoggerHook                         
 -------------------/dev/null 
before_train:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(LOW         ) MlflowDatasetHook                  
(VERY_LOW    ) CheckpointHook                     
 -------------------/dev/null 
before_train_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DistSamplerSeedHook                
 -------------------/dev/null 
before_train_iter:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
 -------------------/dev/null 
after_train_iter:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(BELOW_NORMAL) LoggerHook  

  return obj_cls(**args)


07/14 18:33:07 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth
07/14 18:33:08 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
The model and loaded state dict do not match exactly

unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std

07/14 18:33:08 - mmengine /dev/null [4m[97mINFO[0m - Load checkpoint from https://download.openmmlab.com/mmocr/textrecog/abinet/abinet_20e_st-an_mj/abinet_20e_st-an_mj_20221005_012617-ead8c139.pth
07/14 18:33:08 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:

  return _dataset_source_registry.resolve(
  return _dataset_source_registry.resolve(


07/14 18:33:08 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>




07/14 18:33:09 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:09 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:09 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:09 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:09 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:09 - mmengine /dev/null [4m[97mINFO[0m - [DEBUG] calling hook <mmocr_custom.hooks.mlflow_dataset_hook.MlflowDatasetHook object at 0x7f4347529b10>
07/14 18:33:09 - mmengine /dev/nul

In [None]:
# RESULTS of dict ablation on 1 model

# abinet on CEGD-R w/extended dict: recog/word_acc: 0.90 @ 20 epochs
# abinet on CEGD-R w/allow unk: recog/word_acc: 0.6320 @ 20 epochs
# abinet on CEGD-R-truncated w/original dict: recog/word_acc: 0.93 @ 20 epochs

In [None]:
# !rm -rf work_dirs/*

In [None]:
#@title Train all models

import os
from mmengine.runner import Runner
import time
from mmengine import Config
import pandas as pd
from pathlib import Path
from dotenv import load_dotenv

if not SMOKE_TEST:
    load_dotenv() # NOTE: make sure to reload notebook when changing .env to use new env variables
    os.chdir(os.path.expanduser('~/bonting-identification'))

    results = []
    model_configs = []
    ckpts = []

    # Determine how many models to train
    if NUM_MODELS is None:
        # Use all models when NUM_MODELS is None
        models_to_train = len(active_configs)
    else:
        # Use specified number of models
        models_to_train = min(NUM_MODELS, len(active_configs))
    
    for model_config in active_configs[:models_to_train]:
        cfg = Config.fromfile(model_config)
        cfg['load_from'] = CONFIG_TO_CKPT[model_config]
        cfg.visualizer.name = f'{time.localtime()}'

        # cfg.train_cfg['max_epochs'] = 1

        runner = Runner.from_cfg(cfg)
        result = runner.train()

        results.append(result)
        model_configs.append(Path(model_config).name.rstrip('.py'))
        ckpts.append(Path(CONFIG_TO_CKPT[model_config]).parts[-2])


In [None]:
# results_df = pd.DataFrame(results)
# results_df.insert(0, 'model_config', model_configs)
# results_df.insert(1, 'ckpt', ckpts)
# results_df = results_df.set_index(['model_config', 'ckpt'])
# results_df.sort_values('recog/word_acc', ascending=False, inplace=True)
# results_df

In [None]:
# save_path = Path('reports/eval/cegdr/textrecog/mmocr_finetuned_recog_results.csv')
# save_path.parent.mkdir(parents=True, exist_ok=True)
# print(f'Saving results to:\n{save_path}')
# results_df.to_csv(save_path, index=True, header=True)