In [None]:

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import pytorch_lightning as pl
from PIL import Image
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import evaluate
from clearml import Task
import shutil

In [None]:
# This is in case to use ClearML (Local with docker) to Log the metrics
%env CLEARML_WEB_HOST=http://localhost:8080
%env CLEARML_API_HOST=http://localhost:8008
%env CLEARML_FILES_HOST=http://localhost:8081
%env CLEARML_API_ACCESS_KEY=AEBY191O3R1U4SGBDPLA
%env CLEARML_API_SECRET_KEY=OVvAzcKHtSfqP95jjMHgmgAvzDcSKIKRt5wv1hE1PerO5D3uiT
%env CLEARML_LOG_MODEL=False

In [3]:
HF_CACHE = "/home/ralvarez22/Documentos/llm_data/llm_cache"
TROCR_MODEL = "/home/ralvarez22/Documentos/llm_data/llm_cache/models--microsoft--trocr-large-stage1/snapshots/3c8ead8dfda428d914334169380bb546f770a300"
DEVICE = "cuda"


DATASETS_PATH = "../hand-cursive-trocr"
TRAIN_FILE = "train_metadata.json"
VALID_FILE = "valid_metadata.json"

BATCH_SIZE = 8 # Modify in case of CUDA OUT OF MEMORY
ACC_BATCH =  BATCH_SIZE * 4
LOGGING_STEPS = 1000

CKP_PATH = "../finetuned/trocr"
FINAL_MODEL_PATH = "../finetuned/trocr/final"
MODEL_CODENAME = "Terminus" # Model Codename versioning
MODEL_VERSION = 1 

LOG_DIR = "../clearml_logs_trocr"

EPOCHS = 5 # I use this value because it was only a Proof of concept test. With more Epochs, the accurancy (in theory) should be better
LR = 1e-5 # All the tutorials recommend 4e-5 or 5e-5, but, I couldn't get a good model, the model stopped learning at the epoch 20 or 25 and the Loss Graph begun to raise instead of go down

In [4]:
cer_metric = evaluate.load("cer")
os.makedirs(CKP_PATH, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

## Data loading and preparation

In [5]:
class OCRDataset(Dataset):
    def __init__(self, root_dir, df, processor):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # abrir la imagen y label
        df_data = self.df.iloc[idx]
        return (
            Image.open(os.path.join(self.root_dir, df_data["image"])).convert("RGB"),
            df_data["label"],
        )

In [6]:
TRAIN_DF = pd.read_json(os.path.join(DATASETS_PATH, TRAIN_FILE))
VALIDATION_DF = pd.read_json(os.path.join(DATASETS_PATH, VALID_FILE))

In [7]:

processor = TrOCRProcessor.from_pretrained(TROCR_MODEL, cache_dir=HF_CACHE, device_map=DEVICE)

In [8]:
def collate_function(batch):
    # Pad and process images and labels
    batch_images = [x[0] for x in batch]
    batch_labels = [x[1] for x in batch]
    pixel_values = processor(batch_images, return_tensors="pt").pixel_values.to(DEVICE)
    labels = processor.tokenizer(
        batch_labels, add_special_tokens=True, padding=True, return_tensors="pt"
    ).input_ids.to(DEVICE)
    # Clone the labels to avoid modifications in the original tensor
    input_labels = labels.clone()
    # Convert the EOS token to a padding token
    input_labels = torch.where(
        input_labels == processor.tokenizer.eos_token_id,
        processor.tokenizer.pad_token_id,
        input_labels,
    )
    # Because I shifted 1 item to the right, I need to add an additional token to preserve the dimensions
    to_concat = (
        torch.empty((1, input_labels.shape[0]), dtype=torch.long, device=DEVICE)
        .masked_fill(
            torch.ones(input_labels.shape[0], dtype=torch.bool, device=DEVICE),
            processor.tokenizer.pad_token_id,
        )
        .transpose(1, 0)
    )
    # This are the shifted labels
    shifted_labels = torch.cat((labels[:, 1:], to_concat), dim=1)
    # Create the Attention Mask for the decoder
    # shifted_mask = torch.ones_like(shifted_labels, device="cuda")
    # The attention is: 0 for pad token (or tokens to ignore), 1 for the other values
    shifted_mask = torch.where(
        shifted_labels == processor.tokenizer.pad_token_id, 0, 1
    ).to(DEVICE)

    encoding = {
        "pixel_values": pixel_values.squeeze(),
        "labels": batch_labels,
        "decoder_input": input_labels,
        "shift_mask": shifted_mask,
        "shifted_labels": shifted_labels
    }
    return encoding

In [9]:
TRAIN_DATASET = OCRDataset(DATASETS_PATH, TRAIN_DF, processor)
VALIDATION_DATASET = OCRDataset(DATASETS_PATH, VALIDATION_DF, processor)

TRAIN_DATASET = DataLoader(dataset=TRAIN_DATASET, batch_size=BATCH_SIZE, collate_fn=collate_function, pin_memory=False)
VALIDATION_DATASET = DataLoader(dataset=VALIDATION_DATASET, batch_size=BATCH_SIZE, collate_fn=collate_function, pin_memory=False)

In [None]:
next(iter(TRAIN_DATASET))

In [None]:
print("Number of training examples:", len(TRAIN_DATASET))
print("Number of validation examples:", len(VALIDATION_DATASET))

 ## Model loading

In [12]:
class HandCursiveTrOCR(pl.LightningModule):

    def __init__(
        self,
        model_path,
        image_processor,
        train_dataset,
        eval_dataset=None,
        learning_rate=4e-5,
        weight_decay=0.1,
        cache_dir="",
    ):
        super().__init__()

        self.model = VisionEncoderDecoderModel.from_pretrained(
            model_path, cache_dir=cache_dir
        )
        self.image_processor = image_processor

        self.model.generation_config.decoder_start_token_id = (
            self.image_processor.tokenizer.bos_token_id
        )

        self.model.generation_config.temperature = 0.4
        self.model.generation_config.max_length = 200
        self.model.generation_config.do_sample = True

        self.model.config.decoder.bos_token_id = (
            self.image_processor.tokenizer.bos_token_id
        )
        self.model.config.decoder.decoder_start_token_id = (
            self.image_processor.tokenizer.bos_token_id
        )
        self.model.config.decoder.eos_token_id = (
            self.image_processor.tokenizer.eos_token_id
        )
        self.model.config.decoder.pad_token_id = (
            self.image_processor.tokenizer.pad_token_id
        )
        self.model.config.encoder.bos_token_id = (
            self.image_processor.tokenizer.bos_token_id
        )
        self.model.config.encoder.decoder_start_token_id = (
            self.image_processor.tokenizer.bos_token_id
        )
        self.model.config.encoder.eos_token_id = (
            self.image_processor.tokenizer.eos_token_id
        )
        self.model.config.vocab_size = self.image_processor.tokenizer.vocab_size

        self.criterion = torch.nn.CrossEntropyLoss(
            ignore_index=self.image_processor.tokenizer.pad_token_id
        )
        self.train_dataset = train_dataset
        self.evaluation_dataset = eval_dataset
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, pixel_values, decoder_input_ids, decoder_mask=None):
        return self.model.forward(pixel_values, decoder_input_ids, decoder_mask)

    def common_step(self, batch):
        pixel_values = batch["pixel_values"]
        decoder_inputs = batch["decoder_input"]
        shifted_mask = batch["shift_mask"]
        shifted_labels = batch["shifted_labels"]
        model_output = self.forward(pixel_values, decoder_inputs, shifted_mask)
        logits = model_output.logits
        loss = self.criterion(
            logits.contiguous().view(-1, self.model.config.decoder.vocab_size),
            shifted_labels.contiguous().view(-1),
        )
        return loss, loss.item()

    def compute_cer_metric(self, batch):
        pixel_values = batch["pixel_values"]
        gt_labels = batch["labels"]
        model_predictions = self.model.generate(pixel_values)
        predicted_strings = self.image_processor.tokenizer.batch_decode(model_predictions, skip_special_tokens=True)
        return cer_metric.compute(predictions=predicted_strings, references=gt_labels)

    def training_step(self, batch):
        loss, loss_value = self.common_step(batch)
        self.log("train_loss", loss_value)
        return loss

    def validation_step(self, batch):
        loss, loss_value = self.common_step(batch)
        self.log("validation_loss", loss_value)
        cer_value = self.compute_cer_metric(batch)
        self.log("validation_cer", cer_value)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

    def train_dataloader(self):
        return self.train_dataset

    def val_dataloader(self):
        return self.evaluation_dataset

In [None]:
model = HandCursiveTrOCR(
    TROCR_MODEL,
    processor,
    TRAIN_DATASET,
    VALIDATION_DATASET,
    learning_rate=LR
)

## TRAINING

In [14]:
# Params for the logger
hyperparams = {
    "model_type": "TrOCR",
    "model_name": "microsoft/trocr-large-stage1",
    "codename": MODEL_CODENAME,
    "version": MODEL_VERSION,
    "model_learning_rate": LR,
    "epochs": EPOCHS,
    "acc_grad_batches": ACC_BATCH,
    "batch_size": BATCH_SIZE
}

In [None]:
trocr_total_params = sum(p.numel() for p in model.parameters())
trocr_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total params: {}\nTrainable params: {} M".format(trocr_total_params / 1e6, trocr_train_params/ 1e6))
hyperparams["total_params"] = trocr_total_params
hyperparams["trainable_params"] = trocr_train_params

In [None]:
tsk_name = "{}_V_{}".format(MODEL_CODENAME, str(MODEL_VERSION))
task = Task.init(task_name=tsk_name, project_name="HandCursive-I")
task.set_parameters(hyperparams)

In [17]:
log_path = os.path.join(LOG_DIR, MODEL_CODENAME, "version_{}".format(MODEL_VERSION))
shutil.rmtree(log_path, ignore_errors=True)
# I used Tensorboard Logger. If you too, please make sure to initiate the TB instance
logger = pl.loggers.TensorBoardLogger(save_dir=LOG_DIR, version=MODEL_VERSION, name=MODEL_CODENAME)
logger.log_hyperparams(hyperparams)

In [None]:
trainer = pl.Trainer(devices=1, accelerator="gpu", max_epochs=EPOCHS, precision="bf16-mixed", accumulate_grad_batches=ACC_BATCH, log_every_n_steps=LOGGING_STEPS, val_check_interval=LOGGING_STEPS, logger=logger)
# Run the training Cycle and log the metrics
trainer.fit(model)

In [None]:
# Save the final model
FINAL_CKP_PATH = os.path.join(CKP_PATH, MODEL_CODENAME, "V_{}".format(MODEL_VERSION))
# By default it saves the Safetensors type
model.model.save_pretrained(FINAL_CKP_PATH)
processor.save_pretrained(FINAL_CKP_PATH)