In [1]:
import datasets
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AdamW
from torchmetrics.text import cer
from PIL import Image
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = 'cuda:0'
EPOCHS = 10

In [3]:
synthetic_cyrillic_dataset = datasets.load_dataset('nastyboget/synthetic_cyrillic', split='train').train_test_split(test_size=0.1)
synthetic_cyrillic_dataset

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['image', 'text', 'path', 'name'],
        num_rows: 270000
    })
    test: Dataset({
        features: ['image', 'text', 'path', 'name'],
        num_rows: 30000
    })
})

In [4]:
class OCRDataset(Dataset):
    def __init__(self, dataset, processor, max_target_length=128):
        self.dataset = dataset
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        image = self.dataset[idx]['image'].convert("L").convert('RGB')
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [5]:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

model.to(DEVICE)
print('model loaded')

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-stage1 and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model loaded


In [6]:
train_dataset = OCRDataset(dataset=synthetic_cyrillic_dataset['train'], processor=processor)
test_dataset = OCRDataset(dataset=synthetic_cyrillic_dataset['test'], processor=processor)

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)

In [8]:
cer_metric = cer.CharErrorRate()
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True,)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    cer = cer_metric(pred_str, label_str).item()
    return cer

In [9]:
best_cer = float('inf')

In [10]:
optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(EPOCHS):
  model.train()
  train_loss = 0.0
  train_loader = tqdm(train_dataloader)
  for i, batch in enumerate(train_loader):
    for k,v in batch.items():
      batch[k] = v.to(DEVICE)

    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    train_loss += loss.item()
    train_loader.set_postfix({'loss': (train_loss/(i+1))})
    train_loader.update()

  print(f"Epoch: {epoch}, Train Loss: {train_loss/len(train_dataloader)}")
    
  model.eval()
  test_cer = 0.0
  with torch.no_grad():
    test_loader = tqdm(test_dataloader)
    for i, batch in enumerate(test_loader):
      outputs = model.generate(batch["pixel_values"].to(DEVICE))
      cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
      test_cer += cer 
      test_loader.set_postfix({'cer': (test_cer/(i+1))})
      test_loader.update()
  test_cer = test_cer / len(test_dataloader)
  print("Validation CER:", test_cer)
  if test_cer < best_cer:
    best_cer = test_cer 
    model.save_pretrained('./tr_ocr/')
    processor.save_pretrained('./tr_ocr/')

  0%|          | 0/16875 [00:00<?, ?it/s]

100%|██████████| 16875/16875 [5:28:59<00:00,  1.17s/it, loss=0.0848]  


Epoch: 0, Train Loss: 0.08479003954233119


100%|██████████| 1875/1875 [3:47:43<00:00,  7.29s/it, cer=0.0852]  
Non-default generation parameters: {'max_length': 64, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Validation CER: 0.0852381637925903


 13%|█▎        | 2112/16875 [40:03<4:39:58,  1.14s/it, loss=0.088] 


KeyboardInterrupt: 