In [1]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, XLMRobertaTokenizerFast

# processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    'google/vit-base-patch16-224',
    'xlm-roberta-base'
)


Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of XLMRobertaForCausalLM were not initialized from the model checkpoint at xlm-roberta-base

In [2]:
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print(f"Total: {sum(p.numel() for p in model.parameters())}")

Trainable: 393051282
Total: 393051282


In [None]:
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

In [2]:
from torch.utils.data import Dataset, random_split, DataLoader
import os
from PIL import Image
import torch

from lightning import LightningDataModule
from typing import Any, Dict, Optional, Tuple



### Custom Dataset

In [3]:
# from torch.utils.data import Dataset, random_split, DataLoader
# import os
# from PIL import Image
# import torch

class ORCDataset(Dataset):
    def __init__(self, root_dir: str, map_file: str, processor, tokenizer, max_target_length: int = 128):
        self.root_dir = root_dir
        self.paths = []
        self.labels = []
        with open(os.path.join(root_dir, map_file), encoding= 'utf8') as f:
            for l in f.readlines():
                path, label = l.strip().split()
                self.paths.append(path)
                self.labels.append(label)
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length


    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        
        image = Image.open(os.path.join(self.root_dir, self.paths[idx])).convert('RGB')
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        labels = self.tokenizer(self.labels[idx], padding= 'max_length', max_length= self.max_target_length).input_ids
        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]

        return {
            'pixel_values': pixel_values.squeeze(),
            'labels': torch.tensor(labels)    
        }

### Custom Lightning Data Module

In [29]:
# import torch
# from lightning import LightningDataModule
# from typing import Any, Dict, Optional, Tuple

class ORCDataModule(LightningDataModule):
    def __init__(self, data_dir: str, train_val_test_split: Tuple[int, int, int] = None, batch_size: int = 32, num_workers: int = 0, pin_memory: bool = False) -> None:
        super().__init__()
        self.save_hyperparameters(logger= False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def num_classes(self) -> int:
        pass

    def prepare_data(self) -> None:
        pass

    def setup(self, processor, map_file: str = 'train_annotation.txt', max_target_length: int = 128, split: bool = False, stage: Optional[str] = None) -> None:
        if not self.data_train:
            dataset = ORCDataset(self.hparams.data_dir, map_file, processor, max_target_length)
            if split and not self.hparams.train_val_test_split and not self.data_val and not self.data_test:
                self.data_train, self.data_val, self.data_test = random_split(
                    dataset = dataset,
                    lengths = self.hparams.train_val_test_split,
                    generator = torch.Generator().manual_seed(42)
                )
            else:
                self.data_train = dataset

    def train_loader(self, shuffle: bool = True) -> DataLoader[Any]:
        return DataLoader(
            dataset= self.data_train,
            batch_size= self.hparams.batch_size,
            num_workers= self.hparams.num_workers,
            pin_memory= self.hparams.pin_memory,
            shuffle= shuffle
        )

    def val_loader(self, shuffle: bool = False) -> DataLoader[Any]:
        return DataLoader(
            dataset= self.data_val,
            batch_size= self.hparams.batch_size,
            num_workers= self.hparams.num_workers,
            pin_memory= self.hparams.pin_memory,
            shuffle= shuffle
        )

    def test_loader(self, shuffle: bool = False) -> DataLoader[Any]:
        return DataLoader(
            dataset= self.data_test,
            batch_size= self.hparams.batch_size,
            num_workers= self.hparams.num_workers,
            pin_memory= self.hparams.pin_memory,
            shuffle= shuffle
        )

    def teardown(self, stage: Optional[str] = None) -> None:
        pass

    def state_dict(self) -> Dict[Any, Any]:
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        pass

### Load Data

In [2]:
train_dataset = ORCDataset('./../data/', 'train_annotation.txt', processor, tokenizer)
valid_dataset = ORCDataset('./../data/', 'valid_annotation.txt', processor, tokenizer)

NameError: name 'ORCDataset' is not defined

### Test

In [8]:
test_token = tokenizer('alo', return_tensors= 'pt', padding= 'max_length', max_length= 32)
test_token

{'input_ids': tensor([[  0,  10, 365,   2,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1]]), 'attention_mask': tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])}

In [9]:
tokenizer.decode(test_token['input_ids'][0])

'<s> alo</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [13]:
ids = model.generate(torch.unsqueeze(train_dataset[0]['pixel_values'], 0))



In [14]:
ids

tensor([[    0, 24658, 24658, 24658, 24658, 24658, 24658, 24658, 24658, 24658,
         24658, 24658, 24658, 24658, 24658, 24658, 24658, 24658, 24658, 24658]])

In [15]:
tokenizer.batch_decode(ids, skip_special_tokens= True)[0]

'underunderunderunderunderunderunderunderunderunderunderunderunderunderunderunderunderunderunder'

In [111]:
train_dataset.labels[4]

'Quấy'

In [19]:
label = train_dataset[2]['labels']
label[label == -100] = tokenizer.pad_token_id

In [21]:
tokenizer.decode(label)

'<s> nhẹn</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [68]:
label

tensor([   0,  282,  298, 1376, 3070, 9253,  282,    2,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1])

### Train

In [5]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# model.config.vocab_size = model.config.decoder.vocab_size


In [6]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from evaluate import load

cer_metric = load("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True, 
    output_dir="./",
    report_to='none'
)

In [7]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator,
)
trainer.train()



Epoch,Training Loss,Validation Loss
