# HTR Training

Combine different sets of documents, train HTR on Transkribus and evaluate the results.

In [None]:
import os
import numpy as np
import pandas as pd
import random
import regex
import shutil
import sys
import xml.etree.ElementTree as ET
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from IPython.display import clear_output
sys.path.append(os.getcwd() + '/..')
from scripts.read_transkribus_files import get_text_from_file
from matplotlib.pyplot import imshow

In [None]:
DATA_DIRS = [ "tmp/1586842/Training_set",
              "tmp/1586854/Validation_set",
              "tmp/1609526/Training_set_2",
              "tmp/1609530/Validation_set_2",
              "tmp/1616742/Sample_three-column" ]
IGNORE_DIRS = [ "tmp/1616742/Sample_three-column" ]

YEAR_FILES = [ os.path.basename(file_name) + "_years.csv" for file_name in DATA_DIRS ]
USAGE_FILES = [ os.path.basename(file_name) + "_usage.csv" for file_name in DATA_DIRS ]
NAME_FILES = [ os.path.basename(file_name) + "_names.csv" for file_name in DATA_DIRS ]

MAX_WIDTH = 1000
MAX_HEIGHT = 750

random.seed(42)

## 1. Find years available training data

Some of the names of the scans in the training data have changed from the central repository names. Here we check the scans, specify the year of the scan and store these in year files (`_years.csv`) for each data collection.

In [None]:
def collect_years_of_files(data_dirs=DATA_DIRS, year_files=YEAR_FILES):
    for data_dir_id in range(0, len(data_dirs)):
        last_year = ""
        file_counter = 0
        data_dir = data_dirs[data_dir_id]
        year_file = year_files[data_dir_id]
        file_names = sorted(os.listdir(os.path.join(data_dir, "page")))
        try:
            years = list(pd.read_csv(year_file, index_col=0)["0"])
        except:
            years = []
        for file_name in file_names:
            file_counter += 1
            if file_counter > len(years):
                try:
                    display(Image.open(os.path.join(data_dir, regex.sub(".xml", ".jpg", file_name))))
                except:
                    display(Image.open(os.path.join(data_dir, regex.sub(".xml", ".JPG", file_name))))
                print(f"data_dir: {data_dir}; file name: {file_name}; last year: {last_year};", end=" ")
                print(f"file: {file_counter}/{len(file_names)}")
                year = input().strip()
                if year == "":
                    year = last_year
                years.append(year)
                pd.DataFrame(years).to_csv(year_file)
                last_year = year
                clear_output(wait=True)

In [None]:
def get_years(year_files=YEAR_FILES):
    all_years = []
    for file_name in year_files:
        years = pd.read_csv(file_name, index_col=0)
        all_years.extend(years["0"])
    return all_years

In [None]:
collect_years_of_files(data_dirs=[ "tmp/1616639/Sample_test_1" ], year_files=[ "Sample_test_1_years.csv" ])

## 2. Determine file usability

Some scans contain parts of multiple certificates or damaged certficates. We want to exclude these from HTR training and layout training. We label them and store the labels in usage files (`_usage.csv`).

In [None]:
def check_file_usability(data_dirs=DATA_DIRS, usage_files=USAGE_FILES):
    for data_dir_id in range(0, len(data_dirs)):
        file_counter = 0
        data_dir = data_dirs[data_dir_id]
        usage_file = usage_files[data_dir_id]
        file_names = sorted(os.listdir(os.path.join(data_dir, "page")))
        try:
            usage = list(pd.read_csv(usage_file, index_col=0)["0"])
        except:
            usage = []
        for file_name in file_names:
            file_counter += 1
            if file_counter > len(usage):
                try:
                    display(Image.open(os.path.join(data_dir, regex.sub(".xml", ".jpg", file_name))))
                except:
                    display(Image.open(os.path.join(data_dir, regex.sub(".xml", ".JPG", file_name))))
                print(f"data_dir: {data_dir}; file name: {file_name};", end=" ")
                print(f"file: {file_counter}/{len(file_names)}")
                usage_value = input().strip()
                if usage_value == "":
                    usage_value = "yes"
                else:
                    usage_value = "no"
                usage.append(usage_value)
                pd.DataFrame(usage).to_csv(usage_file)
                clear_output(wait=True)

In [None]:
check_file_usability()

## 3. Plot file counts per decade

In [None]:
def year2decade(year):
    return int(regex.sub(".$", "", str(year)))

In [None]:
def get_decades(year_files=YEAR_FILES):
    return [ year2decade(year) for year in get_years(year_files=YEAR_FILES) ]

In [None]:
def get_file_usage(usage_files=USAGE_FILES):
    usage = []
    for file_name in usage_files:
        usage.extend(list(pd.read_csv(file_name, index_col=0)["0"]))
    return usage

In [None]:
def plot_file_counts(file_counts):
    x = [ x[0] for x in list(file_counts.index) ]
    y = list(file_counts.values)
    plt.xticks(ticks=x, labels=[ str(x_value) + "0" for x_value in x ])
    plt.title(f"Number of selected scans per decade (total={sum(file_counts.values)})")
    plt.bar(x, y)

In [None]:
NBR_OF_SKIPPED_FILES = 50 # while annotation is not finished

decades = get_decades()[:-NBR_OF_SKIPPED_FILES]
usage = get_file_usage()[:-NBR_OF_SKIPPED_FILES]
decades = [ decades[i] for i in range(0, len(usage)) if usage[i] == "yes" ]

In [None]:
plot_file_counts(pd.DataFrame(decades).value_counts())

## 4. Select complementary scans (unfinished)

Plan:

1. find years
2. remove difficult cases
3. fill up decades to 15
4. sort by name
5. train layout and baseline detection
6. train htr, check CER
7. evaluate names

In [None]:
FILE_DIR = "../../data/Overlijden"

def random_select_scan(year):
    year_dir = os.path.join(FILE_DIR, "O.R. " + str(year))
    file_names = []
    for region_dir in os.listdir(year_dir):
        file_names.extend([ os.path.join(region_dir, file_name) for file_name in os.listdir(os.path.join(year_dir, region_dir)) ])
    return file_names[random.randint(0, len(file_names) - 1)]

In [None]:
def get_years_from_dir(dir_name):
    years = []
    for file_name in os.listdir(dir_name):
        file_name_parts = file_name.split()
        years.append(int(file_name_parts[1]))
    return years

In [None]:
years = get_years()[:-NBR_OF_SKIPPED_FILES]
years = [ years[i] for i in range(0, len(usage)) if usage[i] == "yes" ]

In [None]:
TARGET_DIR = os.path.join(FILE_DIR, "x-samples", "complementary-2023")

target_dir_years = get_years_from_dir(TARGET_DIR)
for decade_start in [ 1860, 1870, 1910, 1920, 1940]:
    available_years = sorted(set([ year for year in years if year >= decade_start and year < decade_start + 10 ]))
    for year in range(decade_start, decade_start + 10):
        while year not in available_years and year not in target_dir_years:
            file_name = os.path.join(FILE_DIR, "O.R. " + str(year), random_select_scan(year))
            display(Image.open(file_name))
            print("accept this file?")
            accept_value = input().strip()
            if accept_value == "":
                shutil.copy(file_name, TARGET_DIR)
                available_years.append(year)
            clear_output(wait=True)

In [None]:
random_select_scan(1871)

In [None]:
text, meta_data, file_regions = get_text_from_file(os.path.join(DATA_DIRS[0], "page", "p001.xml"))

In [None]:
saw_trigger_word = False
for line in  text.split("\n"):
    if saw_trigger_word:
        print(line)
        break
    if regex.search("^Heden", line):
        saw_trigger_word = True

## 5. Determine archive file names

Find the archive file names of the scans in the training data. We use the years stored in the years files and ask the user for the folio number. Then the matching scans from the archive are shown and the user can choose one by entering the matching id number. Choice "0" here will enable the user to change the provided id number. The resulting archive file names are stored in name files per collection (`_names.csv`). 

In [None]:
ARCHIVE_FILE_DIR = "../../data/Overlijden"

def check_archive_file_name(archive_file_name):
    file_name_parts = archive_file_name.split()
    year_dir = " ".join(file_name_parts[:2])
    region_dir = regex.sub("\.$", "", " ".join(file_name_parts[:-1]))
    file_name_with_dirs = os.path.join(ARCHIVE_FILE_DIR, year_dir, region_dir, archive_file_name)
    return os.path.isfile(file_name_with_dirs)

In [None]:
def get_candidate_images(year, folio_number):
    year_dir = f"O.R. {year}"
    folio_number = str(folio_number).zfill(3)
    candidate_images = []
    for region_dir in os.listdir(os.path.join(ARCHIVE_FILE_DIR, year_dir)):
        if regex.search(year_dir, region_dir):
            for image_file_name in os.listdir(os.path.join(ARCHIVE_FILE_DIR, year_dir, region_dir)):
                if regex.search(f"{folio_number}\.jpg$", image_file_name, regex.IGNORECASE):
                    candidate_images.append(os.path.join(ARCHIVE_FILE_DIR, year_dir, region_dir, image_file_name))
    return candidate_images

In [None]:
def determine_archive_file_names(data_dirs=DATA_DIRS, name_files=NAME_FILES, year_files=YEAR_FILES):
    for data_dir_id in range(0, len(data_dirs)):
        file_counter = 0
        data_dir = data_dirs[data_dir_id]
        name_file = name_files[data_dir_id]
        year_file = year_files[data_dir_id]
        file_names = sorted(os.listdir(os.path.join(data_dir, "page")))
        years = list(pd.read_csv(year_file, index_col=0)["0"])
        try:
            archive_file_names = list(pd.read_csv(name_file, index_col=0)["0"])
        except:
            archive_file_names = []
        for file_name in file_names:
            year = years[file_counter]
            file_counter += 1
            while file_counter > len(archive_file_names):
                clear_output(wait=True)
                try:
                    display(Image.open(os.path.join(data_dir, 
                                                    regex.sub(".xml", ".jpg", file_name))).resize((MAX_WIDTH, MAX_HEIGHT)))
                except:
                    display(Image.open(os.path.join(data_dir, 
                                                    regex.sub(".xml", ".JPG", file_name))).resize((MAX_WIDTH, MAX_HEIGHT)))
                print(f"data_dir: {data_dir}; file name: {file_name}; year: {year}", end=" ")
                print(f"file: {file_counter}/{len(file_names)}")
                print("Enter folio number")
                folio_number = input().strip()
                candidate_images = get_candidate_images(year, folio_number)
                for index in range(0, len(candidate_images)):
                    #display(index+1, Image.open(candidate_images[index]).resize((0.5 * MAX_WIDTH, 0.5 * MAX_HEIGHT)))
                    display(index+1, Image.open(candidate_images[index]).resize((500,500)))
                print(f"Enter id number of scan (1-{len(candidate_images)})")
                chosen_index = int(input().strip()) - 1
                if chosen_index >= 0:
                    archive_file_name = os.path.basename(candidate_images[chosen_index])
                    archive_file_names.append(archive_file_name)
                    pd.DataFrame(archive_file_names).to_csv(name_file)
                    break

In [None]:
determine_archive_file_names(data_dirs=[ "tmp/1616639/Sample_test_1" ], 
                             name_files=[ "Sample_test_1_names.csv" ], 
                             year_files=[ "Sample_test_1_years.csv" ] )

## 6. Check archive file names

In [None]:
def add_directory_name_to_archive_file_name(archive_file_name):
    file_name_parts = archive_file_name.split()
    year_dir = " ".join(file_name_parts[:2])
    region_dir = regex.sub("\.$", "", " ".join(file_name_parts[:-1]))
    file_name_with_dirs = os.path.join(ARCHIVE_FILE_DIR, year_dir, region_dir, archive_file_name)
    return file_name_with_dirs

In [None]:
def check_archive_file_names(data_dirs=DATA_DIRS, name_files=NAME_FILES, start_data_dir = "", start_file_name = ""):
    for data_dir_id in range(0, len(data_dirs)):
        file_counter = 0
        data_dir = data_dirs[data_dir_id]
        if start_data_dir != "" and data_dir != start_data_dir:
            continue
        start_data_dir = ""
        name_file = name_files[data_dir_id]
        file_names = sorted(os.listdir(os.path.join(data_dir, "page")))
        archive_file_names = list(pd.read_csv(name_file, index_col=0)["0"])
        for file_name in file_names:
            archive_file_name = str(archive_file_names[file_counter])
            archive_file_name_with_dir = add_directory_name_to_archive_file_name(archive_file_name)
            file_counter += 1
            if start_file_name != "" and file_name != start_file_name:
                continue
            start_file_name = ""
            clear_output(wait=True)
            try:
                file_name_with_dir = os.path.join(data_dir, regex.sub(".xml", ".jpg", file_name))
                display(Image.open(file_name_with_dir).resize((int(0.4 * MAX_WIDTH), int(0.4 * MAX_HEIGHT))))
            except:
                file_name_with_dir = os.path.join(data_dir, regex.sub(".xml", ".JPG", file_name))
                display(Image.open(file_name_with_dir).resize((int(0.4 * MAX_WIDTH), int(0.4 * MAX_HEIGHT))))
            print(f"file name: {data_dir}/{file_name}; size: {os.path.getsize(file_name_with_dir)}; file: {file_counter}/{len(file_names)}")
            if archive_file_name != "nan":
                display(Image.open(archive_file_name_with_dir).resize((int(0.4 * MAX_WIDTH), int(0.4 * MAX_HEIGHT))))
                print(f"archive file name: {archive_file_name}; size: {os.path.getsize(archive_file_name_with_dir)}")
            response = input().strip()
            if response == "0":
                return

In [None]:
check_archive_file_names()

## 7. Make image file name link table

In [None]:
def make_image_file_link_table(data_dirs=DATA_DIRS, name_files=NAME_FILES):
    table_dict =  { "collectie": [], "bestandsnaam": [], "archiefnaam": [], "opmerking": [] }
    for data_dir_id in range(0, len(data_dirs)):
        data_dir = data_dirs[data_dir_id]
        if data_dir not in IGNORE_DIRS:
            name_file = name_files[data_dir_id]
            file_names = sorted(os.listdir(os.path.join(data_dir, "page")))
            archive_file_names = list(pd.read_csv(name_file, index_col=0)["0"])
            file_counter = 0
            for file_name in file_names:
                archive_file_name = str(archive_file_names[file_counter])
                table_dict["collectie"].append(os.path.basename(data_dir))
                table_dict["bestandsnaam"].append(file_name)
                table_dict["archiefnaam"].append(archive_file_name)
                table_dict["opmerking"].append("")
                file_counter += 1
    return pd.DataFrame(table_dict)

In [None]:
link_table = make_image_file_link_table()
link_table.to_csv("koppeltabel_HTR_train.csv")