In [1]:
from data.loader.custom_loader import CustomLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
from jiwer import cer
from tqdm import tqdm
from PIL import Image
from collections import defaultdict

# Load the pre-trained TrOCR model and processor
model = VisionEncoderDecoderModel.from_pretrained("../custom_models/trocr-printed/w_augmentation/small_medium/1.0/vision_model")
processor = TrOCRProcessor.from_pretrained("../custom_models/trocr-printed/w_augmentation/small_medium/1.0/processor")

kubhi_paths = ["../datasets/printed/Histrorical_News_Paper/combined.csv" ]
kubhist_cl = CustomLoader(kubhi_paths)
kubhist_cl.generate_dataframe()
#put ../ in every file name in the dataframe
kubhist_df = kubhist_cl.get_dataframe()
kubhist_df["file_name"] = "../" + kubhist_df["file_name"]
kubhist_df.head()

File exists: ../datasets/printed/Histrorical_News_Paper/combined.csv
Encoding: utf-8


Calculating max length: 100%|██████████| 8393/8393 [00:00<00:00, 22665.58it/s]


Unnamed: 0,file_name,text
0,../datasets/printed/Histrorical_News_Paper/TES...,denna ſekt.
1,../datasets/printed/Histrorical_News_Paper/TES...,﻿Antiomianer kallas ſå af Gre=
2,../datasets/printed/Histrorical_News_Paper/TES...,De antaga ej goda gerningar ſåſom
3,../datasets/printed/Histrorical_News_Paper/TES...,"nödwändiga medel till ſaligheten, och på="
4,../datasets/printed/Histrorical_News_Paper/TES...,"ſtå, att de utwalde ingenting kunna"


In [2]:
# Function to evaluate CER and show frequency of mismatched characters
def evaluate_cer_and_mismatched_chars(_model, _processor, _dataset):
    _model.eval()
    cer_scores = []
    mismatched_chars = defaultdict(int)

    for _, example in tqdm(_dataset.iterrows(), total=len(_dataset)):
        image_path = example["file_name"]
        ground_truth_text = example["text"]

        # Load the image
        image = Image.open(image_path).convert("RGB")

        # Preprocess the image
        pixel_values = _processor(images=image, return_tensors="pt").pixel_values

        # Generate prediction
        with torch.no_grad():
            generated_ids = _model.generate(pixel_values)
        predicted_text = _processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Calculate CER
        cer_score = cer(ground_truth_text, predicted_text)
        cer_scores.append(cer_score)

        # Compare ground truth and predicted text character by character
        for gt_char, pred_char in zip(ground_truth_text, predicted_text):
            if gt_char != pred_char:
                mismatched_chars[(gt_char, pred_char)] += 1

    # Return the average CER and mismatched characters frequency
    _average_cer = sum(cer_scores) / len(cer_scores)
    return _average_cer, dict(mismatched_chars)

# Evaluate the model and get mismatched characters frequency
average_cer, mismatched_chars_freq = evaluate_cer_and_mismatched_chars(model, processor, kubhist_df[:100])
print(f"Average CER: {average_cer}")
print("Mismatched Characters Frequency:")
for chars, freq in mismatched_chars_freq.items():
    print(f"{chars}: {freq}")

100%|██████████| 100/100 [02:21<00:00,  1.41s/it]

Average CER: 0.9176029238833198
Mismatched Characters Frequency:
('e', 'o'): 11
('n', 'i'): 11
('n', ' '): 55
('a', ':'): 3
('ſ', '1'): 1
('e', '0'): 1
('k', '.'): 2
('t', ' '): 54
('.', '1'): 1
('\ufeff', 'E'): 1
('A', 'n'): 2
('t', 's'): 10
('i', 't'): 13
('m', 'r'): 17
('i', ' '): 33
('n', 'v'): 9
('e', ' '): 77
('r', 'd'): 9
(' ', 'e'): 46
('k', 'm'): 2
('a', ' '): 56
('l', 'b'): 4
('l', 'o'): 7
('a', 'r'): 13
('s', ' '): 7
(' ', 'i'): 32
('ſ', ' '): 35
('å', 'O'): 1
(' ', 's'): 24
('a', 'l'): 15
('f', 'o'): 1
('G', 'p'): 1
('r', 'å'): 5
('=', 'V'): 1
('a', 'v'): 9
('n', 'a'): 18
('t', 'n'): 16
('g', 'i'): 3
('a', 'g'): 5
(' ', 'a'): 40
('j', 'p'): 1
(' ', 'å'): 5
('g', ' '): 24
('o', 'd'): 3
('d', 'a'): 7
('g', 'n'): 7
('r', 'j'): 2
('i', 'n'): 3
('n', 'u'): 8
('g', 'a'): 9
('r', 'i'): 17
('ſ', 'j'): 4
('å', 'u'): 1
('ſ', 'n'): 3
('o', 'i'): 6
('m', '.'): 4
('n', 'k'): 9
('ö', 'o'): 7
('d', 'n'): 12
('w', 'j'): 1
('ä', 'u'): 2
('n', 'g'): 7
(' ', ','): 18
('m', ' '): 25
('e', '(')


