In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

sajinamatya_ocr_donut_data_path = kagglehub.dataset_download('sajinamatya/ocr-donut-data')

print('Data source import complete.')


In [None]:
import os
import json
import torch
import gc
from PIL import Image
from datasets import Dataset
from transformers import DonutProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using cuda device
GPU: Tesla P100-PCIE-16GB
Available memory: 17.06 GB


In [None]:
# Define paths
train_metadata_path = "/kaggle/input/ocr-donut-data/data/train/metadata.jsonl"
val_metadata_path = "/kaggle/input/ocr-donut-data/data/val/metadata.jsonl"
test_metadata_path = "/kaggle/input/ocr-donut-data/data/test/metadata.jsonl"

In [None]:
def resize_image(image, max_size=(384, 384)):
    """Resize image while maintaining aspect ratio"""
    width, height = image.size
    if width > max_size[0] or height > max_size[1]:
        image.thumbnail(max_size, Image.LANCZOS)
    return image


In [None]:
model_name = "naver-clova-ix/donut-base"
processor = DonutProcessor.from_pretrained(model_name)


In [None]:
model = VisionEncoderDecoderModel.from_pretrained(
    model_name,

    low_cpu_mem_usage=True
)


In [None]:
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [None]:
model = model.to(device)

In [None]:
def load_data_from_metadata(metadata_path, max_samples=None):
    # Get base directory from metadata path
    base_dir = os.path.dirname(metadata_path)
    metadata = []
    processed = 0

    print(f"Reading metadata from {metadata_path}...")
    with open(metadata_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if max_samples and processed >= max_samples:
                print(f"Reached max samples limit ({max_samples})")
                break

            if not line.strip():
                continue

            try:
                item = json.loads(line)
                # Parse the ground_truth string into a JSON object
                gt = json.loads(item["ground_truth"])
                parsed = gt["gt_parse"]

                # Format the text in a structured way
                text = f"""citizenship_no: {parsed['citizenship_certificate_no']}
name: {parsed['full_name']}
sex: {parsed['sex']}
dob: {parsed['date_of_birth']['year']}-{parsed['date_of_birth']['month']}-{parsed['date_of_birth']['day']}
birth_place: {parsed['birth_place']['district']}, {parsed['birth_place']['vdc']}-{parsed['birth_place']['ward']}
address: {parsed['permament_address']['district']}, {parsed['permament_address']['vdc']}-{parsed['permament_address']['ward']}"""

                # Build complete image path
                image_path = os.path.join(base_dir, item["file_name"])

                # Verify image exists
                if os.path.exists(image_path):
                    metadata.append({
                        "image_path": image_path,
                        "text": text
                    })
                    processed += 1
                    if processed % 100 == 0:
                        print(f"Processed {processed} valid entries")
                else:
                    print(f"Warning: Image not found at {image_path}")

            except json.JSONDecodeError as e:
                print(f"Error parsing line {line_num} in {metadata_path}: {e}")
                continue
            except KeyError as e:
                print(f"Missing key in data at line {line_num}: {e}")
                continue

    if not metadata:
        raise ValueError(f"No valid data loaded from {metadata_path}")

    print(f"Successfully loaded {len(metadata)} entries from {metadata_path}")
    return metadata


In [None]:
def prepare_dataset(data, max_samples=None):
    images = []
    ground_truths = []

    if max_samples and len(data) > max_samples:
        print(f"Limiting dataset to {max_samples} samples (from {len(data)} available)")
        data = data[:max_samples]

    for i, item in enumerate(data):
        try:
            image = Image.open(item["image_path"]).convert("RGB")
            # Resize images to save memory
            image = resize_image(image)
            ground_truth = item["text"]
            images.append(image)
            ground_truths.append(ground_truth)

            if (i+1) % 100 == 0:
                print(f"Prepared {i+1}/{len(data)} images")

        except Exception as e:
            print(f"Error processing {item['image_path']}: {e}")
            continue

    if not images:
        raise ValueError("No valid images processed")

    return Dataset.from_dict({
        "image": images,
        "ground_truth": ground_truths
    })


In [None]:
def preprocess_function(examples):
    try:
        # Process images
        pixel_values = processor(
            examples["image"],
            padding="max_length",
            max_length=256, # Reduced sequence length
            return_tensors="pt",
        ).pixel_values

        # Process text
        task_prompt = "<s_ocr>"
        decoder_input_ids = processor.tokenizer(
            [task_prompt + gt for gt in examples["ground_truth"]],
            padding="max_length",
            max_length=256, # Reduced sequence length
            truncation=True,
            return_tensors="pt",
        ).input_ids

        labels = decoder_input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels,
        }
    except Exception as e:
        print(f"Error in preprocessing: {e}")
        raise e


In [None]:
# Load and prepare train data with memory limits
print("Loading training data...")
train_max_samples = 200  # Adjust based on your dataset size and memory constraints
train_data = load_data_from_metadata(train_metadata_path, max_samples=train_max_samples)


print("Preparing training dataset...")
train_dataset = prepare_dataset(train_data)
del train_data


print("Processing training dataset...")
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    batch_size=2,  # Very small batch for preprocessing
    remove_columns=train_dataset.column_names
)


# Load and prepare validation data with memory limits
print("Loading validation data...")
val_max_samples = 200  # Adjust based on your dataset size and memory constraints
val_data = load_data_from_metadata(val_metadata_path, max_samples=val_max_samples)


print("Preparing validation dataset...")
val_dataset = prepare_dataset(val_data)
del val_data

print("Processing validation dataset...")
val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    batch_size=2,  # Very small batch for preprocessing
    remove_columns=val_dataset.column_names
)



In [None]:

training_args = Seq2SeqTrainingArguments(
    output_dir="/kaggle/working/model",
    evaluation_strategy="steps",
    eval_steps=200,
    learning_rate=5e-5,
    per_device_train_batch_size=1, # Reduced batch size
    per_device_eval_batch_size=1, # Reduced batch size
    gradient_accumulation_steps=4, # Reduced gradient accumulation steps
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True, # Enable mixed precision training
    save_strategy="steps",
    save_steps=200,
    logging_dir="./logs",
    logging_steps=50,
    dataloader_num_workers=0,
    gradient_checkpointing=True,
    ddp_find_unused_parameters=False,
    optim="adamw_torch",
    max_grad_norm=0.5,  # Add gradient clipping
    report_to="none"  # Disable reporting to save memory
)

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# Train model
print("Starting training...")
trainer.train()
trainer.save_model()


In [None]:
model.save_pretrained("/kaggle/working/pretrain")
processor.save_pretrained("/kaggle/working/pretrain")
