In [1]:
!nvidia-smi

Fri Mar 29 19:12:35 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off | 00000000:2D:00.0 Off |                  N/A |
|  0%   37C    P3              24W / 320W |     89MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, TrainingArguments, Trainer
from torch.utils.data.dataset import random_split
class SpanishDocumentsDataset(Dataset):
    def __init__(self, image_dir, text_dir, processor):
        self.image_dir = image_dir
        self.text_dir = text_dir
        self.processor = processor
        self.filenames = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        image_file = self.filenames[idx]
        # Derive the corresponding text file name by changing the extension
        text_file = image_file.replace('.jpg', '.txt')
        
        image_path = os.path.join(self.image_dir, image_file)
        text_path = os.path.join(self.text_dir, text_file)
        
        image = Image.open(image_path).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()

        with open(text_path, 'r', encoding='utf-8') as file:
            text = file.read().strip()
        labels = self.processor.tokenizer(text, return_tensors="pt").input_ids.squeeze()
        
        # Treat padding specially for label calculation, if necessary
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        
        return {"pixel_values": pixel_values, "labels": labels}



image_dir = 'line_data/line_images'
text_dir = 'line_data/line_texts'

def collate_fn(batch):
    pixel_values = [item['pixel_values'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Padding value for labels should be -100 to ignore tokens during loss calculation
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)
    pixel_values = torch.stack(pixel_values)
    
    return {"pixel_values": pixel_values, "labels": labels}



# Initialize processor and model with correct configurations
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.ignore_pad_token_for_loss = True  # Ensure pad tokens are ignored during loss calculation

# Prepare dataset and data loader
dataset = SpanishDocumentsDataset(image_dir, text_dir, processor=processor)

dataset_size = len(dataset)
eval_size = int(dataset_size * 0.1)  # 10% for evaluation
train_size = dataset_size - eval_size  # Remaining for training

# Split the dataset
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=8, collate_fn=collate_fn)



Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# # To see how the data is loaded, let's fetch a batch from the DataLoader
for batch in train_loader:
    print(batch['pixel_values'].shape, batch['labels'].shape)
    break  # Just show the first batch for demonstration

torch.Size([8, 3, 384, 384]) torch.Size([8, 16])


In [10]:
import Levenshtein as lev

def compute_cer(decoded_preds, decoded_labels):
    """
    Compute the Character Error Rate (CER).
    CER is defined as the edit distance between the predicted and true sequences
    divided by the length of the true sequence.
    """
    total_edit_distance = 0
    total_length = 0
    
    for pred, label in zip(decoded_preds, decoded_labels):
        total_edit_distance += lev.distance(pred, label)
        total_length += len(label)
    
    cer = total_edit_distance / total_length if total_length > 0 else 0
    return cer

def compute_wer(decoded_preds, decoded_labels):
    """
    Compute the Word Error Rate (WER).
    WER is defined as the edit distance between the predicted and true sequences
    of words divided by the number of words in the true sequence.
    """
    total_edit_distance = 0
    total_words = 0
    
    for pred, label in zip(decoded_preds, decoded_labels):
        pred_words = pred.split()
        label_words = label.split()
        
        total_edit_distance += lev.distance(pred_words, label_words)
        total_words += len(label_words)
    
    wer = total_edit_distance / total_words if total_words > 0 else 0
    return wer


In [11]:
from datasets import load_metric
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW

cer_metric = load_metric("cer")
wer_metric = load_metric("wer")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Extract the logits from the tuple if necessary
    if isinstance(logits, tuple):
        logits = logits[0]

    # Convert logits to the most likely token IDs
    predictions = logits.argmax(-1)
    
    # Decode predictions to text
    decoded_preds = processor.batch_decode(predictions, skip_special_tokens=True)
    
    # Prepare labels for decoding
    decoded_labels = []
    for label in labels:
        # Filter out -100 values which are used for padding/ignored indices
        label_filtered = [token for token in label if token != -100]
        decoded_label = processor.decode(label_filtered, skip_special_tokens=True)
        decoded_labels.append(decoded_label)

    cer = compute_cer(decoded_preds, decoded_labels)  
    wer = compute_wer(decoded_preds, decoded_labels) 

    return {"cer": cer, "wer": wer}


total_train_steps = (len(train_dataset) // 8) * 20

optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=500, 
                                            num_training_steps=total_train_steps)


# TrainingArguments
training_args = TrainingArguments(
    output_dir="./trocr_finetuned",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    num_train_epochs=20,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=50,
    save_steps=50,
    warmup_steps=500,
    weight_decay=0.01,
    save_total_limit=2,
    load_best_model_at_end=True,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  
    eval_dataset=eval_dataset,  
    optimizers=(optimizer, scheduler),
    data_collator=collate_fn,
    tokenizer=processor.feature_extractor,
    compute_metrics=compute_metrics, 
)

# Train the model
trainer.train()

# Save the fine-tuned model and processor
trainer.save_model("./trocr_finetuned")
processor.save_pretrained("./trocr_finetuned")


  cer_metric = load_metric("cer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Cer,Wer
50,1.9763,1.183152,0.210676,0.406832
100,0.5319,0.441966,0.080636,0.217391
150,0.3324,0.414173,0.120954,0.23913
200,0.3413,0.365381,0.124929,0.26087
250,0.4334,0.502178,0.089154,0.214286
300,1.2797,2.216505,0.33163,0.649068
350,0.2555,0.56537,0.141965,0.322981
400,0.3079,0.535151,0.097672,0.257764
450,0.5731,0.681593,0.120386,0.31677
500,0.782,0.893538,0.158433,0.357143


There were missing keys in the checkpoint model loaded: ['decoder.output_projection.weight'].


[]

In [5]:
import re
import os
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

model_path = "./trocr_finetuned"
processor_path = "./trocr_finetuned"

# Load the fine-tuned model and processor
processor = TrOCRProcessor.from_pretrained(processor_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)

# Function to generate text for a single image segment
def generate_text_from_image_segment(image_path):
    image = Image.open(image_path).convert("RGB")
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Base directory where the page folders are located
base_dir = "test_line_segments" 


def sort_key(filename):
    """
    Custom sort function to extract the segment number from the filename
    and use it as the key for sorting.
    """
    match = re.search(r"line_segment_(\d+)\.jpg", filename)
    if match:
        return int(match.group(1))
    return -1  # Return -1 if the pattern doesn't match


# Iterate through each page's folder
for page_folder in sorted(os.listdir(base_dir)):
    page_path = os.path.join(base_dir, page_folder)
    if os.path.isdir(page_path):
        print(f"Processing {page_folder}:")
        page_texts = []

        # Sort the line segment images numerically based on the segment number
        line_segment_images = sorted(os.listdir(page_path), key=sort_key)

        # Iterate through each sorted line segment in the page folder
        for line_segment_image in line_segment_images:
            if line_segment_image.endswith('.jpg'):
                line_segment_path = os.path.join(page_path, line_segment_image)
                line_text = generate_text_from_image_segment(line_segment_path)
                page_texts.append(line_text)
                print(f"  {line_segment_image}: {line_text}")

        # Compile and display the full page's text
        full_page_text = "\n".join(page_texts)
        print(f"\nFull text for {page_folder}:")
        print(full_page_text)
        print("\n" + "="*50 + "\n")


Processing page_26:




  page_26_line_segment_0.jpg: Nobleza Vrvuosa.
  page_26_line_segment_1.jpg: Si por euitar un pegado mortal
  page_26_line_segment_2.jpg: queys de poner vuestra vida en pe-
  page_26_line_segment_3.jpg: ligro, atriesgalda, que es el mejor
  page_26_line_segment_4.jpg: empleo que della podeys hazer, y
  page_26_line_segment_5.jpg: de vuestra hazienda para este sin en
  page_26_line_segment_6.jpg: redemir cautuos, y sacar mugenes
  page_26_line_segment_7.jpg: de pegado, dotandolas liberalmen-
  page_26_line_segment_8.jpg: Caton dixo, nunca hagas el bien
  page_26_line_segment_9.jpg: porque se sepa, dad pues vos sin
  page_26_line_segment_10.jpg: bueno a cualquiera obra, con que,
  page_26_line_segment_11.jpg: huyreys de la hijocresia, pero tam-
  page_26_line_segment_12.jpg: poco escondays las que han de ser
  page_26_line_segment_13.jpg: de buen exemplo, pues es obliga-
  page_26_line_segment_14.jpg: ion de personas tales el darlo, y
  page_26_line_segment_15.jpg: Ontario tentación en a