## 1Ô∏è‚É£ Environment Setup & Installation

In [None]:
# Check Python version and GPU
import sys
print(f"Python version: {sys.version}")

# Check CUDA availability
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,driver_version', '--format=csv'], 
                       capture_output=True, text=True)
print(f"\nGPU Info:\n{result.stdout}")

In [None]:
# Install all dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q transformers>=4.40.0 accelerate>=0.27.0 datasets>=2.18.0
!pip install -q peft>=0.10.0 bitsandbytes>=0.43.0
!pip install -q Pillow opencv-python tqdm numpy scipy scikit-learn
!pip install -q wandb tensorboard

# Install Flash Attention 2 (optional but recommended for speed)
!pip install -q flash-attn --no-build-isolation

print("\n‚úÖ All dependencies installed!")

## 2Ô∏è‚É£ Import Libraries

In [None]:
import os
import sys
import json
import torch
import warnings
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
from typing import Dict, List, Any, Optional
from dataclasses import dataclass

# Transformers
from transformers import (
    AutoModelForCausalLM,
    AutoProcessor,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
)
from transformers.trainer_callback import TrainerCallback

# PEFT for LoRA
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

# Dataset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# Suppress warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 3Ô∏è‚É£ Configuration

In [None]:
@dataclass
class TrainingConfig:
    """Training configuration for DeepSeek OCR fine-tuning"""
    
    # Model
    model_name: str = "deepseek-ai/DeepSeek-OCR"
    
    # Paths
    dataset_path: str = "dataset_multi_image.jsonl"
    images_dir: str = "images"
    output_dir: str = "./deepseek_ocr_finetuned"
    
    # Training hyperparameters (optimized for RTX 4090)
    num_epochs: int = 10
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 1
    gradient_accumulation_steps: int = 8  # Effective batch size = 8
    learning_rate: float = 2e-4
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    
    # LoRA config
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    
    # Quantization (4-bit for RTX 4090)
    use_4bit: bool = True
    bnb_4bit_compute_dtype: str = "bfloat16"
    bnb_4bit_quant_type: str = "nf4"
    use_nested_quant: bool = False
    
    # Other settings
    use_flash_attention: bool = True
    gradient_checkpointing: bool = True
    fp16: bool = False
    bf16: bool = True  # Use bf16 for RTX 4090
    max_seq_length: int = 4096
    seed: int = 42
    
    # Logging
    logging_steps: int = 1
    save_steps: int = 50
    eval_steps: int = 50
    use_wandb: bool = False  # Set to True if you want to use Weights & Biases
    wandb_project: str = "deepseek-ocr-invoice"

config = TrainingConfig()
print("‚úÖ Configuration loaded:")
print(f"   Model: {config.model_name}")
print(f"   Dataset: {config.dataset_path}")
print(f"   Output: {config.output_dir}")
print(f"   Epochs: {config.num_epochs}")
print(f"   Effective batch size: {config.per_device_train_batch_size * config.gradient_accumulation_steps}")
print(f"   LoRA rank: {config.lora_r}, alpha: {config.lora_alpha}")

## 4Ô∏è‚É£ Load and Validate Dataset

In [None]:
def load_and_validate_dataset(dataset_path: str, images_dir: str) -> List[Dict]:
    """Load JSONL dataset and validate all images exist"""
    
    print(f"üìÇ Loading dataset from: {dataset_path}")
    
    samples = []
    missing_images = []
    
    with open(dataset_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            
            try:
                sample = json.loads(line)
                
                # Validate required fields
                if 'images' not in sample or 'prompt' not in sample or 'response' not in sample:
                    print(f"‚ö†Ô∏è Line {line_num}: Missing required fields")
                    continue
                
                # Check all images exist
                valid_images = []
                for img_path in sample['images']:
                    full_path = Path(img_path)
                    if full_path.exists():
                        valid_images.append(str(full_path))
                    else:
                        missing_images.append(img_path)
                
                if len(valid_images) == len(sample['images']):
                    sample['images'] = valid_images
                    samples.append(sample)
                else:
                    print(f"‚ö†Ô∏è Line {line_num}: Some images missing")
                    
            except json.JSONDecodeError as e:
                print(f"‚ö†Ô∏è Line {line_num}: JSON parse error - {e}")
    
    print(f"\n‚úÖ Dataset loaded successfully!")
    print(f"   Total samples: {len(samples)}")
    
    # Statistics
    total_images = sum(len(s['images']) for s in samples)
    total_invoices = 0
    for s in samples:
        try:
            resp = json.loads(s['response'])
            total_invoices += len(resp.get('data', []))
        except:
            pass
    
    print(f"   Total images: {total_images}")
    print(f"   Total invoices: {total_invoices}")
    
    if missing_images:
        print(f"\n‚ö†Ô∏è Missing images ({len(missing_images)}):")
        for img in missing_images[:5]:
            print(f"   - {img}")
        if len(missing_images) > 5:
            print(f"   ... and {len(missing_images) - 5} more")
    
    return samples

# Load dataset
dataset = load_and_validate_dataset(config.dataset_path, config.images_dir)

# Show sample structure
print("\nüìã Sample entry structure:")
sample = dataset[0]
print(f"   Images: {len(sample['images'])} files")
print(f"   Prompt length: {len(sample['prompt'])} chars")
print(f"   Response length: {len(sample['response'])} chars")

In [None]:
# Display dataset summary
print("üìä Dataset Summary:")
print("=" * 80)
for i, sample in enumerate(dataset, 1):
    num_images = len(sample['images'])
    try:
        resp = json.loads(sample['response'])
        num_invoices = len(resp.get('data', []))
    except:
        num_invoices = "?"
    
    # Get PDF name from first image
    pdf_name = Path(sample['images'][0]).stem.replace('_page_1', '')[:45]
    print(f"{i:2}. {pdf_name:45} | {num_images:2} pages | {num_invoices:2} invoices")

print("=" * 80)
print(f"Total: {len(dataset)} samples ready for training")

## 5Ô∏è‚É£ Create PyTorch Dataset

In [None]:
class InvoiceOCRDataset(Dataset):
    """Dataset for DeepSeek OCR fine-tuning on invoices"""
    
    def __init__(
        self, 
        samples: List[Dict], 
        processor: Any,
        tokenizer: Any,
        max_length: int = 4096
    ):
        self.samples = samples
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.samples[idx]
        
        # Load images
        images = []
        for img_path in sample['images']:
            try:
                img = Image.open(img_path).convert('RGB')
                images.append(img)
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Create a blank image as fallback
                images.append(Image.new('RGB', (224, 224), color='white'))
        
        # Prepare prompt and response
        prompt = sample['prompt']
        response = sample['response']
        
        # Create full conversation
        full_text = f"{prompt}\n\n{response}"
        
        # Process with processor
        try:
            inputs = self.processor(
                images=images,
                text=full_text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length
            )
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            # Return a minimal valid tensor
            inputs = self.tokenizer(
                full_text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length
            )
        
        # Squeeze batch dimension and prepare labels
        result = {}
        for key, value in inputs.items():
            if isinstance(value, torch.Tensor):
                result[key] = value.squeeze(0)
            else:
                result[key] = value
        
        # Set labels (same as input_ids for causal LM)
        if 'input_ids' in result:
            result['labels'] = result['input_ids'].clone()
            # Mask prompt tokens (only train on response)
            prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
            prompt_len = len(prompt_tokens)
            result['labels'][:prompt_len] = -100  # Ignore prompt in loss
        
        return result

print("‚úÖ Dataset class defined")

## 6Ô∏è‚É£ Load Model with QLoRA

In [None]:
def load_model_and_processor(config: TrainingConfig):
    """Load DeepSeek OCR model with 4-bit quantization"""
    
    print(f"üì• Loading model: {config.model_name}")
    print("   This may take a few minutes...")
    
    # BitsAndBytes config for 4-bit quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=config.use_4bit,
        bnb_4bit_quant_type=config.bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=getattr(torch, config.bnb_4bit_compute_dtype),
        bnb_4bit_use_double_quant=config.use_nested_quant,
    )
    
    # Load processor
    print("   Loading processor...")
    processor = AutoProcessor.from_pretrained(
        config.model_name,
        trust_remote_code=True
    )
    
    # Load tokenizer
    print("   Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        config.model_name,
        trust_remote_code=True
    )
    
    # Ensure special tokens are set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.bos_token is None:
        tokenizer.bos_token = tokenizer.eos_token
    
    # Load model with quantization
    print("   Loading model with 4-bit quantization...")
    model_kwargs = {
        "quantization_config": bnb_config,
        "device_map": "auto",
        "trust_remote_code": True,
        "torch_dtype": torch.bfloat16,
    }
    
    # Add flash attention if available
    if config.use_flash_attention:
        try:
            model_kwargs["attn_implementation"] = "flash_attention_2"
            print("   Using Flash Attention 2")
        except:
            print("   Flash Attention not available, using default")
    
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        **model_kwargs
    )
    
    # Prepare for k-bit training
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=config.gradient_checkpointing
    )
    
    print(f"\n‚úÖ Model loaded successfully!")
    print(f"   Model dtype: {model.dtype}")
    print(f"   Model device: {model.device}")
    
    # Memory usage
    if torch.cuda.is_available():
        memory_used = torch.cuda.memory_allocated() / 1e9
        memory_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"   GPU memory: {memory_used:.1f}GB / {memory_total:.1f}GB")
    
    return model, processor, tokenizer

# Load model
model, processor, tokenizer = load_model_and_processor(config)

## 7Ô∏è‚É£ Apply LoRA Adapters

In [None]:
def apply_lora(model, config: TrainingConfig):
    """Apply LoRA adapters to the model"""
    
    print("üîß Applying LoRA adapters...")
    
    # Find all linear layer names (excluding lm_head)
    target_modules = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if 'lm_head' not in name:
                # Get just the last part of the name
                layer_name = name.split('.')[-1]
                if layer_name not in target_modules:
                    target_modules.append(layer_name)
    
    # Common target modules for transformer models
    common_targets = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
    target_modules = [m for m in common_targets if any(m in name for name, _ in model.named_modules())]
    
    if not target_modules:
        target_modules = ['q_proj', 'v_proj']  # Fallback
    
    print(f"   Target modules: {target_modules}")
    
    # LoRA configuration
    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=target_modules,
    )
    
    # Apply LoRA
    model = get_peft_model(model, lora_config)
    
    # Print trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_percent = 100 * trainable_params / total_params
    
    print(f"\n‚úÖ LoRA applied successfully!")
    print(f"   Trainable parameters: {trainable_params:,} ({trainable_percent:.2f}%)")
    print(f"   Total parameters: {total_params:,}")
    
    return model

# Apply LoRA
model = apply_lora(model, config)

## 8Ô∏è‚É£ Create Data Collator

In [None]:
class DataCollatorForOCR:
    """Custom data collator for OCR training with multi-image support"""
    
    def __init__(self, tokenizer, pad_token_id: int = None):
        self.tokenizer = tokenizer
        self.pad_token_id = pad_token_id or tokenizer.pad_token_id
        
    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        # Find all keys
        all_keys = set()
        for f in features:
            all_keys.update(f.keys())
        
        batch = {}
        
        for key in all_keys:
            values = [f.get(key) for f in features if key in f]
            
            if not values:
                continue
                
            if isinstance(values[0], torch.Tensor):
                # Stack tensors
                if all(v.shape == values[0].shape for v in values):
                    batch[key] = torch.stack(values)
                else:
                    # Pad tensors to same length
                    max_len = max(v.shape[0] for v in values)
                    padded = []
                    for v in values:
                        if v.shape[0] < max_len:
                            pad_size = max_len - v.shape[0]
                            if key == 'labels':
                                pad_value = -100
                            elif key == 'attention_mask':
                                pad_value = 0
                            else:
                                pad_value = self.pad_token_id
                            v = torch.cat([v, torch.full((pad_size,) + v.shape[1:], pad_value, dtype=v.dtype)])
                        padded.append(v)
                    batch[key] = torch.stack(padded)
            elif isinstance(values[0], list):
                batch[key] = values
            else:
                batch[key] = values
        
        return batch

# Create collator
data_collator = DataCollatorForOCR(tokenizer)
print("‚úÖ Data collator created")

## 9Ô∏è‚É£ Prepare Training Data

In [None]:
# Split dataset into train/val
train_samples, val_samples = train_test_split(
    dataset, 
    test_size=0.1,  # 10% for validation
    random_state=config.seed
)

print(f"üìä Dataset split:")
print(f"   Training samples: {len(train_samples)}")
print(f"   Validation samples: {len(val_samples)}")

# Create datasets
train_dataset = InvoiceOCRDataset(
    samples=train_samples,
    processor=processor,
    tokenizer=tokenizer,
    max_length=config.max_seq_length
)

val_dataset = InvoiceOCRDataset(
    samples=val_samples,
    processor=processor,
    tokenizer=tokenizer,
    max_length=config.max_seq_length
)

print(f"\n‚úÖ Datasets created:")
print(f"   Train dataset size: {len(train_dataset)}")
print(f"   Val dataset size: {len(val_dataset)}")

## üîü Setup Training

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=config.output_dir,
    num_train_epochs=config.num_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_ratio=config.warmup_ratio,
    weight_decay=config.weight_decay,
    max_grad_norm=config.max_grad_norm,
    
    # Precision
    fp16=config.fp16,
    bf16=config.bf16,
    
    # Logging
    logging_dir=f"{config.output_dir}/logs",
    logging_steps=config.logging_steps,
    logging_first_step=True,
    
    # Evaluation & Saving
    evaluation_strategy="steps" if len(val_samples) > 0 else "no",
    eval_steps=config.eval_steps,
    save_strategy="steps",
    save_steps=config.save_steps,
    save_total_limit=3,
    load_best_model_at_end=True if len(val_samples) > 0 else False,
    
    # Optimization
    optim="paged_adamw_32bit",
    lr_scheduler_type="cosine",
    
    # Other
    gradient_checkpointing=config.gradient_checkpointing,
    dataloader_pin_memory=True,
    remove_unused_columns=False,
    report_to="wandb" if config.use_wandb else "tensorboard",
    run_name=f"deepseek-ocr-invoice-{config.seed}",
    seed=config.seed,
)

print("‚úÖ Training arguments configured")
print(f"   Output directory: {training_args.output_dir}")
print(f"   Learning rate: {training_args.learning_rate}")
print(f"   Epochs: {training_args.num_train_epochs}")
print(f"   Batch size: {training_args.per_device_train_batch_size}")
print(f"   Gradient accumulation: {training_args.gradient_accumulation_steps}")

In [None]:
# Custom trainer for multi-image OCR
class DeepSeekOCRTrainer(Trainer):
    """Custom trainer that properly handles multi-image inputs"""
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Override compute_loss to handle the model's output format.
        """
        labels = inputs.pop("labels", None)
        
        # Forward pass
        outputs = model(**inputs)
        
        if labels is not None:
            # Get logits
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
            
            # Shift for causal LM
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Compute loss
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
        else:
            loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
        
        return (loss, outputs) if return_outputs else loss

# Create trainer
trainer = DeepSeekOCRTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset if len(val_samples) > 0 else None,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

print("‚úÖ Trainer created and ready!")

## 1Ô∏è‚É£1Ô∏è‚É£ Train the Model! üöÄ

In [None]:
# Initialize wandb if enabled
if config.use_wandb:
    import wandb
    wandb.init(
        project=config.wandb_project,
        name=f"deepseek-ocr-{config.seed}",
        config=vars(config)
    )

print("üöÄ Starting training...")
print("=" * 50)

# Clear CUDA cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Train!
train_result = trainer.train()

print("\n" + "=" * 50)
print("‚úÖ Training completed!")
print(f"   Total steps: {train_result.global_step}")
print(f"   Training loss: {train_result.training_loss:.4f}")

## 1Ô∏è‚É£2Ô∏è‚É£ Save the Model

In [None]:
# Save the final model
print("üíæ Saving model...")

# Save LoRA adapters
lora_output_dir = f"{config.output_dir}/lora_adapters"
trainer.model.save_pretrained(lora_output_dir)
tokenizer.save_pretrained(lora_output_dir)

print(f"\n‚úÖ Model saved to: {lora_output_dir}")
print("\nüìÅ Saved files:")
for f in Path(lora_output_dir).iterdir():
    print(f"   - {f.name}")

# Save training metrics
metrics_file = f"{config.output_dir}/training_metrics.json"
with open(metrics_file, 'w') as f:
    json.dump({
        "training_loss": train_result.training_loss,
        "global_step": train_result.global_step,
        "epochs": config.num_epochs,
        "samples": len(train_samples),
    }, f, indent=2)
print(f"   - training_metrics.json")

## 1Ô∏è‚É£3Ô∏è‚É£ Test Inference

In [None]:
# Test inference on a sample
print("üß™ Testing inference on a sample...")

# Get a test sample
test_sample = dataset[0]

# Load images
test_images = [Image.open(img_path).convert('RGB') for img_path in test_sample['images'][:3]]  # Limit to 3 for testing

# Extract just the prompt (without schema for cleaner input)
prompt = test_sample['prompt']

print(f"\nTest sample info:")
print(f"   Number of images: {len(test_images)}")
print(f"   Prompt length: {len(prompt)} chars")

# Process inputs
model.eval()
with torch.no_grad():
    inputs = processor(
        images=test_images,
        text=prompt,
        return_tensors="pt"
    ).to(model.device)
    
    # Generate
    print("\n‚è≥ Generating output...")
    outputs = model.generate(
        **inputs,
        max_new_tokens=2048,
        do_sample=False,
        num_beams=1,
    )
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract just the response part
    if prompt in generated_text:
        response = generated_text.split(prompt)[-1].strip()
    else:
        response = generated_text

print("\nüìÑ Generated output (first 1000 chars):")
print("-" * 50)
print(response[:1000])
print("-" * 50)

# Try to parse as JSON
try:
    parsed = json.loads(response)
    print("\n‚úÖ Output is valid JSON!")
except json.JSONDecodeError:
    print("\n‚ö†Ô∏è Output is not valid JSON (may need more training)")

## 1Ô∏è‚É£4Ô∏è‚É£ Cleanup and Summary

In [None]:
# Final summary
print("\n" + "=" * 60)
print("üéâ TRAINING COMPLETE!")
print("=" * 60)

print(f"""
üìä Training Summary:
   - Model: {config.model_name}
   - Training samples: {len(train_samples)}
   - Validation samples: {len(val_samples)}
   - Epochs: {config.num_epochs}
   - Final loss: {train_result.training_loss:.4f}
   - Total steps: {train_result.global_step}

üìÅ Output files:
   - LoRA adapters: {lora_output_dir}
   - Checkpoints: {config.output_dir}
   - Logs: {config.output_dir}/logs

üöÄ To load the fine-tuned model:
   from peft import PeftModel
   base_model = AutoModelForCausalLM.from_pretrained("{config.model_name}")
   model = PeftModel.from_pretrained(base_model, "{lora_output_dir}")
""")

# Clean up
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"\nüíæ Final GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Close wandb if used
if config.use_wandb:
    wandb.finish()
    print("üìä WandB run finished")