# Fine-tune DeepSeek OCR

## Install Dependencies

In [None]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install jiwer
!pip install einops addict easydict

# Install the fast transfer library
!pip install -q hf_transfer

## Download DeepSeek-OCR Model

In [None]:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from huggingface_hub import snapshot_download
snapshot_download("unsloth/DeepSeek-OCR", local_dir = "deepseek_ocr")

## Prepare Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
!unzip /content/drive/MyDrive/Deepseek_OCR_FT/UIT_HWDB_line.zip -d /content/
!unzip /content/drive/MyDrive/Deepseek_OCR_FT/UIT_HWDB_paragraph.zip -d /content/
!unzip /content/drive/MyDrive/Deepseek_OCR_FT/UIT_HWDB_word.zip -d /content/

Import

In [None]:
import random
import unicodedata
import gc
import torch
import traceback
from peft import PeftModel
from unsloth import FastVisionModel
from transformers import AutoModel
from jiwer import cer


## Generate Training Data JSON

In [None]:
import os
import json
import random
from pathlib import Path

def generate_json(base_path: str, output_dir: str = "."):
    """
    Generate train.json and test.json from UIT_HWDB datasets

    Args:
        base_path: Path containing UIT_HWDB_line, UIT_HWDB_paragraph, UIT_HWDB_word folders
        output_dir: Directory to save train.json and test.json
    """
    base_path = Path(base_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)

    datasets = ["UIT_HWDB_line", "UIT_HWDB_paragraph", "UIT_HWDB_word"]

    final_train_data = []
    final_test_data = []

    for dataset_name in datasets:
        dataset_path = base_path / dataset_name
        if not dataset_path.exists():
            print(f"Warning: {dataset_path} not found, skipping...")
            continue

        # Temporary lists for this specific dataset
        ds_train_data = []
        ds_test_data = []

        for split, data_list in [("train", ds_train_data), ("test", ds_test_data)]:
            data_dir = dataset_path / f"{split}_data"
            if not data_dir.exists():
                continue

            # Iterate through writer folders
            for writer_folder in sorted(data_dir.iterdir()):
                if not writer_folder.is_dir():
                    continue

                label_file = writer_folder / "label.json"
                if not label_file.exists():
                    continue

                with open(label_file, "r", encoding="utf-8") as f:
                    labels = json.load(f)

                for image_name, text in labels.items():
                    image_path = writer_folder / image_name
                    if image_path.exists():
                        data_list.append({
                            "image": str(image_path),
                            "label": text
                        })

        # Shuffle and slice THIS dataset
        random.seed(42)
        random.shuffle(ds_train_data)
        random.shuffle(ds_test_data)

        # Apply specific sampling logic
        if dataset_name == "UIT_HWDB_line":
            # 50% of Line
            train_cutoff = int(len(ds_train_data) * 0.5)
            test_cutoff = len(ds_test_data)
        elif dataset_name == "UIT_HWDB_paragraph":
            # 100% of Paragraph
            train_cutoff = int(len(ds_train_data))

            test_cutoff = min(20, len(ds_test_data))
        elif dataset_name == "UIT_HWDB_word":
            # 1000 of Word
            train_cutoff = min(1000, len(ds_train_data))

            test_cutoff = min(200, len(ds_test_data))
        else:
            train_cutoff = 0
            test_cutoff = 0

        final_train_data.extend(ds_train_data[:train_cutoff])
        final_test_data.extend(ds_test_data[:test_cutoff])

        print(f"{dataset_name}: Used {train_cutoff}/{len(ds_train_data)} train, {test_cutoff}/{len(ds_test_data)} test samples.")

    # Final shuffle of the combined dataset
    random.seed(42)
    random.shuffle(final_train_data)
    random.shuffle(final_test_data)

    # Save to JSON files
    with open(output_dir / "train.json", "w", encoding="utf-8") as f:
        json.dump(final_train_data, f, ensure_ascii=False, indent=2)

    with open(output_dir / "test.json", "w", encoding="utf-8") as f:
        json.dump(final_test_data, f, ensure_ascii=False, indent=2)

    print(f"train.json: {len(final_train_data)} samples")
    print(f"test.json: {len(final_test_data)} samples")

    return final_train_data, final_test_data

# Generate the JSON files
# Using current directory as base_path since /content does not exist locally
train_data, test_data = generate_json(
    base_path=".",
    output_dir="."
)

## Convert to Conversation Format

We need to convert the dataset into a conversation format that the model understands. Each sample will be a conversation between a user (providing the image) and an assistant (providing the text label).

In [None]:
from PIL import Image
instruction = "<image>\nFree OCR. "

def convert_to_conversation(sample):
    """Convert dataset sample to conversation format"""
    image_path = sample['image']
    # Lazy loading: Store path instead of loading image immediately
    # pil_image = Image.open(image_path).convert("RGB")
    conversation = [
        {
            "role": "<|User|>",
            "content": instruction,
            "images": [image_path] # Pass path string
        },
        {
            "role": "<|Assistant|>",
            "content": sample["label"]
        },
    ]
    return {"messages": conversation}

In [None]:
import json

with open('train.json', 'r', encoding='utf-8') as f:
    dataset = json.load(f)

print(f"Loaded {len(dataset)} samples into train_data_raw.")

In [None]:
converted_dataset = []
for sample in dataset:
    converted_dataset.append(convert_to_conversation(sample))

print(f"Converted {len(converted_dataset)} samples to conversation format.")
del dataset

In [None]:
converted_dataset[0]

## Data Collator

We need a custom data collator to handle the image and text inputs for the DeepSeek-OCR model. This collator processes the images (resizing, cropping) and tokenizes the text.

In [None]:
# @title Create datacollator

import torch
import math
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple
from PIL import Image, ImageOps
from torch.nn.utils.rnn import pad_sequence
import io

from deepseek_ocr.modeling_deepseekocr import (
    format_messages,
    text_encode,
    BasicImageTransform,
    dynamic_preprocess,
)

@dataclass
class DeepSeekOCRDataCollator:
    """
    Args:
        tokenizer: Tokenizer
        model: Model
        image_size: Size for image patches (default: 640)
        base_size: Size for global view (default: 1024)
        crop_mode: Whether to use dynamic cropping for large images
        train_on_responses_only: If True, only train on assistant responses (mask user prompts)
    """
    tokenizer: Any
    model: Any
    image_size: int = 640
    base_size: int = 1024
    crop_mode: bool = True
    image_token_id: int = 128815
    train_on_responses_only: bool = True

    def __init__(
        self,
        tokenizer,
        model,
        image_size: int = 640,
        base_size: int = 1024,
        crop_mode: bool = True,
        train_on_responses_only: bool = True,
    ):
        self.tokenizer = tokenizer
        self.model = model
        self.image_size = image_size
        self.base_size = base_size
        self.crop_mode = crop_mode
        self.image_token_id = 128815
        self.dtype = model.dtype  # Get dtype from model
        self.train_on_responses_only = train_on_responses_only

        self.image_transform = BasicImageTransform(
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
            normalize=True
        )
        self.patch_size = 16
        self.downsample_ratio = 4

        # Get BOS token ID from tokenizer
        if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:
            self.bos_id = tokenizer.bos_token_id
        else:
            self.bos_id = 0
            print(f"Warning: tokenizer has no bos_token_id, using default: {self.bos_id}")

    def deserialize_image(self, image_data) -> Image.Image:
        """Convert image data (bytes dict, PIL Image, or str path) to PIL Image in RGB mode"""
        if isinstance(image_data, str):
            return Image.open(image_data).convert("RGB")
        elif isinstance(image_data, Image.Image):
            return image_data.convert("RGB")
        elif isinstance(image_data, dict) and 'bytes' in image_data:
            image_bytes = image_data['bytes']
            image = Image.open(io.BytesIO(image_bytes))
            return image.convert("RGB")
        else:
            raise ValueError(f"Unsupported image format: {type(image_data)}")

    def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int:
        """Calculate the number of tokens this image will generate"""
        num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
        num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)

        width_crop_num, height_crop_num = crop_ratio

        if self.crop_mode:
            img_tokens = num_queries_base * num_queries_base + 1
            if width_crop_num > 1 or height_crop_num > 1:
                img_tokens += (num_queries * width_crop_num + 1) * (num_queries * height_crop_num)
        else:
            img_tokens = num_queries * num_queries + 1

        return img_tokens

    def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]:
        """
        Process a single image based on crop_mode and size thresholds

        Returns:
            Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio)
        """
        images_list = []
        images_crop_list = []
        images_spatial_crop = []

        if self.crop_mode:
            # Determine crop ratio based on image size
            if image.size[0] <= 640 and image.size[1] <= 640:
                crop_ratio = (1, 1)
                images_crop_raw = []
            else:
                images_crop_raw, crop_ratio = dynamic_preprocess(
                    image, min_num=2, max_num=9,
                    image_size=self.image_size, use_thumbnail=False
                )

            # Process global view with padding
            global_view = ImageOps.pad(
                image, (self.base_size, self.base_size),
                color=tuple(int(x * 255) for x in self.image_transform.mean)
            )
            images_list.append(self.image_transform(global_view).to(self.dtype))

            width_crop_num, height_crop_num = crop_ratio
            images_spatial_crop.append([width_crop_num, height_crop_num])

            # Process local views (crops) if applicable
            if width_crop_num > 1 or height_crop_num > 1:
                for crop_img in images_crop_raw:
                    images_crop_list.append(
                        self.image_transform(crop_img).to(self.dtype)
                    )

            # Calculate image tokens
            num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
            num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)

            tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
            tokenized_image += [self.image_token_id]

            if width_crop_num > 1 or height_crop_num > 1:
                tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
                    num_queries * height_crop_num)

        else:  # crop_mode = False
            crop_ratio = (1, 1)
            images_spatial_crop.append([1, 1])

            # For smaller base sizes, resize; for larger, pad
            if self.base_size <= 640:
                resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS)
                images_list.append(self.image_transform(resized_image).to(self.dtype))
            else:
                global_view = ImageOps.pad(
                    image, (self.base_size, self.base_size),
                    color=tuple(int(x * 255) for x in self.image_transform.mean)
                )
                images_list.append(self.image_transform(global_view).to(self.dtype))

            num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
            tokenized_image = ([self.image_token_id] * num_queries + [self.image_token_id]) * num_queries
            tokenized_image += [self.image_token_id]

        return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio

    def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]:
            """
            Process a single conversation into model inputs.
            """

            # --- 1. Setup ---
            images = []
            for message in messages:
                if "images" in message and message["images"]:
                    for img_data in message["images"]:
                        if img_data is not None:
                            pil_image = self.deserialize_image(img_data)
                            images.append(pil_image)

            if not images:
                raise ValueError("No images found in sample. Please ensure all samples contain images.")

            tokenized_str = []
            images_seq_mask = []
            images_list, images_crop_list, images_spatial_crop = [], [], []

            prompt_token_count = -1 # Index to start training
            assistant_started = False
            image_idx = 0

            # Add BOS token at the very beginning
            tokenized_str.append(self.bos_id)
            images_seq_mask.append(False)

            for message in messages:
                role = message["role"]
                content = message["content"]

                # Check if this is the assistant's turn
                if role == "<|Assistant|>":
                    if not assistant_started:
                        # This is the split point. All tokens added *so far*
                        # are part of the prompt.
                        prompt_token_count = len(tokenized_str)
                        assistant_started = True

                    # Append the EOS token string to the *end* of assistant content
                    content = f"{content.strip()} {self.tokenizer.eos_token}"

                # Split this message's content by the image token
                text_splits = content.split('<image>')

                for i, text_sep in enumerate(text_splits):
                    # Tokenize the text part
                    tokenized_sep = text_encode(self.tokenizer, text_sep, bos=False, eos=False)
                    tokenized_str.extend(tokenized_sep)
                    images_seq_mask.extend([False] * len(tokenized_sep))

                    # If this text is followed by an <image> tag
                    if i < len(text_splits) - 1:
                        if image_idx >= len(images):
                            raise ValueError(
                                f"Data mismatch: Found '<image>' token but no corresponding image."
                            )

                        # Process the image
                        image = images[image_idx]
                        img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(image)

                        images_list.extend(img_list)
                        images_crop_list.extend(crop_list)
                        images_spatial_crop.extend(spatial_crop)

                        # Add image placeholder tokens
                        tokenized_str.extend(tok_img)
                        images_seq_mask.extend([True] * len(tok_img))

                        image_idx += 1 # Move to the next image

            # --- 3. Validation and Final Prep ---
            if image_idx != len(images):
                raise ValueError(
                    f"Data mismatch: Found {len(images)} images but only {image_idx} '<image>' tokens were used."
                )

            # If we never found an assistant message, we're in a weird state
            # (e.g., user-only prompt). We mask everything.
            if not assistant_started:
                print("Warning: No assistant message found in sample. Masking all tokens.")
                prompt_token_count = len(tokenized_str)

            # Prepare image tensors
            images_ori = torch.stack(images_list, dim=0)
            images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype=torch.long)

            if images_crop_list:
                images_crop = torch.stack(images_crop_list, dim=0)
            else:
                images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype=self.dtype)

            return {
                "input_ids": torch.tensor(tokenized_str, dtype=torch.long),
                "images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool),
                "images_ori": images_ori,
                "images_crop": images_crop,
                "images_spatial_crop": images_spatial_crop_tensor,
                "prompt_token_count": prompt_token_count, # This is now accurate
            }

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """Collate batch of samples"""
        batch_data = []

        # Process each sample
        for feature in features:
            try:
                processed = self.process_single_sample(feature['messages'])
                batch_data.append(processed)
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue

        if not batch_data:
            raise ValueError("No valid samples in batch")

        # Extract lists
        input_ids_list = [item['input_ids'] for item in batch_data]
        images_seq_mask_list = [item['images_seq_mask'] for item in batch_data]
        prompt_token_counts = [item['prompt_token_count'] for item in batch_data]

        # Pad sequences
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        images_seq_mask = pad_sequence(images_seq_mask_list, batch_first=True, padding_value=False)

        # Create labels
        labels = input_ids.clone()

        # Mask padding tokens
        labels[labels == self.tokenizer.pad_token_id] = -100

        # Mask image tokens (model shouldn't predict these)
        labels[images_seq_mask] = -100

        # Mask user prompt tokens when train_on_responses_only=True (only train on assistant responses)
        if self.train_on_responses_only:
            for idx, prompt_count in enumerate(prompt_token_counts):
                if prompt_count > 0:
                    labels[idx, :prompt_count] = -100

        # Create attention mask
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

        # Prepare images batch (list of tuples)
        images_batch = []
        for item in batch_data:
            images_batch.append((item['images_crop'], item['images_ori']))

        # Stack spatial crop info
        images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim=0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "images": images_batch,
            "images_seq_mask": images_seq_mask,
            "images_spatial_crop": images_spatial_crop,
        }

## Evaluation Helper Functions

We define these functions early so we can use them during hyperparameter tuning.

In [None]:
from tqdm.auto import tqdm

def normalize_text(text):
    """
    Standardize text to ensure fair CER calculation.
    1. Unicode normalization (fix composite vs precomposed Vietnamese chars)
    2. Strip whitespace (leading/trailing/newlines)
    3. Reduce multiple spaces to single space
    """
    if not isinstance(text, str):
        return ""

    # Normalize Unicode (e.g., specific Vietnamese tone encodings)
    text = unicodedata.normalize("NFC", text)

    # Remove newlines and extra spaces
    text = text.replace("\n", " ").strip()
    text = " ".join(text.split())

    return text

def evaluate_by_category(model, tokenizer, test_data, num_samples_per_category=None, num_to_show = 3, instruction="<image>\nFree OCR. "):
    categories = {
        "Word": "UIT_HWDB_word",
        "Line": "UIT_HWDB_line",
        "Paragraph": "UIT_HWDB_paragraph"
    }

    results = {}

    for category_name, folder_identifier in categories.items():
        print(f"\n{'='*20} Evaluating {category_name} Level {'='*20}")

        # Filter test data for this category based on the folder name in the path
        category_samples = [s for s in test_data if folder_identifier in s["image"]]

        if not category_samples:
            print(f"No samples found for {category_name} (looking for '{folder_identifier}' in paths).")
            continue

        if num_samples_per_category == None:
            eval_samples = category_samples
        else:
            random.seed(42) # Reproducible random sampling
            if len(category_samples) > num_samples_per_category:
                eval_samples = random.sample(category_samples, num_samples_per_category)
            else:
                eval_samples = category_samples

        print(f"Found {len(category_samples)} total. Evaluating random {len(eval_samples)}...")

        cat_predictions = []
        cat_references = []

        for i, sample in tqdm(enumerate(eval_samples), total=len(eval_samples), desc=f"Processing {category_name}"):
            image_path = sample["image"]
            ground_truth = normalize_text(sample["label"])

            try:
                # Call generate_prediction
                raw_prediction = generate_prediction(model, tokenizer, image_path, instruction)

                predicted_text = normalize_text(raw_prediction)

                cat_predictions.append(predicted_text)
                cat_references.append(ground_truth)

                # Print first few examples to verify
                # use tqdm.write to prevent the print from breaking the progress bar layout
                if i < num_to_show:
                    tqdm.write(f"\nSample {i+1}:")
                    tqdm.write(f"  Image: {image_path}")
                    tqdm.write(f"  Ground Truth: {ground_truth}")
                    tqdm.write(f"  Prediction: {predicted_text}")
            except Exception as e:
                tqdm.write(f"Error processing sample {i+1} ({image_path}): {e}")
                if "out of memory" in str(e).lower():
                    torch.cuda.empty_cache()

        # Calculate CER for this category
        if cat_references and cat_predictions:
            category_cer = cer(cat_references, cat_predictions)
            print(f"\n>>> {category_name} CER: {category_cer:.4f}")
            results[category_name] = category_cer
        else:
            print(f"\n>>> {category_name}: No successful predictions to calculate CER.")
            results[category_name] = None

    return results

def load_and_evaluate_model(model_path, run_name):
    print(f"\n{'#'*40}")
    print(f"Evaluating: {run_name}")
    print(f"Path: {model_path}")
    print(f"{'#'*40}\n")

    # Clear cache before loading new model
    torch.cuda.empty_cache()
    gc.collect()

    try:
        # Load Model and Adapter together using Unsloth
        print(f"Loading model from {model_path}...")
        model, tokenizer = FastVisionModel.from_pretrained(
            model_path,
            load_in_4bit = True,
            auto_model = AutoModel,
            trust_remote_code=True,
            unsloth_force_compile=True,
            use_gradient_checkpointing = "unsloth",
        )

        FastVisionModel.for_inference(model)

        # Run evaluation
        results = evaluate_by_category(model, tokenizer, test_data)

        # Cleanup to free VRAM for the next model
        del model, tokenizer
        torch.cuda.empty_cache()
        gc.collect()

        return results
    except Exception as e:
        print(f"Failed to load or evaluate {run_name}: {e}")
        traceback.print_exc()
        return None

## Prediction Generation

Define a function to generate OCR predictions from the model. This function handles image processing and inference.

In [None]:
import io
from contextlib import redirect_stdout

# Load test data
if os.path.exists('test.json'):
    with open('test.json', 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    print(f"Loaded {len(test_data)} test samples.")
else:
    print("Warning: test.json not found. Please generate it first.")
    test_data = []

def generate_prediction(model, tokenizer, image_path, instruction="<image>\nFree OCR. "):
    """
    Generate OCR prediction using the model.infer method.
    """
    # Capture stdout because model.infer might print the result instead of returning it
    f = io.StringIO()

    # Use inference_mode to reduce memory usage
    with torch.inference_mode():
        with redirect_stdout(f):
            result = model.infer(
                tokenizer,
                prompt=instruction,
                image_file=image_path,
                output_path='.',
                base_size=1024,
                image_size=640,
                crop_mode=True,
                save_results=False, # Don't save to disk for evaluation loop
                test_compress=False
            )

    captured_output = f.getvalue().strip()

    # Filter out debug prints from captured stdout
    # The model seems to print tensor sizes during inference
    lines = captured_output.split('\n')
    filtered_lines = [
        line for line in lines
        if not line.strip().startswith("BASE:  torch.Size")
        and not line.strip().startswith("PATCHES:  torch.Size")
        and not line.strip() == "====================="
    ]
    cleaned_output = "\n".join(filtered_lines).strip()

    if result is not None:
        # If result is a list, take the first element
        if isinstance(result, list) and len(result) > 0:
            return result[0]
        return result

    return cleaned_output

## Load Base Model

Load the base model to ensure a clean state before starting the final fine-tuning process with the selected hyperparameters.

In [None]:
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = '0'

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Qwen3-VL-8B-Instruct-bnb-4bit", # Qwen 3 vision support
    "unsloth/Qwen3-VL-8B-Thinking-bnb-4bit",
    "unsloth/Qwen3-VL-32B-Instruct-bnb-4bit",
    "unsloth/Qwen3-VL-32B-Thinking-bnb-4bit",
] # More models at https://huggingface.co/unsloth

import torch
import gc

# Cleanup previous model if it exists
if 'model' in locals():
    del model
if 'tokenizer' in locals():
    del tokenizer
torch.cuda.empty_cache()
gc.collect()

print("Reloading base model for final training...")
model, tokenizer = FastVisionModel.from_pretrained(
    "./deepseek_ocr",
    load_in_4bit = True,
    auto_model = AutoModel,
    trust_remote_code=True,
    unsloth_force_compile=True,
    use_gradient_checkpointing = "unsloth",
)

Inference example

```bash
from PIL import Image
from IPython.display import display

prompt = "<image>\nFree OCR "
image_file = '/content/UIT_HWDB_word/train_data/1/1.jpg'
output_path ='.'

image = Image.open(image_file)
display(image)

res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = False)
```

## Setup LoRA for Fine-tuning

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

## Save Initialized Model

Save the model with the initialized LoRA adapters before training starts. This serves as a baseline.

In [None]:
# Save the model before training (initialized LoRA adapters)
os.makedirs("saved_models/before_finetune", exist_ok=True)
model.save_pretrained("saved_models/before_finetune")
tokenizer.save_pretrained("saved_models/before_finetune")
print("Saved initialized model to 'saved_models/before_finetune'")

## Training

In [None]:
from transformers import Trainer, TrainingArguments
from unsloth import is_bf16_supported

# Save to Google Drive if running on Colab to persist checkpoints
output_dir = "outputs"
if os.path.exists("/content/drive/MyDrive"):
    output_dir = "/content/drive/MyDrive/deepseek_ocr_checkpoints"
    print(f"Saving checkpoints to Google Drive: {output_dir}")

FastVisionModel.for_training(model) # Enable for training!
data_collator = DeepSeekOCRDataCollator(
    tokenizer=tokenizer,
    model = model,
    image_size=640,
    base_size=1024,
    crop_mode=True,
    train_on_responses_only=True,
)
trainer = Trainer(
    model = model,
    tokenizer = tokenizer,
    data_collator = data_collator, # Must use!
    train_dataset = converted_dataset,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # max_steps = 60,
        num_train_epochs = 1, # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "linear",
        seed = 3407,
        fp16 = not is_bf16_supported(),  # Use fp16 if bf16 is not supported
        bf16 = is_bf16_supported(),  # Use bf16 if supported
        output_dir = output_dir,
        report_to = "none",     # For Weights and Biases
        dataloader_num_workers=2,
        save_steps = 20, # 0 for not saving
        save_total_limit = 2,
        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
    ),
)

## Start Training

Begin the fine-tuning process using the configured trainer.

In [None]:
# To resume from the latest checkpoint, use resume_from_checkpoint=True
trainer_stats = trainer.train(resume_from_checkpoint=True)
# trainer_stats = trainer.train()

## Save Fine-Tuned Model

Save the model and tokenizer after fine-tuning is complete.

In [None]:
# Save the fine-tuned model
os.makedirs("saved_models/after_finetune", exist_ok=True)
model.save_pretrained("saved_models/after_finetune")
tokenizer.save_pretrained("saved_models/after_finetune")
print("Saved fine-tuned model to 'saved_models/after_finetune'")

os.makedirs("/content/drive/MyDrive/deepseek_ocr_checkpoints/after_finetune", exist_ok=True)
model.save_pretrained("/content/drive/MyDrive/deepseek_ocr_checkpoints/after_finetune")
tokenizer.save_pretrained("/content/drive/MyDrive/deepseek_ocr_checkpoints/after_finetune")
print("Saved fine-tuned model to '/content/drive/MyDrive/deepseek_ocr_checkpoints/after_finetune'")

## Final Evaluation

Compare the performance of the model before and after fine-tuning using the Character Error Rate (CER) metric.

In [None]:
# 1. Evaluate Before Fine-tuning
results_before = load_and_evaluate_model("saved_models/before_finetune", "Before Fine-tuning")

# 2. Evaluate After Fine-tuning
# results_after = load_and_evaluate_model("saved_models/after_finetune", "After Fine-tuning")
results_after = load_and_evaluate_model("/content/drive/MyDrive/Deepseek_OCR_FT/after_finetune", "After Fine-tuning")

# 3. Print Comparison
print("\n" + "="*40)
print("FINAL COMPARISON (CER - Lower is Better)")
print("="*40)
categories = ["Word", "Line", "Paragraph"]
print(f"{'Category':<15} | {'Before':<10} | {'After':<10} | {'Improvement':<10}")
print("-" * 55)

for cat in categories:
    val_before = results_before.get(cat, "N/A") if results_before else "N/A"
    val_after = results_after.get(cat, "N/A") if results_after else "N/A"

    diff_str = "N/A"
    if isinstance(val_before, float) and isinstance(val_after, float):
        diff = val_before - val_after
        diff_str = f"{diff:+.4f}"
        val_before_str = f"{val_before:.4f}"
        val_after_str = f"{val_after:.4f}"
    else:
        val_before_str = str(val_before)
        val_after_str = str(val_after)

    print(f"{cat:<15} | {val_before_str:<10} | {val_after_str:<10} | {diff_str:<10}")