In [None]:
# Mongolian OCR Training on Google Colab
# Run each cell in order by clicking the play button or pressing Shift+Enter

# ===== CELL 1: Install Dependencies =====
print("Installing Kraken and dependencies...")
!pip install -q kraken pillow


In [None]:

# ===== CELL 2: Mount Google Drive =====
from google.colab import drive
drive.mount('/content/drive')
print("\nGoogle Drive mounted!")
print("Your files should be in /content/drive/MyDrive/")



In [None]:
# --- Fix for torch/torchvision/lightning version mismatch ---
!pip install --upgrade torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 lightning==2.4.0 torchmetrics==1.4.0 kraken==4.3.13


In [None]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 \
    lightning==2.4.0 torchmetrics==1.4.0 kraken==6.0.2 --upgrade --quiet


In [None]:
import torch, torchvision, kraken
print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("Kraken:", kraken.__version__)
print("CUDA available:", torch.cuda.is_available())


In [None]:
# ===== CELL 5: Verify files =====
import glob
import os

# Check if images exist
os.chdir('/content')
images = glob.glob('/content/drive/MyDrive/synthetic_mongolian_large_images/images/*.png')
gt_files = glob.glob('/content/drive/MyDrive/synthetic_mongolian_large_images/images/*.gt.txt')

print(f"Found {len(images)} PNG images")
print(f"Found {len(gt_files)} ground truth files")

if len(images) != len(gt_files):
    print("WARNING: Number of images and ground truth files don't match!")
else:
    print("‚úì All files present and matched!")



In [None]:
# === CELL 6: Kraken OCR Training (Drive-safe, autosave, GPU-aware) ===
import os, glob, random, shutil, traceback, subprocess, re, builtins, sys, torch
from packaging import version

# I/O problem
if hasattr(sys, "stdout") and hasattr(sys.stdout, "write"):
    try:
        sys.stdout.flush()
    except Exception:
        sys.stdout = sys.__stdout__

# Prevent Kraken exit() crash
builtins.exit = lambda code=0: (_ for _ in ()).throw(SystemExit(code))

# Check Kraken version
try:
    import kraken
    KRAKEN_VERSION = version.parse(kraken.__version__)
except Exception:
    KRAKEN_VERSION = version.parse("0.0.0")
#print(f"‚úÖ Detected Kraken version: {KRAKEN_VERSION}")

from kraken.lib.train import RecognitionModel
from kraken.lib import train as train_lib
KrakenTrainer = train_lib.KrakenTrainer
try:
    from kraken.lib import evaluate
    HAS_EVALUATE = True
except ImportError:
    HAS_EVALUATE = False


def compute_cer(model_path, val_path, log):
    cer = None
    val_imgs = sorted(glob.glob(os.path.join(val_path, '*.png')))
    if not val_imgs:
        log("No validation images found.")
        return None
    try:
        if HAS_EVALUATE:
            res = evaluate.evaluate(model=model_path, test_data=val_imgs, device='cuda' if torch.cuda.is_available() else 'cpu')
            cer = res.get('char_error_rate', None)
        else:
            cmd = ["ketos", "test", "-m", model_path, val_path]
            result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
            m = re.search(r"char_error_rate[:=]\s*([0-9.]+)", result.stdout)
            if m:
                cer = float(m.group(1))
    except Exception as e:
        log(f"‚ö†Ô∏è CER evaluation failed: {e}")
    return cer


def train_mongolian_model(
    data_path,
    checkpoints_dir,
    val_split,
    batch_size,
    images_per_batch,
    epochs,
    learning_rate,
    keep_last_n,
    lag,
    min_epochs,
    quit_mode,
    freq,
    partition,
    load_threads,
):
    """Main Kraken OCR training loop (fast local data + Drive checkpoints)."""
    os.makedirs(checkpoints_dir, exist_ok=True)
    log_path = os.path.join(checkpoints_dir, "training_log.txt")
    cer_log_path = os.path.join(checkpoints_dir, "cer_log.txt")

    def log(msg):
        print(msg)
        with open(log_path, "a", encoding="utf-8") as f: f.write(msg + "\n")

    def log_cer(ep, cer):
        with open(cer_log_path, "a", encoding="utf-8") as f:
            f.write(f"Epoch {ep}: CER={cer:.4f}\n")

    try:
        log("="*60)
        log("TRAINING START")
        log("="*60)

        # === Verify dataset consistency ===
        imgs = sorted(glob.glob(os.path.join(data_path, "images", "*.png")))
        total = len(imgs)
        gts = sorted(glob.glob(os.path.join(data_path, "images", "*.gt.txt")))
        log(f"Found {total} PNGs and {len(gts)} GTs")
        if total == 0:
            log("‚ùå No training images found!"); return None
        if abs(total - len(gts)) > 0:
            log("‚ö†Ô∏è Mismatch between .png and .gt.txt counts")

        # === Validation split ===
        val_path = os.path.join(data_path, "validation")
        if not os.path.exists(val_path):
            os.makedirs(val_path, exist_ok=True)
            n_val = max(1, int(total * val_split))
            val_imgs = random.sample(imgs, n_val)
            for img in val_imgs:
                gt = img.replace(".png", ".gt.txt")
                shutil.move(img, os.path.join(val_path, os.path.basename(img)))
                if os.path.exists(gt):
                    shutil.move(gt, os.path.join(val_path, os.path.basename(gt)))
            log(f"Created validation split of {n_val} images.")
            imgs = sorted(glob.glob(os.path.join(data_path, "images", "*.png")))
            total = len(imgs)

        batches = [imgs[i:i+images_per_batch] for i in range(0, total, images_per_batch)]
        log(f"Split into {len(batches)} batches (‚â§{images_per_batch} each)")

        # === Resume if checkpoints exist ===
        existing = sorted(glob.glob(os.path.join(checkpoints_dir, "mongolian_model_epoch_*.mlmodel")))
        start_epoch = 1; last_ckpt = None; best_cer = float("inf")
        if existing:
            last_ckpt = existing[-1]
            start_epoch = int(os.path.basename(last_ckpt).split("_")[-1].split(".")[0]) + 1
            log(f"Resuming from {last_ckpt}")

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        log(f"üîß Using device: {device}")

        # === Training Loop ===
        for epoch in range(start_epoch, epochs + 1):
            log(f"\n===== EPOCH {epoch}/{epochs} =====")

            for b, batch_imgs in enumerate(batches, start=1):
                log(f"Batch {b}/{len(batches)} ({len(batch_imgs)} imgs)")

                model = RecognitionModel(
                    training_data=batch_imgs,
                    format_type="path",
                    hyper_params={
                        "epochs": 1,
                        "lag": lag,
                        "min_epochs": min_epochs,
                        "quit": quit_mode,
                        "freq": freq,
                        "partition": partition,
                        "lrate": learning_rate,
                        "load_threads": load_threads,
                        "batch_size": batch_size,
                    },
                    output="mongolian_model_temp.mlmodel",
                )

                if last_ckpt and os.path.exists(last_ckpt):
                    try:
                        model.load(last_ckpt); log(f"Loaded {last_ckpt}")
                    except Exception as e:
                        log(f"‚ö†Ô∏è load failed: {e}")

                trainer = KrakenTrainer(enable_progress_bar=True, enable_checkpointing=False, accelerator=device)
                try:
                    trainer.fit(model)
                except SystemExit:
                    log("‚ö†Ô∏è Interrupted gracefully (SystemExit).")
                except Exception as e:
                    log(f"‚ö†Ô∏è Training interrupted: {e}")

            # === Save at end of epoch ===
            temp_model = "mongolian_model_temp.mlmodel"
            ep_path = os.path.join(checkpoints_dir, f"mongolian_model_epoch_{epoch:02d}.mlmodel")
            if os.path.exists(temp_model):
                try:
                    shutil.copy(temp_model, ep_path)
                    log(f"üíæ Saved model checkpoint for epoch {epoch}: {ep_path}")
                    last_ckpt = ep_path
                except Exception as e:
                    log(f"‚ö†Ô∏è Could not save model: {e}")
            else:
                log("‚ö†Ô∏è No temp model found to save.")

            cer = compute_cer(last_ckpt, val_path, log)
            if cer and cer < best_cer:
                best_cer = cer
                log_cer(epoch, cer)
                log(f"‚ú® New best model (CER={cer:.4f})")

        log("‚úÖ Training complete.")
        return True

    except Exception as e:
        log(f"EXCEPTION: {e}")
        log(traceback.format_exc())
        return None


In [None]:
%%writefile train_script.py
# Launch training with optimized parameters for 300-DPI line images

# Function train_mongolian_model() must already be defined by running Cell 6.

# === Adjustable parameters ===
DATA_PATH        = '/content/drive/MyDrive/synthetic_mongolian_large_images'
CHECKPOINTS_DIR  = f"{DATA_PATH}/checkpoints"
VAL_SPLIT        = 0.05
BATCH_SIZE       = 8
IMAGES_PER_BATCH = 1000
EPOCHS           = 60
LEARNING_RATE    = 0.0003
KEEP_LAST_N      = 3
LAG              = 20
MIN_EPOCHS       = 1
QUIT_MODE        = 'never'
FREQ             = 1.0
PARTITION        = 0.9
LOAD_THREADS     = 4
# ===============================

success = train_mongolian_model(
    data_path=DATA_PATH,
    checkpoints_dir=CHECKPOINTS_DIR,
    val_split=VAL_SPLIT,
    batch_size=BATCH_SIZE,
    images_per_batch=IMAGES_PER_BATCH,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    keep_last_n=KEEP_LAST_N,
    lag=LAG,
    min_epochs=MIN_EPOCHS,
    quit_mode=QUIT_MODE,
    freq=FREQ,
    partition=PARTITION,
    load_threads=LOAD_THREADS
)

if success:
    print("‚úÖ Training completed successfully.")
else:
    print("‚ùå Training failed ‚Äî check training_log.txt and cer_log.txt.")


In [None]:
import time, re, matplotlib.pyplot as plt
from IPython.display import clear_output

log_path = "/content/drive/MyDrive/synthetic_mongolian_large_images/full_training_log.txt"

def live_plot(log_path, refresh=30):
    """Continuously read the log and display Train Loss / Val Acc / Word Acc trends."""
    while True:
        try:
            with open(log_path, "r", errors="ignore") as f:
                text = f.read()

            loss = [float(x) for x in re.findall(r"train_loss_epoch:\s*([\d.]+)", text)]
            acc  = [float(x) for x in re.findall(r"val_accuracy:\s*([\d.]+)", text)]
            wacc = [float(x) for x in re.findall(r"val_word_accuracy:\s*([\d.]+)", text)]

            clear_output(wait=True)
            plt.figure(figsize=(8,5))
            if loss: plt.plot(loss, label="Train Loss", color="orange")
            if acc:  plt.plot(acc,  label="Val Char Acc (%)", color="blue")
            if wacc: plt.plot(wacc, label="Val Word Acc (%)", color="green")
            plt.xlabel("Stage / Epoch"); plt.ylabel("Metric Value")
            plt.title("Kraken Training Progress (Live)")
            plt.legend(); plt.grid(True)
            plt.show()
        except Exception as e:
            print("Waiting for log file...", e)
        time.sleep(refresh)

live_plot(log_path, refresh=30)


In [None]:
# === CELL 7: Run training safely with fast local data and Drive checkpoints ===
from datetime import datetime
import sys, os, shutil, glob

drive_root = "/content/drive/MyDrive/synthetic_mongolian_large_images"
local_data = "/content/data"
checkpoints_dir = f"{drive_root}/checkpoints"
log_path = f"{checkpoints_dir}/full_training_log.txt"

# Ensure checkpoint directory exists
os.makedirs(checkpoints_dir, exist_ok=True)

# ---- Reset stdout before creating Tee ----
sys.stdout = sys.__stdout__

# ---- Define a safe logger ----
class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, data):
        for f in self.files:
            try:
                f.write(data)
                f.flush()
            except ValueError:
                pass  # file already closed
    def flush(self):
        for f in self.files:
            try:
                f.flush()
            except ValueError:
                pass

# ---- Copy data locally for fast I/O ----
if not os.path.exists(local_data):
    print("‚è≥ Copying image data to /content for fast access (may take several minutes)...")
    shutil.copytree(drive_root, local_data)
else:
    print("‚úÖ Using existing local data directory.")

# ---- Start logging ----
with open(log_path, "a", encoding="utf-8") as f:
    tee = Tee(sys.__stdout__, f)
    sys.stdout = tee
    print(f"\n=== Training session started {datetime.now()} ===\n")

    success = train_mongolian_model(
        data_path=local_data,
        checkpoints_dir=checkpoints_dir,
        val_split=0.05,
        batch_size=8,
        images_per_batch=1000,
        epochs=60,
        learning_rate=0.0003,
        keep_last_n=3,
        lag=20,
        min_epochs=1,
        quit_mode="never",
        freq=1.0,
        partition=0.9,
        load_threads=4
    )

    print("\n‚úÖ Training completed." if success else "\n‚ùå Training failed.")
    print(f"=== Training session ended {datetime.now()} ===\n")

# ---- Restore normal output ----
sys.stdout = sys.__stdout__
print("‚úÖ Logging closed and stdout restored.")


In [None]:
# Read and display the training log file
log_file_path = '/content/drive/MyDrive/synthetic_mongolian_large_images/training_log.txt'
try:
    with open(log_file_path, 'r') as f:
        log_content = f.read()
    print(log_content)
except FileNotFoundError:
    print(f"Error: The file {log_file_path} was not found.")
except Exception as e:
    print(f"An error occurred while reading the file: {e}")

In [None]:
# NEW CELL: Convert existing checkpoint to usable model
import glob
import os

# Find the latest checkpoint
checkpoints = glob.glob('/content/drive/MyDrive/*.ckpt')
checkpoints.sort()

if checkpoints:
    latest_ckpt = checkpoints[-1]
    print(f"Found {len(checkpoints)} checkpoints")
    print(f"Latest: {latest_ckpt}")

    # Load the checkpoint and save as .mlmodel
    from kraken.lib.train import RecognitionModel

    print("\nConverting checkpoint to .mlmodel format...")
    model = RecognitionModel.load_from_checkpoint(latest_ckpt)
    model.save('mongolian_model_epoch48.mlmodel')

    # Copy to Drive
    import shutil
    shutil.copy('mongolian_model_epoch48.mlmodel', '/content/drive/MyDrive/mongolian_model_epoch48.mlmodel')
    print("‚úì Model saved to: /content/drive/MyDrive/mongolian_model_epoch48.mlmodel")
    print("\nYou can download and use this model now!")
else:
    print("No checkpoints found!")

In [None]:
# ===== CELL 8: Test the model =====
# Test on a sample image
test_image = 'synthetic_mongolian_large_images/images/line_0500-1.png'

print(f"Testing model on: {test_image}")
!kraken -i {test_image} output.txt segment ocr -m mongolian_model.mlmodel

print("\nGround truth:")
!cat synthetic_mongolian_large_images/images/line_0500-1.gt.txt

print("\nModel prediction:")
!cat output.txt



In [None]:
# ===== CELL 9: Download the trained model =====
from google.colab import files

print("Downloading trained model...")
files.download('mongolian_model.mlmodel')
print("Model downloaded! You can now use it for OCR.")



In [None]:
# ===== CELL 10: (Optional) Save model to Google Drive =====
# Uncomment and run if you want to save to Drive for later use

# import shutil
# shutil.copy('mongolian_model.mlmodel', '/content/drive/MyDrive/mongolian_model.mlmodel')
# print("Model saved to Google Drive!")