In [14]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
import random
import fastwer
from tqdm.auto import tqdm

In [6]:
CKP_MODEL = "/home/ralvarez22/Documentos/trocr_hand/trocr_llm/finetuned/Akivili/V_5"
DEVICE = "cuda" # Where to load the model
DATASET_FILE = "/home/ralvarez22/Documentos/trocr_hand/trocr_llm/datasets/cursive_hand_cropped/metadata.json"

In [7]:
processor = TrOCRProcessor.from_pretrained(CKP_MODEL,device_map=DEVICE)
model = VisionEncoderDecoderModel.from_pretrained(CKP_MODEL, device_map=DEVICE)
# This configuration allows to correctly set the BOS token for inference and the Temperature and Sample strategy for the decoder
# This could be setted on the generation_config.json file, generated when saving the model, but, to avoid mismatch or errors, I overwrite this
model.generation_config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.generation_config.temperature = 0.4
model.generation_config.do_sample = True

2024-07-03 14:11:17.512304: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
cropped_images_dataset = json.load(open(DATASET_FILE, "r"))

In [9]:
items_to_eval = int(len(cropped_images_dataset) * 0.9)

In [11]:
items_to_eval

3472

In [12]:
test_dataset = random.sample(cropped_images_dataset, items_to_eval)

In [20]:
cer_value = 0
wer_value = 0
for test_item in tqdm(test_dataset):
    pil_image = Image.open(test_item["image"]).convert("RGB")
    proc_image = processor(pil_image, return_tensors="pt").pixel_values.to(DEVICE)
    tgt_text = test_item["label"]
    data_generated = model.generate(proc_image)
    gen_text = processor.tokenizer.decode(data_generated[0].cpu(), skip_special_tokens=True)
    #print("Original: {} - Recognized: {}".format(tgt_text, gen_text))
    cer_val = fastwer.score_sent(gen_text, tgt_text, char_level=True)
    wer_val = fastwer.score_sent(gen_text, tgt_text)
    cer_value += cer_val
    wer_value += wer_val

  0%|          | 0/3472 [00:00<?, ?it/s]

In [23]:
cer_value /= items_to_eval
wer_value /= items_to_eval

In [24]:
print("CER Avg value: {:.4f}".format(cer_value))
print("WER Avg value: {:.4f}".format(wer_value))

CER Avg value: 1.4539
WER Avg value: 1.1593
