# TrOCR 토크나이저 decoding 확인하기

In [None]:
import pandas as pd
df = pd.read_csv('halfdata/cropped_image_half.csv')
df

In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=32): # max_target_length 낮출 필요 있음.max(df['text length']) 찍어보기
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [None]:
from transformers import TrOCRProcessor, AutoTokenizer

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed')
train_dataset = IAMDataset(root_dir='halfdata/cropped_image_half/',
                           df=train_df,
                           processor=processor)
eval_dataset = IAMDataset(root_dir='halfdata/cropped_image_half/',
                           df=test_df,
                           processor=processor)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

In [None]:
import csv

for i in range(40800):
    encoding = train_dataset[i]
    labels = encoding['labels']

    labels[labels == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.decode(labels, skip_special_tokens=True)

    with open('trocrtokenizer2.csv','a') as f:
        data = [{'original_text': train_df['text'][i], 'decoded_text': label_str}]
        writer = csv.DictWriter(f, fieldnames= data[0].keys())
        if i == 0 :
            writer.writeheader()
        writer.writerows(data)