In [None]:
import random
from datasets import load_dataset
from transformers import AutoProcessor, VisionEncoderDecoderModel, AutoTokenizer
from PIL import Image
from datasets import Dataset
import torch
from torch.utils.tensorboard import SummaryWriter

dataset = load_dataset("sylvain471/trocr-3ch-fr-imc-2-dpo")
dataset = dataset["train"].train_test_split(test_size=0.05)
train_dataset = dataset["train"] 
eval_dataset = dataset["test"] 

# Load model and processor
model_name = "sylvain471/troc-medieval-fr-3ch-imc"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

max_length=32
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
# model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
# model.config.forced_eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_length
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4




In [None]:
# to freeze the vision encoder during training
for name,param in model.named_parameters():
    if name.startswith("encoder"):
        param.requires_grad = False
    if name.startswith("encoder.pooler"):
        param.requires_grad = True

In [None]:
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from jiwer import wer, cer


# Define collate function
def collate_fn(batch):
    max_target_length = max_length
    images = [
        processor(item["image"], return_tensors="pt")["pixel_values"] for item in batch
    ]
    chosen_texts = [
        item["chosen"] + " " + tokenizer.eos_token for item in batch
    ]  # Ensure EOS token
    rejected_texts = [item["rejected"] + " " + tokenizer.eos_token for item in batch]

    pixel_values = torch.cat(images).to(device)
    chosen_inputs = tokenizer(
        chosen_texts,
        padding="max_length",
        return_tensors="pt",
        truncation=True,
        max_length=max_target_length,
    ).to(device)
    rejected_inputs = tokenizer(
        rejected_texts,
        padding="max_length",
        return_tensors="pt",
        truncation=True,
        max_length=max_target_length,
    ).to(device)
    return pixel_values, chosen_inputs, rejected_inputs


# DPO reference free Loss Function
def dpo_rf_loss(model, pixel_values, chosen_inputs, rejected_inputs, beta=0.1):
    # Compute logits for both the trained and frozen models
    chosen_logits = model(
        pixel_values=pixel_values, labels=chosen_inputs["input_ids"]
    ).logits
    rejected_logits = model(
        pixel_values=pixel_values, labels=rejected_inputs["input_ids"]
    ).logits

    chosen_log_probs = F.log_softmax(chosen_logits, dim=-1)
    rejected_log_probs = F.log_softmax(rejected_logits, dim=-1)

    chosen_token_log_probs = torch.gather(
        chosen_log_probs, 2, chosen_inputs["input_ids"].unsqueeze(-1)
    ).squeeze(-1)
    rejected_token_log_probs = torch.gather(
        rejected_log_probs, 2, rejected_inputs["input_ids"].unsqueeze(-1)
    ).squeeze(-1)

    chosen_log_prob = chosen_token_log_probs.sum(dim=1)
    rejected_log_prob = rejected_token_log_probs.sum(dim=1)

    # Compute log-ratio for policy and reference model
    policy_ratio = chosen_log_prob - rejected_log_prob
    loss = -torch.mean(F.logsigmoid(beta * (policy_ratio)))
    return loss



@torch.no_grad()
def evaluate_model(model, processor, dataloader, num_samples=10):
    model.eval()
    total_cer = 0
    total_wer = 0
    total_reward_margin = 0
    count = 0

    for i, (pixel_values, chosen_inputs, rejected_inputs) in enumerate(dataloader):
        if count >= num_samples:
            break

        # Generate model predictions
        generated_ids = model.generate(pixel_values)
        predictions = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        # Ground truth (chosen texts)
        references = processor.tokenizer.batch_decode(chosen_inputs["input_ids"], skip_special_tokens=True)

        # Compute CER/WER
        batch_cer = sum(cer(ref, pred) for ref, pred in zip(references, predictions)) / len(references)
        batch_wer = sum(wer(ref, pred) for ref, pred in zip(references, predictions)) / len(references)

        # Compute log probabilities for chosen and rejected responses
        chosen_logits = model(pixel_values=pixel_values, labels=chosen_inputs["input_ids"]).logits
        rejected_logits = model(pixel_values=pixel_values, labels=rejected_inputs["input_ids"]).logits

        chosen_log_probs = torch.log_softmax(chosen_logits, dim=-1)
        rejected_log_probs = torch.log_softmax(rejected_logits, dim=-1)

        # Sum token log probabilities for each sequence
        chosen_token_log_probs = torch.gather(chosen_log_probs, 2, chosen_inputs["input_ids"].unsqueeze(-1)).squeeze(-1)
        rejected_token_log_probs = torch.gather(rejected_log_probs, 2, rejected_inputs["input_ids"].unsqueeze(-1)).squeeze(-1)

        chosen_log_prob = chosen_token_log_probs.sum(dim=1)  # Summing over sequence length
        rejected_log_prob = rejected_token_log_probs.sum(dim=1)

        # Compute reward margin
        reward_margin = (chosen_log_prob - rejected_log_prob).mean().item()

        total_cer += batch_cer
        total_wer += batch_wer
        total_reward_margin += reward_margin
        count += 1

    avg_cer = total_cer / count if count > 0 else 0
    avg_wer = total_wer / count if count > 0 else 0
    avg_reward_margin = total_reward_margin / count if count > 0 else 0
    return avg_cer, avg_wer, avg_reward_margin

In [None]:
### test evaluate_model
# (pixel_values, chosen_inputs, rejected_inputs)=next(iter(dataloader))
# print(len(pixel_values))

# generated_ids = model.generate(pixel_values)
# predictions = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [None]:
grad_accum_steps = 4  # Number of steps before optimizer update
eval_every = 15  # Evaluate every N steps
eval_samples = 8  # Number of samples for evaluation

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
epochs = 10
step = 0

# Set up TensorBoard
writer = SummaryWriter(log_dir="./runs/trocr_dpo")

# Create DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=eval_samples, shuffle=True, collate_fn=collate_fn)

global_step=0

# Log model structure to TensorBoard
example_batch = next(iter(train_dataloader))
example_pixel_values, example_chosen_inputs, _ = example_batch
# writer.add_graph(model, example_pixel_values)  # Log model architecture !! DOES NOT WORK

for epoch in range(epochs):
    model.train()
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
    total_loss = 0

    optimizer.zero_grad()  # Start with zeroed gradients

    for i, (pixel_values, chosen_inputs, rejected_inputs) in enumerate(progress_bar):
        loss = dpo_rf_loss(model, pixel_values, chosen_inputs, rejected_inputs,beta=0.2)  # Reference-free DPO
        loss = loss / grad_accum_steps  # Scale loss
        loss.backward()  # Backpropagate
        
        if (i + 1) % grad_accum_steps == 0:  # Update only every `grad_accum_steps` steps
            optimizer.step()
            optimizer.zero_grad()  # Reset gradients
            global_step += 1

            # Log to TensorBoard
            writer.add_scalar("Train/Loss", loss.item() * grad_accum_steps, global_step)
            # writer.add_scalar("Train/Reward Margin", reward_margin, global_step)


        total_loss += loss.item()
        progress_bar.set_postfix(loss=total_loss / (i + 1))
        step += 1

        # **Evaluation Step**
        if step % eval_every == 0:
            model.eval()
            avg_cer, avg_wer, avg_reward_margin = evaluate_model(model, processor, eval_dataloader, eval_samples)
            writer.add_scalar("Eval/CER", avg_cer, global_step)
            writer.add_scalar("Eval/WER", avg_wer, global_step)
            writer.add_scalar("Eval/Reward Margin", avg_reward_margin, global_step)
            print(f"\n[Step {step}] CER Loss: {avg_cer:.4f} | Reward margin: {avg_reward_margin:.4f}")
            model.train()  # Back to training

    print(f"Epoch {epoch+1} Loss: {total_loss / len(train_dataloader)}")

writer.close()

## Test trained model

In [None]:
idx=17

image = eval_dataset[idx]['image'].convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values.to("cuda"))
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(eval_dataset[idx]['chosen'])
print(generated_text)
image

## Push to hub

In [None]:
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv()
token=os.getenv("HF_TOKEN")

repo_name="<your_hf_repo>"

model.push_to_hub(repo_name,token=token)
processor.push_to_hub(repo_name,token=token)
tokenizer.push_to_hub(repo_name,token=token)