# DeepSeek OCR Fine-tuning for Invoice Extraction

This notebook fine-tunes the DeepSeek OCR model on Turkish e-invoice documents.

## Key Features:
- **QLoRA** for memory-efficient training on RTX 4090 (24GB VRAM)
- **4-bit quantization** with bitsandbytes
- **Gradient checkpointing** for reduced memory
- **Multi-page invoice support** - DeepSeek OCR handles multiple images per sample

## Requirements:
- RTX 4090 (24GB VRAM) or equivalent
- CUDA 11.8+
- Flash Attention 2

In [None]:
# Install requirements (run this first on Vast.ai)
!pip install -q torch transformers>=4.40.0 accelerate peft>=0.10.0 bitsandbytes>=0.43.0 wandb scikit-learn tqdm pillow torchvision
!pip install -q flash-attn --no-build-isolation

In [None]:
import os
import json
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import (
    Trainer, 
    TrainingArguments, 
    AutoTokenizer, 
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from PIL import Image, ImageOps
from torchvision import transforms
from typing import List, Dict, Optional, Tuple, Union
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## Configuration

Adjust these settings based on your GPU and requirements.

In [None]:
# ============== CONFIGURATION ==============
# Model
MODEL_NAME = "deepseek-ai/DeepSeek-OCR"
OUTPUT_DIR = "./deepseek_ocr_finetuned"

# Dataset
DATASET_PATH = "dataset_multi_image.jsonl"
IMAGES_DIR = "images"

# Training hyperparameters (optimized for RTX 4090 24GB)
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8  # Effective batch size = 8
NUM_EPOCHS = 10  # More epochs for small dataset
LEARNING_RATE = 2e-4  # Higher LR for LoRA
WARMUP_RATIO = 0.1
MAX_SEQ_LENGTH = 4096

# LoRA configuration
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# Quantization (set to True for 4-bit, False for bf16 full precision)
USE_4BIT_QUANTIZATION = True

# Logging
USE_WANDB = False  # Set to True to enable Weights & Biases logging
WANDB_PROJECT = "deepseek-ocr-invoice"

# Image processing
IMAGE_SIZE = 640
BASE_SIZE = 1024
PATCH_SIZE = 16
DOWNSAMPLE_RATIO = 4

In [None]:
# ============== IMAGE PROCESSING HELPERS ==============

def load_image(image_path: str) -> Optional[Image.Image]:
    """Load image with EXIF orientation correction."""
    try:
        image = Image.open(image_path)
        corrected_image = ImageOps.exif_transpose(image)
        return corrected_image.convert("RGB")
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """Find the closest aspect ratio from target ratios."""
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=9, image_size=640, use_thumbnail=False):
    """Dynamically preprocess image into tiles based on aspect ratio."""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) 
        for i in range(1, n + 1) 
        for j in range(1, n + 1) 
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images, target_aspect_ratio

class BasicImageTransform:
    """Basic image transformation with normalization."""
    def __init__(
        self, 
        mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
        std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
        normalize: bool = True
    ):
        self.mean = mean
        self.std = std
        transform_pipelines = [transforms.ToTensor()]
        if normalize:
            transform_pipelines.append(transforms.Normalize(mean=mean, std=std))
        self.transform = transforms.Compose(transform_pipelines)
    
    def __call__(self, x):
        return self.transform(x)

def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
    """Encode text with optional BOS/EOS tokens."""
    t = tokenizer.encode(text, add_special_tokens=False)
    bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1
    eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2
    
    if bos:
        t = [bos_id] + t
    if eos:
        t = t + [eos_id]
    return t

print("‚úÖ Image processing helpers loaded")

## Dataset and Data Collator

The dataset handles multi-page invoices where each sample can have multiple images.

In [None]:
# ============== DATASET ==============

class DeepSeekOCRDataset(Dataset):
    """Dataset for DeepSeek OCR fine-tuning with multi-image support."""
    
    def __init__(self, data_path: str, images_dir: str = "images"):
        self.data = []
        self.images_dir = images_dir
        
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                item = json.loads(line)
                # Verify all images exist
                valid = True
                for img_path in item['images']:
                    full_path = os.path.join(images_dir, os.path.basename(img_path)) if not os.path.exists(img_path) else img_path
                    if not os.path.exists(full_path):
                        print(f"‚ö†Ô∏è Missing image: {img_path}")
                        valid = False
                        break
                if valid:
                    self.data.append(item)
        
        print(f"‚úÖ Loaded {len(self.data)} samples from {data_path}")
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


class DeepSeekOCRDataCollator:
    """Collator for DeepSeek OCR that handles multi-image inputs."""
    
    def __init__(
        self, 
        tokenizer, 
        image_size: int = IMAGE_SIZE, 
        base_size: int = BASE_SIZE, 
        patch_size: int = PATCH_SIZE, 
        downsample_ratio: int = DOWNSAMPLE_RATIO,
        max_seq_length: int = MAX_SEQ_LENGTH
    ):
        self.tokenizer = tokenizer
        self.image_size = image_size
        self.base_size = base_size
        self.patch_size = patch_size
        self.downsample_ratio = downsample_ratio
        self.max_seq_length = max_seq_length
        self.image_transform = BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
        
        # Get image token
        self.image_token = '<image>'
        self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
        if self.image_token_id is None or self.image_token_id == tokenizer.unk_token_id:
            self.image_token_id = 128815  # Fallback for DeepSeek OCR
            
        # Get BOS token
        self.bos_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1
        self.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2
        self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
             
    def __call__(self, batch):
        input_ids_batch = []
        labels_batch = []
        images_batch = []
        images_seq_mask_batch = []
        images_spatial_crop_batch = []
        
        for item in batch:
            prompt = item['prompt']
            response = item['response']
            image_paths = item['images']
            
            # Load images
            images = []
            for img_path in image_paths:
                # Handle relative paths
                if not os.path.exists(img_path):
                    img_path = os.path.join(IMAGES_DIR, os.path.basename(img_path))
                img = load_image(img_path)
                if img:
                    images.append(img)
            
            # Split prompt by image tokens
            text_splits = prompt.split(self.image_token)
            
            tokenized_str = []
            images_seq_mask = []
            images_list = []
            images_crop_list = []
            current_images_spatial_crop = []
            
            for i, text_sep in enumerate(text_splits):
                # Tokenize text segment
                tokenized_sep = self.tokenizer.encode(text_sep, add_special_tokens=False)
                tokenized_str += tokenized_sep
                images_seq_mask += [False] * len(tokenized_sep)
                
                # Process image if available
                if i < len(images):
                    image = images[i]
                    
                    # Dynamic preprocessing for tiles
                    images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=self.image_size)
                    
                    # Global view
                    global_view = ImageOps.pad(
                        image, 
                        (self.base_size, self.base_size),
                        color=tuple(int(x * 255) for x in self.image_transform.mean)
                    )
                    images_list.append(self.image_transform(global_view).to(torch.bfloat16))
                    
                    width_crop_num, height_crop_num = crop_ratio
                    current_images_spatial_crop.append([width_crop_num, height_crop_num])
                    
                    # Add crop tiles if needed
                    if width_crop_num > 1 or height_crop_num > 1:
                        for crop_img in images_crop_raw:
                            images_crop_list.append(self.image_transform(crop_img).to(torch.bfloat16))
                            
                    # Calculate image token count
                    num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
                    num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
                    
                    # Base image tokens
                    tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
                    tokenized_image += [self.image_token_id]
                    
                    # Crop tokens if applicable
                    if width_crop_num > 1 or height_crop_num > 1:
                        tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
                            num_queries * height_crop_num
                        )
                                    
                    tokenized_str += tokenized_image
                    images_seq_mask += [True] * len(tokenized_image)
            
            # Add BOS token at the beginning
            tokenized_str = [self.bos_token_id] + tokenized_str
            images_seq_mask = [False] + images_seq_mask
            
            # Tokenize response
            response_tokens = self.tokenizer.encode(response, add_special_tokens=False) + [self.eos_token_id]
            
            # Combine input_ids
            input_ids = tokenized_str + response_tokens
            images_seq_mask += [False] * len(response_tokens)
            
            # Create labels (-100 for prompt tokens, actual tokens for response)
            labels = [-100] * len(tokenized_str) + response_tokens
            
            # Truncate if necessary
            if len(input_ids) > self.max_seq_length:
                input_ids = input_ids[:self.max_seq_length]
                labels = labels[:self.max_seq_length]
                images_seq_mask = images_seq_mask[:self.max_seq_length]
            
            input_ids_batch.append(torch.LongTensor(input_ids))
            labels_batch.append(torch.LongTensor(labels))
            images_seq_mask_batch.append(torch.tensor(images_seq_mask, dtype=torch.bool))
            
            if len(images_list) > 0:
                images_ori = torch.stack(images_list, dim=0)
                images_spatial_crop_tensor = torch.tensor(current_images_spatial_crop, dtype=torch.long)
                if images_crop_list:
                    images_crop = torch.stack(images_crop_list, dim=0)
                else:
                    images_crop = torch.zeros((1, 3, self.image_size, self.image_size), dtype=torch.bfloat16)
                
                images_batch.append((images_crop, images_ori))
                images_spatial_crop_batch.append(images_spatial_crop_tensor)

        # Pad sequences
        input_ids_padded = torch.nn.utils.rnn.pad_sequence(
            input_ids_batch, batch_first=True, padding_value=self.pad_token_id
        )
        labels_padded = torch.nn.utils.rnn.pad_sequence(
            labels_batch, batch_first=True, padding_value=-100
        )
        images_seq_mask_padded = torch.nn.utils.rnn.pad_sequence(
            images_seq_mask_batch, batch_first=True, padding_value=False
        )
        
        return {
            "input_ids": input_ids_padded,
            "labels": labels_padded,
            "images": images_batch,
            "images_seq_mask": images_seq_mask_padded,
            "images_spatial_crop": images_spatial_crop_batch
        }

print("‚úÖ Dataset and Collator classes defined")

## Custom Trainer for DeepSeek OCR

We need a custom trainer to properly handle the multi-image inputs.

In [None]:
# ============== CUSTOM TRAINER ==============

class DeepSeekOCRTrainer(Trainer):
    """Custom trainer that handles DeepSeek OCR's multi-image inputs."""
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Extract image-related inputs
        images = inputs.pop("images", None)
        images_seq_mask = inputs.pop("images_seq_mask", None)
        images_spatial_crop = inputs.pop("images_spatial_crop", None)
        
        # Forward pass with all inputs
        outputs = model(
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
            images=images,
            images_seq_mask=images_seq_mask,
            images_spatial_crop=images_spatial_crop,
        )
        
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

print("‚úÖ Custom Trainer defined")

In [None]:
# ============== LOAD MODEL AND TOKENIZER ==============

print(f"Loading model: {MODEL_NAME}")
print(f"Using 4-bit quantization: {USE_4BIT_QUANTIZATION}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Ensure pad token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Quantization config for 4-bit (QLoRA)
if USE_4BIT_QUANTIZATION:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        attn_implementation='flash_attention_2',
    )
    
    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation='flash_attention_2',
    )
    model.gradient_checkpointing_enable()

print(f"‚úÖ Model loaded successfully")
print(f"Model type: {type(model).__name__}")

In [None]:
# ============== APPLY LoRA ==============

# Find all linear layers for LoRA
def find_all_linear_names(model):
    """Find all linear layer names for LoRA targeting."""
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            names = name.split('.')
            lora_module_names.add(names[-1])
    
    # Remove output layer if present
    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    
    return list(lora_module_names)

# Get target modules
target_modules = find_all_linear_names(model)
print(f"Target modules for LoRA: {target_modules}")

# LoRA configuration
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=target_modules,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters
def print_trainable_parameters(model):
    """Print the number of trainable parameters."""
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable params: {trainable_params:,} || "
        f"All params: {all_param:,} || "
        f"Trainable %: {100 * trainable_params / all_param:.2f}%"
    )

print_trainable_parameters(model)
print("‚úÖ LoRA applied successfully")

In [None]:
# ============== LOAD DATASET ==============

# Load dataset
dataset = DeepSeekOCRDataset(DATASET_PATH, images_dir=IMAGES_DIR)

# Create data collator
collator = DeepSeekOCRDataCollator(
    tokenizer=tokenizer,
    image_size=IMAGE_SIZE,
    base_size=BASE_SIZE,
    patch_size=PATCH_SIZE,
    downsample_ratio=DOWNSAMPLE_RATIO,
    max_seq_length=MAX_SEQ_LENGTH,
)

# Test collator with one sample
print("\nüìã Testing collator with first sample...")
test_batch = collator([dataset[0]])
print(f"  Input IDs shape: {test_batch['input_ids'].shape}")
print(f"  Labels shape: {test_batch['labels'].shape}")
print(f"  Images seq mask shape: {test_batch['images_seq_mask'].shape}")
print(f"  Number of images: {len(test_batch['images'])}")
if len(test_batch['images']) > 0:
    print(f"  Image crops shape: {test_batch['images'][0][0].shape}")
    print(f"  Image ori shape: {test_batch['images'][0][1].shape}")
print("‚úÖ Collator test passed")

In [None]:
# ============== TRAINING ARGUMENTS ==============

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # Batch size settings
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    
    # Training settings
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    warmup_ratio=WARMUP_RATIO,
    lr_scheduler_type="cosine",
    
    # Precision
    bf16=True,
    
    # Memory optimization
    gradient_checkpointing=True,
    optim="adamw_torch_fused",  # Faster optimizer
    
    # Logging
    logging_steps=1,
    logging_first_step=True,
    report_to="wandb" if USE_WANDB else "none",
    
    # Saving
    save_steps=50,
    save_total_limit=3,
    
    # Other
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    dataloader_num_workers=0,  # Avoid multiprocessing issues
    
    # Disable find_unused_parameters for memory efficiency
    ddp_find_unused_parameters=False,
)

print("‚úÖ Training arguments configured")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Total steps: ~{len(dataset) * NUM_EPOCHS // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)}")

In [None]:
# ============== INITIALIZE WANDB (Optional) ==============

if USE_WANDB:
    import wandb
    wandb.init(
        project=WANDB_PROJECT,
        name=f"deepseek-ocr-invoice-{NUM_EPOCHS}ep",
        config={
            "model": MODEL_NAME,
            "lora_r": LORA_R,
            "lora_alpha": LORA_ALPHA,
            "learning_rate": LEARNING_RATE,
            "epochs": NUM_EPOCHS,
            "batch_size": BATCH_SIZE,
            "gradient_accumulation": GRADIENT_ACCUMULATION_STEPS,
            "quantization": "4bit" if USE_4BIT_QUANTIZATION else "bf16",
        }
    )
    print("‚úÖ Weights & Biases initialized")
else:
    print("‚ÑπÔ∏è Weights & Biases disabled. Set USE_WANDB=True to enable.")

In [None]:
# ============== CREATE TRAINER AND START TRAINING ==============

trainer = DeepSeekOCRTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

print("üöÄ Starting training...")
print(f"  Dataset size: {len(dataset)} samples")
print(f"  Using custom DeepSeekOCRTrainer for proper image handling")
print("-" * 50)

# Start training
trainer.train()

In [None]:
# ============== SAVE MODEL ==============

print("üíæ Saving model...")

# Save the LoRA adapter
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"‚úÖ Model saved to {OUTPUT_DIR}")
print("\nTo load the model later:")
print(f"""
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = AutoModelForCausalLM.from_pretrained(
    "{MODEL_NAME}",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model = PeftModel.from_pretrained(base_model, "{OUTPUT_DIR}")
tokenizer = AutoTokenizer.from_pretrained("{OUTPUT_DIR}")
""")

## Inference Test

Test the fine-tuned model on a sample invoice.

In [None]:
# ============== INFERENCE TEST ==============

# Set model to eval mode
model.eval()

# Test with first sample from dataset
test_sample = dataset[0]
test_images = [load_image(img_path if os.path.exists(img_path) else os.path.join(IMAGES_DIR, os.path.basename(img_path))) 
               for img_path in test_sample['images']]

print(f"Testing with invoice: {test_sample['images'][0]}")
print(f"Number of pages: {len(test_images)}")

# Note: Full inference requires proper handling of images through the model's processor
# This is a simplified test - for production use, use the model's built-in chat interface

print("\nüìÑ Expected output (first 500 chars):")
print(test_sample['response'][:500] + "...")

if USE_WANDB:
    wandb.finish()
    
print("\nüéâ Training complete!")