In [1]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
from tqdm.auto import tqdm
from urllib.request import urlretrieve
from zipfile import ZipFile
from torch.utils.data import Dataset
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import glob
import torch.optim as optim
import evaluate
from tqdm.notebook import tqdm

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
torch.cuda.empty_cache()

In [4]:
def seed_everything(seed_value):
   np.random.seed(seed_value)
   torch.manual_seed(seed_value)
   torch.cuda.manual_seed_all(seed_value)
   torch.backends.cudnn.deterministic = True
   torch.backends.cudnn.benchmark = False
 
seed_everything(42)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
#device = torch.device('cpu')

In [6]:
def download_and_unzip(url, save_path):
    print(f"Downloading and extracting assets....", end="")
    urlretrieve(url, save_path)

    try:
        with ZipFile(save_path) as z:
            z.extractall(os.path.split(save_path)[0])
        print("Done")
    except Exception as e:
        print("\nInvalid file.", e)
 
URL = r"https://storage.teklia.com/public/rimes2011/RIMES-2011-Lines.zip"
asset_zip_path = os.path.join(os.getcwd(), "Datasets/RIMES-2011-Lines.zip")

if not os.path.exists(asset_zip_path):
    download_and_unzip(URL, asset_zip_path)

In [7]:
torch.cuda.empty_cache()

In [8]:
#processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-stage1')
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
model = VisionEncoderDecoderModel.from_pretrained(
    'microsoft/trocr-small-stage1'
).to(device)

VisionEncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-stage1 and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3 # 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [10]:
class RIMESDataset(Dataset):
   def __init__(self, root_dir, df, processor, max_target_length=128):
       self.root_dir = root_dir
       self.df = df
       self.processor = processor
       self.max_target_length = max_target_length
       self.batch_size = 4
 
       self.df['text'] = self.df['text'].fillna('')
 
 
   def __len__(self):
       return len(self.df)
 
 
   def __getitem__(self, idx):
       file_name = self.df['file_name'][idx]
       text = self.df['text'][idx]

       image = Image.open(file_name).convert('RGB')
       pixel_values = self.processor(image, return_tensors='pt').pixel_values
       labels = self.processor.tokenizer(
           text,
           padding='max_length',
           max_length=self.max_target_length
       ).input_ids

       labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
       encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
       return encoding

In [11]:
def get_text(transcription_path):
    with open(transcription_path) as f_transcription:
        transcription = f_transcription.readline().strip()
    return transcription

In [12]:
def prepare_RIMES(dataset_path, file):
    set_path = os.path.join(dataset_path, 'Sets')
    images_path = os.path.join(dataset_path, 'Images')
    text_path = os.path.join(dataset_path, 'Transcriptions')
    dataset ={'file_name': [], 'text': []} 

    i = 0
    with open(os.path.join(set_path, file)) as f:
        for val in f:
            dataset['file_name'].append(os.path.join(images_path, val.strip() + '.jpg'))
            text = get_text(os.path.join(text_path, val.strip() + '.txt'))
            dataset['text'].append(text) 
            
    return pd.DataFrame(dataset)

In [13]:
dataset_path=os.path.join('Datasets', 'RIMES-2011-Lines')

In [14]:
train_df = prepare_RIMES(dataset_path, 'TrainLines.txt')
test_df = prepare_RIMES(dataset_path, 'TestLines.txt')

In [15]:
train_dataset = RIMESDataset(root_dir=dataset_path,
                           df=train_df,
                           processor=processor)
test_dataset = RIMESDataset(root_dir=dataset_path,
                           df=test_df,
                           processor=processor)

In [16]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_dataloader = DataLoader(test_dataset, batch_size=4)

In [17]:
cer_metric = evaluate.load('cer')
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [18]:
wer_metric = evaluate.load('wer')
def compute_wer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return wer

In [19]:
accumulation_steps = 4
num_epochs = 10

In [20]:
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
cer_vals = []
wer_vals = []

In [21]:
for epoch in range(num_epochs):  # loop over the dataset multiple times
   # train
    model.train()
    batch_idx = 0
    for batch in tqdm(train_dataloader):
        for k,v in batch.items():
            batch[k] = v.to(device)

            ### FORWARD AND BACK PROP   
        outputs = model(**batch) 

        outputs["loss"] = outputs["loss"] / accumulation_steps
        outputs["loss"].backward()

            ### UPDATE MODEL PARAMETERS
        if not batch_idx % accumulation_steps:
            optimizer.step()
            optimizer.zero_grad()

            ### LOGGING
        if not batch_idx % 300:
            print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} "
                    f"| Batch {batch_idx:04d}/{len(train_dataloader):04d} "
                    f"| Loss: {outputs['loss']:.4f}")

        batch_idx += 1
        
    model.eval()
    valid_cer = 0.0
    valid_wer = 0.0 
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
       # run batch generation
            outputs = model.generate(batch["pixel_values"].to(device))
       # compute metrics
            cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
            wer = compute_wer(pred_ids=outputs, label_ids=batch["labels"])
         
            valid_cer += cer
            valid_wer += wer
    
    cer_vals.append(valid_cer / len(eval_dataloader))
    wer_vals.append(valid_wer / len(eval_dataloader)) 
    print("Validation CER:", valid_cer / len(eval_dataloader))
    print("Validation WER:", valid_wer / len(eval_dataloader))


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0001/0010 | Batch 0000/2547 | Loss: 2.5380
Epoch: 0001/0010 | Batch 0300/2547 | Loss: 0.4976
Epoch: 0001/0010 | Batch 0600/2547 | Loss: 0.2476
Epoch: 0001/0010 | Batch 0900/2547 | Loss: 0.3102
Epoch: 0001/0010 | Batch 1200/2547 | Loss: 0.2772
Epoch: 0001/0010 | Batch 1500/2547 | Loss: 0.3139
Epoch: 0001/0010 | Batch 1800/2547 | Loss: 0.2017
Epoch: 0001/0010 | Batch 2100/2547 | Loss: 0.1906
Epoch: 0001/0010 | Batch 2400/2547 | Loss: 0.2093


  0%|          | 0/195 [00:00<?, ?it/s]



Validation CER: 0.16317000041528312
Validation WER: 0.3498530799381593


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0002/0010 | Batch 0000/2547 | Loss: 0.1621
Epoch: 0002/0010 | Batch 0300/2547 | Loss: 0.6846
Epoch: 0002/0010 | Batch 0600/2547 | Loss: 0.2419
Epoch: 0002/0010 | Batch 0900/2547 | Loss: 0.1515
Epoch: 0002/0010 | Batch 1200/2547 | Loss: 0.1236
Epoch: 0002/0010 | Batch 1500/2547 | Loss: 0.1711
Epoch: 0002/0010 | Batch 1800/2547 | Loss: 0.1747
Epoch: 0002/0010 | Batch 2100/2547 | Loss: 0.2262
Epoch: 0002/0010 | Batch 2400/2547 | Loss: 0.0903


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.195517271957147
Validation WER: 0.3644511242234529


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0003/0010 | Batch 0000/2547 | Loss: 0.0991
Epoch: 0003/0010 | Batch 0300/2547 | Loss: 0.0647
Epoch: 0003/0010 | Batch 0600/2547 | Loss: 0.1203
Epoch: 0003/0010 | Batch 0900/2547 | Loss: 0.0525
Epoch: 0003/0010 | Batch 1200/2547 | Loss: 0.1275
Epoch: 0003/0010 | Batch 1500/2547 | Loss: 0.0980
Epoch: 0003/0010 | Batch 1800/2547 | Loss: 0.1632
Epoch: 0003/0010 | Batch 2100/2547 | Loss: 0.1936
Epoch: 0003/0010 | Batch 2400/2547 | Loss: 0.0746


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.12319982827631544
Validation WER: 0.2550179774994266


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0004/0010 | Batch 0000/2547 | Loss: 0.2410
Epoch: 0004/0010 | Batch 0300/2547 | Loss: 0.1787
Epoch: 0004/0010 | Batch 0600/2547 | Loss: 0.1110
Epoch: 0004/0010 | Batch 0900/2547 | Loss: 0.2283
Epoch: 0004/0010 | Batch 1200/2547 | Loss: 0.0670
Epoch: 0004/0010 | Batch 1500/2547 | Loss: 0.1231
Epoch: 0004/0010 | Batch 1800/2547 | Loss: 0.0593
Epoch: 0004/0010 | Batch 2100/2547 | Loss: 0.1999
Epoch: 0004/0010 | Batch 2400/2547 | Loss: 0.1107


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.14404253751637805
Validation WER: 0.27346333680977225


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0005/0010 | Batch 0000/2547 | Loss: 0.1128
Epoch: 0005/0010 | Batch 0300/2547 | Loss: 0.0881
Epoch: 0005/0010 | Batch 0600/2547 | Loss: 0.1085
Epoch: 0005/0010 | Batch 0900/2547 | Loss: 0.1401
Epoch: 0005/0010 | Batch 1200/2547 | Loss: 0.1049
Epoch: 0005/0010 | Batch 1500/2547 | Loss: 0.0399
Epoch: 0005/0010 | Batch 1800/2547 | Loss: 0.1907
Epoch: 0005/0010 | Batch 2100/2547 | Loss: 0.0548
Epoch: 0005/0010 | Batch 2400/2547 | Loss: 0.0564


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.6725324172267129
Validation WER: 0.8832013633924654


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0006/0010 | Batch 0000/2547 | Loss: 0.9332
Epoch: 0006/0010 | Batch 0300/2547 | Loss: 0.1173
Epoch: 0006/0010 | Batch 0600/2547 | Loss: 0.0791
Epoch: 0006/0010 | Batch 0900/2547 | Loss: 0.1022
Epoch: 0006/0010 | Batch 1200/2547 | Loss: 0.0475
Epoch: 0006/0010 | Batch 1500/2547 | Loss: 0.0740
Epoch: 0006/0010 | Batch 1800/2547 | Loss: 0.0641
Epoch: 0006/0010 | Batch 2100/2547 | Loss: 0.0603
Epoch: 0006/0010 | Batch 2400/2547 | Loss: 0.1208


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.07287212245441983
Validation WER: 0.18346381166577205


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0007/0010 | Batch 0000/2547 | Loss: 0.0278
Epoch: 0007/0010 | Batch 0300/2547 | Loss: 0.0353
Epoch: 0007/0010 | Batch 0600/2547 | Loss: 0.0176
Epoch: 0007/0010 | Batch 0900/2547 | Loss: 0.0644
Epoch: 0007/0010 | Batch 1200/2547 | Loss: 0.0124
Epoch: 0007/0010 | Batch 1500/2547 | Loss: 0.0336
Epoch: 0007/0010 | Batch 1800/2547 | Loss: 0.0128
Epoch: 0007/0010 | Batch 2100/2547 | Loss: 0.0248
Epoch: 0007/0010 | Batch 2400/2547 | Loss: 0.0272


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.07107539426098575
Validation WER: 0.18086406036459762


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0008/0010 | Batch 0000/2547 | Loss: 0.0109
Epoch: 0008/0010 | Batch 0300/2547 | Loss: 0.0139
Epoch: 0008/0010 | Batch 0600/2547 | Loss: 0.0185
Epoch: 0008/0010 | Batch 0900/2547 | Loss: 0.1641
Epoch: 0008/0010 | Batch 1200/2547 | Loss: 0.1202
Epoch: 0008/0010 | Batch 1500/2547 | Loss: 0.0746
Epoch: 0008/0010 | Batch 1800/2547 | Loss: 0.0087
Epoch: 0008/0010 | Batch 2100/2547 | Loss: 0.0666
Epoch: 0008/0010 | Batch 2400/2547 | Loss: 0.0568


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.06591462965335806
Validation WER: 0.1695512273816187


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0009/0010 | Batch 0000/2547 | Loss: 0.0285
Epoch: 0009/0010 | Batch 0300/2547 | Loss: 0.0172
Epoch: 0009/0010 | Batch 0600/2547 | Loss: 0.0092
Epoch: 0009/0010 | Batch 0900/2547 | Loss: 0.0591
Epoch: 0009/0010 | Batch 1200/2547 | Loss: 0.0186
Epoch: 0009/0010 | Batch 1500/2547 | Loss: 0.0397
Epoch: 0009/0010 | Batch 1800/2547 | Loss: 0.0654
Epoch: 0009/0010 | Batch 2100/2547 | Loss: 0.0073
Epoch: 0009/0010 | Batch 2400/2547 | Loss: 0.0291


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.0691755150620948
Validation WER: 0.17318745022847917


  0%|          | 0/2547 [00:00<?, ?it/s]

Epoch: 0010/0010 | Batch 0000/2547 | Loss: 0.0051
Epoch: 0010/0010 | Batch 0300/2547 | Loss: 0.0162
Epoch: 0010/0010 | Batch 0600/2547 | Loss: 0.0472
Epoch: 0010/0010 | Batch 0900/2547 | Loss: 0.0255
Epoch: 0010/0010 | Batch 1200/2547 | Loss: 0.0771
Epoch: 0010/0010 | Batch 1500/2547 | Loss: 0.1258
Epoch: 0010/0010 | Batch 1800/2547 | Loss: 0.0287
Epoch: 0010/0010 | Batch 2100/2547 | Loss: 0.1226
Epoch: 0010/0010 | Batch 2400/2547 | Loss: 0.0044


  0%|          | 0/195 [00:00<?, ?it/s]

Validation CER: 0.0691366601482624
Validation WER: 0.17338457942174382


In [None]:
from tqdm.notebook import tqdm

optimizer = optim.AdamW(model.parameters(), lr=5e-5)
cer_vals = []
wer_vals = []

for epoch in range(10):  # loop over the dataset multiple times
   # train
   model.train()
   train_loss = 0.0
   for batch in tqdm(train_dataloader):
      # get the inputs
      #print (batch)
      for k,v in batch.items():
        batch[k] = v.to(device)

      # forward + backward + optimize
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      train_loss += loss.item()

   print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
    
   # evaluate
   model.eval()
   valid_cer = 0.0
   valid_wer = 0.0 
   with torch.no_grad():
     for batch in tqdm(eval_dataloader):
       # run batch generation
       outputs = model.generate(batch["pixel_values"].to(device))
       # compute metrics
       cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
       wer = compute_wer(pred_ids=outputs, label_ids=batch["labels"])
         
       valid_cer += cer
       valid_wer += wer
    
   cer_vals.append(valid_cer / len(eval_dataloader))
   wer_vals.append(valid_wer / len(eval_dataloader)) 
   print("Validation CER:", valid_cer / len(eval_dataloader))
   print("Validation WER:", valid_wer / len(eval_dataloader)) 

#model.save_pretrained(".")

  0%|          | 0/637 [00:00<?, ?it/s]

In [22]:
model.push_to_hub("trocr-base-stage1-42-batch-16", private=True)



model.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/skvoretss/trocr-base-stage1-42-batch-16/commit/e5d756444294f15933747fcfc76150c386118cfd', commit_message='Upload model', commit_description='', oid='e5d756444294f15933747fcfc76150c386118cfd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/skvoretss/trocr-base-stage1-42-batch-16', endpoint='https://huggingface.co', repo_type='model', repo_id='skvoretss/trocr-base-stage1-42-batch-16'), pr_revision=None, pr_num=None)

In [23]:
torch.cpu.memory_allocated(device)

AttributeError: module 'torch.cpu' has no attribute 'memory_allocated'

In [24]:
torch.cuda.max_memory_allocated(device)

2789782528

In [25]:
torch.cuda.max_memory_reserved(device)

7608467456

In [26]:
torch.cuda.memory_reserved(device)

3235905536