# Download module

In [None]:
%pip install -f https://download.pytorch.org/whl/torch_stable.html
%pip install -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html

%pip install timm==0.4.5
%pip install natsort
%pip install tensorboard
%pip install nltk
%pip install h5py
%pip install pybind11
%pip install fastwer
%pip install git+https://github.com/pytorch/fairseq.git

%pip install Pillow

# Download Data

In [None]:
!wget "https://layoutlm.blob.core.windows.net/trocr/dataset/IAM.tar.gz?sv=2022-11-02&ss=b&srt=o&sp=r&se=2033-06-08T16:48:15Z&st=2023-06-08T08:48:15Z&spr=https&sig=a9VXrihTzbWyVfaIDlIT1Z0FoR1073VB0RLQUMuudD4%3D" -O IAM_handwriting_dataset.tar.gz
!gunzip IAM_handwriting_dataset.tar.gz
!tar -xvf IAM_handwriting_dataset.tar
!rm IAM_handwriting_dataset.tar

# Import Module

In [None]:
import torch, os
from torch.utils.data import DataLoader, Dataset
from transformers import VisionEncoderDecoderModel, TrOCRProcessor

from tqdm.auto import tqdm
from PIL import Image

# Sentence preprocess

In [None]:
with open('./IAM/processed_str.txt', mode='w') as wf:
    with open('./IAM/gt_test.txt', mode='r') as rf:
        for line in rf:
            wf.write(line.split('\t')[1])

# Dataset

In [None]:
import torchvision.transforms as transforms

class IAM_dataset(Dataset):
    def __init__(self, img_path, text_path) -> None:
        super(IAM_dataset).__init__()
        self.img_path = img_path
        self.text_path = text_path
        self.imgs = [os.path.join(img_path, x) for x in os.listdir(img_path)]
        
        self.sentence = []
        with open(self.text_path, mode='r') as f:
            for line in f:
                self.sentence.append(line.replace('\n',''))
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        imgs_name = self.imgs[index]
        img = Image.open(imgs_name).convert('RGB')
        convert = transforms.ToTensor()
        img = convert(img)
        sentence = self.sentence[index]
        
        return img, sentence

# Scoring

In [None]:
from fairseq.scoring import BaseScorer, register_scorer
from nltk.metrics.distance import edit_distance
from fairseq.dataclass import FairseqDataclass
import fastwer
from Levenshtein import distance
import string

from dataclasses import dataclass
from tqdm.auto import tqdm

@dataclass
class CERScorerConfig(FairseqDataclass):
    name: str = 'default'

@register_scorer("cer", dataclass=FairseqDataclass)
class CERScorer(BaseScorer):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.refs = []
        self.preds = []

    def add_string(self, ref, pred):
        self.refs.append(ref)
        self.preds.append(pred)
    
    def score(self):
        return fastwer.score(self.preds, self.refs, char_level=True)

    def result_string(self) -> str:
        return f"CER: {self.score():.2f}"

# Evalute

In [None]:
data = DataLoader(IAM_dataset(img_path='./IAM/image', text_path='./IAM/processed_str.txt'), batch_size=1, shuffle=0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten').to(device)
scorer = CERScorer(cfg=CERScorerConfig())

with torch.no_grad():
    scores = []
    for imgs, sentence in tqdm(data):
        pixels = processor(images=imgs, return_tensors='pt').pixel_values
        ids = model.generate((pixels).to(device))
        pred = processor.batch_decode(ids, skip_special_tokens=True)[0]
        
        scorer.add_string(pred = pred, ref = sentence)
        
    score = scorer.result_string()
    print(score)