In [None]:
import os
from PIL import Image
from manga_ocr import MangaOcr
from tqdm import tqdm
import glob
import xml.etree.ElementTree as ET
import cv2
import matplotlib.pyplot as plt
import re
import unicodedata

In [None]:
# Initialize MangaOcr
mocr = MangaOcr()

In [None]:
# --- 1. Configuration Section ---
END_WITH_LOCAL = 'manga-ocr'

NOTEBOOK_DIR = os.getcwd()
print(f"NOTEBOOK_DIR: {NOTEBOOK_DIR}")

# Simple validation
if not (NOTEBOOK_DIR.endswith('/content') or NOTEBOOK_DIR.endswith(END_WITH_LOCAL)):
    raise ValueError(f"Expected to be in .../{END_WITH_LOCAL} or .../content directory, but got: {NOTEBOOK_DIR}")

BASE_DIR = os.path.join(NOTEBOOK_DIR, '..', '..', '..')
print(f"BASE_DIR: {BASE_DIR}")

# Define paths to the Manga109 dataset
IMAGE_ROOT_DIR = os.path.join(BASE_DIR, 'data', 'Manga109_released_2023_12_07/images')
ANNOTATIONS_DIR = os.path.join(BASE_DIR, 'data', 'Manga109_released_2023_12_07/annotations')

# Check if the directories exist
if not os.path.isdir(IMAGE_ROOT_DIR):
    print(f"Warning: Manga images root directory not found at '{IMAGE_ROOT_DIR}'")
    print("Please ensure the Manga109 dataset is downloaded and the path is correct.")
if not os.path.isdir(ANNOTATIONS_DIR):
    print(f"Warning: Annotations directory not found at '{ANNOTATIONS_DIR}'")

xml_files = sorted([f for f in os.listdir(ANNOTATIONS_DIR) if f.endswith('.xml')])
print(f"Found {len(xml_files)} manga books with annotations.")

In [None]:
# --- 2. Process each book ---
def clean_manga_ocr_text(text):
    if not text:
        return ""

    # 1. Chuẩn hóa Unicode về dạng chuẩn (NFKC)
    text = unicodedata.normalize('NFKC', text)

    # 2. Xóa toàn bộ khoảng trắng (Space, Tab, Newline)
    text = re.sub(r'\s+', '', text)

    # 3. Gộp dấu chấm: Biến '.....' hoặc '..' thành '…'
    text = re.sub(r'\.{2,}', '…', text)

    # 4. Chuẩn hóa dấu ngã
    text = text.replace('~', '〜')

    return text

results = []

for xml_file in tqdm(xml_files, desc="Processing manga books"):
    book_name = os.path.splitext(xml_file)[0]
    xml_path = os.path.join(ANNOTATIONS_DIR, xml_file)
    
    book_image_dir = os.path.join(IMAGE_ROOT_DIR, book_name)
    if not os.path.isdir(book_image_dir):
        # Handle cases where book name in XML and folder name don't match (e.g., LoveHina_vol01 vs LoveHina)
        if '_vol' in book_name:
            base_book_name = book_name.split('_vol')[0]
            book_image_dir = os.path.join(IMAGE_ROOT_DIR, base_book_name)
            if not os.path.isdir(book_image_dir):
                # print(f"Warning: Could not find image directory for book '{book_name}'")
                continue
        else:
            # print(f"Warning: Could not find image directory for book '{book_name}'")
            continue

    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        pages = root.find('pages')

        if pages is None:
            continue

        for page in tqdm(pages.findall('page'), desc=f"Processing pages in {book_name}", leave=False):
            page_index = int(page.get('index'))
            
            # Construct image path
            img_filename = f"{page_index}.jpg"
            img_path = os.path.join(book_image_dir, img_filename)

            if not os.path.exists(img_path):
                # Try with png extension as a fallback
                img_path = os.path.join(book_image_dir, f"{page_index}.png")
                if not os.path.exists(img_path):
                    continue

            try:
                image = cv2.imread(img_path)
                if image is None:
                    continue
                
                # Convert image to RGB for PIL
                image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                for text_element in page.findall('.//text'):
                    text_id = text_element.get('id')
                    text = text_element.text
                    xmin = int(text_element.get('xmin'))
                    ymin = int(text_element.get('ymin'))
                    xmax = int(text_element.get('xmax'))
                    ymax = int(text_element.get('ymax'))

                    # Crop the bubble from the image
                    cropped_bubble = image_rgb[ymin:ymax, xmin:xmax]

                    if cropped_bubble.size == 0:
                        continue
                    
                    # Convert cropped numpy array to PIL Image
                    pil_image = Image.fromarray(cropped_bubble)
                    
                    # Perform OCR
                    ocr_text = mocr(pil_image)
                    ocr_text = clean_manga_ocr_text(ocr_text)
                    
                    results.append({
                        'book': book_name,
                        'page': page_index,
                        'text_id': text_id,
                        'bbox': [xmin, ymin, xmax, ymax],
                        'ground_truth': text,
                        'ocr_result': ocr_text,
                        'image': pil_image
                    })

            except Exception as e:
                print(f"Error processing page {page_index} of {book_name}: {e}")

    except ET.ParseError as e:
        print(f"Error parsing XML file {xml_file}: {e}")

In [None]:
# --- 3. Visualize some results ---
import random

def visualize_result(result):
    """Visualizes a single OCR result and prints its details."""
    print(f"Book: {result['book']}")
    print(f"Page: {result['page']}")
    print(f"Text ID: {result['text_id']}")
    print(f"Ground Truth: {result['ground_truth']}")
    print(f"OCR Result: {result['ocr_result']}")
    
    plt.imshow(result['image'])
    plt.axis('off')
    plt.show()

# Visualize a random result
if results:
    random_result = random.choice(results)
    visualize_result(random_result)
else:
    print("No results to display.")

In [None]:
# --- 4. Save results to JSON ---
import json

# Prepare results for JSON serialization (remove non-serializable image objects)
results_to_save = []
for res in results:
    # Create a copy of the result dictionary without the 'image' key
    res_copy = {key: value for key, value in res.items() if key != 'image'}
    results_to_save.append(res_copy)

# Define output path
output_dir = os.path.join(BASE_DIR, 'output', 'MangaOCR_inference')
os.makedirs(output_dir, exist_ok=True)
output_json_path = os.path.join(output_dir, 'manga_ocr_results_with_regex.json')

# Save results to a JSON file
if results_to_save:
    with open(output_json_path, 'w', encoding='utf-8') as f:
        json.dump(results_to_save, f, ensure_ascii=False, indent=4)
    print(f"Successfully saved {len(results_to_save)} results to {output_json_path}")
else:
    print("No results to save.")

In [None]:
# Evaluate OCR performance via CER and WER
from torchmetrics.text import CharErrorRate, WordErrorRate

preds = []
target = []

for res in results:
    pred = res['ocr_result']
    gt = clean_manga_ocr_text(res['ground_truth'])
    
    if gt.strip():
        preds.append(pred)
        target.append(gt)

print(f"Evaluating on {len(preds)} samples...")

if preds:
    cer_metric = CharErrorRate()
    cer = cer_metric(preds, target)
    
    wer_metric = WordErrorRate()
    wer = wer_metric(preds, target)

    print(f"Character Error Rate (CER): {cer.item():.4f}")
    print(f"Word Error Rate (WER):      {wer.item():.4f}")
    print("="*40)

    count = 0
    for p, t in zip(preds, target):
        if p != t:
            print(f"GT:   {t}")
            print(f"Pred: {p}")
            print("-" * 20)
            count += 1
            if count >= 10:
                break