In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import Dataset, DatasetDict
from PIL import Image
import pandas as pd
import os


In [None]:
# Loading CSV file and preparing dataset
csv_path = "./dataset/labels_comma.csv"
images_folder = "./dataset/crops"

In [None]:


df = pd.read_csv(csv_path, engine="python", encoding="utf-8")

# Adding full image path for loading
df["image_path"] = df["image_file"].apply(lambda x: os.path.join(images_folder, x))

# Splitting data to train and test
train_df = df.sample(frac=0.8, random_state=42)
test_df = df.drop(train_df.index)

# Converting to HuggingFace datasets
train_ds = Dataset.from_pandas(train_df.reset_index(drop=True))
test_ds = Dataset.from_pandas(test_df.reset_index(drop=True))
dataset = DatasetDict({"train": train_ds, "Validation": test_ds})

In [None]:
# Loading Processor and model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.max_length = 32
model.config.early_stopping = True
model.config.num_beams = 4



In [None]:
# Preprocessing Function
def preprocess(example):
    image = Image.open(example["image_path"]).convert("RGB")

    # Processing image into model format
    pixel_values = processor(images=image, return_tensors="pt").pixel_values[0]

    # Tokenizing target text
    labels = processor.tokenizer(example["text"], padding="max_length", truncation=True, max_length=32).input_ids
    labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]
    return {"pixel_values": pixel_values, "labels": labels}

# Apply preprocessing
dataset = dataset.map(preprocess)

In [None]:


# === 4. Training Arguments ===
training_args = Seq2SeqTrainingArguments(
    output_dir="./ocr-nepali-trocr",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    num_train_epochs=5,
    learning_rate=5e-5,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    use_cpu=True,
   
)

# === 5. Trainer ===
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["Validation"],
)

# === 6. Train the model ===
trainer.train()

In [None]:
# === 7. Save the fine-tuned model ===
model.save_pretrained("./ocr-nepali-trocr")
processor.save_pretrained("./ocr-nepali-trocr")