In [None]:
!pip install transformers pytorch_lightning sentencepiece datasets nlpaug sacrebleu
!pip install "transformers[torch]"

In [None]:
import nlpaug.augmenter.char as nac
from transformers import pipeline
from dataclasses import dataclass, field
from typing import Optional
from transformers import set_seed
from transformers import HfArgumentParser, TrainingArguments
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
from transformers import Trainer, DataCollatorForSeq2Seq
import argparse
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation
from sacrebleu import sentence_bleu
import evaluate
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torch.optim import AdamW
from datasets import Dataset, DatasetDict
from torchmetrics.text import CharErrorRate

In [3]:
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")

In [4]:
training_args = TrainingArguments(
    output_dir='./byt5-ocr-correction',
    run_name="byt5-ocr-correction-cer",
    overwrite_output_dir=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-4,
    warmup_steps=250,
    logging_steps=100,
    eval_strategy="steps",
    eval_steps=250,
    num_train_epochs=4,
    fp16=False,
    max_steps=5000,
)


set_seed(training_args.seed)

In [5]:
model_name_or_path = "google/byt5-small"
max_len = 128
cache_dir = None

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    cache_dir=cache_dir,
)

model = T5ForConditionalGeneration.from_pretrained(
    model_name_or_path,
    cache_dir=cache_dir,
)

tokenizer.model_max_length = max_len
model.config.max_length = max_len

set_seed(training_args.seed)

In [6]:
df = pd.read_csv('dataset/dataset_abbreviation_corrected.csv' , sep=';')
df.head()

Unnamed: 0,line_id,image_path,text,page_id,line_number,ocr_prediction
0,0033_033_line_30,/work/polygonal_lines/polygonal_lines/0033_033...,Ius dei occultaret. la terza ut tetatis fa,0033_033,30,Ius dei occultare. la terza un tẽtatis fa
1,0033_033_line_24,/work/polygonal_lines/polygonal_lines/0033_033...,ti per lo peccato: Cossi uolse esser tenta,0033_033,24,ti perlo peccato:cessi uolse esser tenta
2,0033_033_line_18,/work/polygonal_lines/polygonal_lines/0033_033...,DOMINICA PRIMA,0033_033,18,ḊN.NUNO la Par¶N A
3,0033_033_line_19,/work/polygonal_lines/polygonal_lines/0033_033...,Vctus est iesus in desertum a spi,0033_033,19,Uctus ẽ iesus in desertũ a sai
4,0033_033_line_25,/work/polygonal_lines/polygonal_lines/0033_033...,to per dare consolatione i et conforto anoi,0033_033,25,to per dare cõsonatione iã costorto anoi


In [7]:
def prepare_ocr_correction_dataset(df, tokenizer, max_length=128):
    def preprocess_function(examples):
        inputs = ["correct OCR: " + text for text in examples['ocr_prediction']]
        targets = examples['text']

        model_inputs = tokenizer(
            inputs,
            max_length=max_length,
            truncation=True,
            padding=False,
            return_tensors=None
        )

        labels = tokenizer(
            targets,
            max_length=max_length,
            truncation=True,
            padding=False,
            return_tensors=None
        )

        # Set labels
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    dataset = Dataset.from_pandas(df)

    tokenized_dataset = dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Tokenizing dataset"
    )
    return tokenized_dataset

In [8]:
def split_dataset(df, train_ratio=0.8):
    train_size = int(len(df) * train_ratio)
    train_df = df[:train_size].reset_index(drop=True)
    val_df = df[train_size:].reset_index(drop=True)
    return train_df, val_df

def prepare_complete_dataset(df, tokenizer, max_length=128, train_ratio=0.8):
    train_df, val_df = split_dataset(df, train_ratio)

    print(f"Training samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")

    train_dataset = prepare_ocr_correction_dataset(train_df, tokenizer, max_length)
    val_dataset = prepare_ocr_correction_dataset(val_df, tokenizer, max_length)

    dataset_dict = DatasetDict({
        'train': train_dataset,
        'validation': val_dataset
    })

    return dataset_dict

def show_dataset_examples(dataset, tokenizer, num_examples=3):
    for i in range(min(num_examples, len(dataset))):
        print(f"\n--- Example {i+1} ---")

        input_ids = dataset[i]['input_ids']
        input_text = tokenizer.decode(input_ids, skip_special_tokens=True)
        print(f"Input: {input_text}")

        labels = dataset[i]['labels']
        labels_for_decode = [l if l != -100 else tokenizer.pad_token_id for l in labels]
        target_text = tokenizer.decode(labels_for_decode, skip_special_tokens=True)
        print(f"Target: {target_text}")

In [9]:
dataset_dict = prepare_complete_dataset(df, tokenizer, max_len)

Training samples: 8512
Validation samples: 2129


Tokenizing dataset:   0%|          | 0/8512 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/2129 [00:00<?, ? examples/s]

In [10]:
show_dataset_examples(dataset_dict['train'], tokenizer, num_examples=2)


--- Example 1 ---
Input: correct OCR: Ius dei occultare.  la  terza  un tẽtatis fa
Target: Ius dei occultaret. la terza ut tetatis fa

--- Example 2 ---
Input: correct OCR: ti perlo peccato:cessi uolse esser tenta
Target: ti per lo peccato: Cossi uolse esser tenta


In [11]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    return_tensors="pt"
)

In [12]:
def compute_cer(predictions, references):
    total_chars = 0
    total_errors = 0

    for pred, ref in zip(predictions, references):
        pred = ' '.join(pred.split())
        ref = ' '.join(ref.split())

        total_chars += len(ref)

        edit_distance = compute_edit_distance(pred, ref)
        total_errors += edit_distance

    cer = total_errors / total_chars if total_chars > 0 else 0
    return cer

def compute_edit_distance(s1, s2):
    if len(s1) < len(s2):
        return compute_edit_distance(s2, s1)

    if len(s2) == 0:
        return len(s1)

    previous_row = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row

    return previous_row[-1]

In [16]:
globals()['tokenizer'] = tokenizer

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset_dict['train'],
    eval_dataset=dataset_dict['validation'],
    data_collator=data_collator,
)

trainer.train()

trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)

  trainer = Trainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
250,0.7336,0.689548
500,0.5576,0.607636
750,0.4671,0.56941
1000,0.4308,0.54254
1250,0.3835,0.527145
1500,0.3753,0.518903
1750,0.3427,0.512063
2000,0.3642,0.495754
2250,0.2714,0.514198
2500,0.2696,0.506021




('./byt5-ocr-correction/tokenizer_config.json',
 './byt5-ocr-correction/special_tokens_map.json',
 './byt5-ocr-correction/added_tokens.json')

In [18]:
corrector = pipeline("text2text-generation",
                    model=training_args.output_dir,
                    tokenizer=training_args.output_dir)

ocr_text = "fiqu gracaron corubinaru che tusto lã"
corrected = corrector(f"correct OCR: {ocr_text}")
print(corrected[0]['generated_text'])

Device set to use cuda:0


figlioli giocatori corrubinarii che tutto lan
