In [None]:
import os
import json
import random
from collections import defaultdict

import cv2
from sklearn.model_selection import train_test_split

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from detectron2.utils.logger import setup_logger
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.structures import BoxMode
from detectron2 import model_zoo
from detectron2.data import build_detection_test_loader
from detectron2.data import MetadataCatalog

from detectron2.evaluation import COCOEvaluator, inference_on_dataset
import pprint

import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

setup_logger()

In [None]:
# --- 1. Configuration Section ---
END_WITH_LOCAL = 'detectron2'

NOTEBOOK_DIR = os.getcwd()
print(f"NOTEBOOK_DIR: {NOTEBOOK_DIR}")

# Simple validation
if not (NOTEBOOK_DIR.endswith('/content') or NOTEBOOK_DIR.endswith(END_WITH_LOCAL)):
    raise ValueError(f"Expected to be in .../{END_WITH_LOCAL} or .../content directory, but got: {NOTEBOOK_DIR}")

BASE_DIR = os.path.join(NOTEBOOK_DIR, '..', '..', '..')
print(f"BASE_DIR: {BASE_DIR}")

# Paths to your data and model files
JSON_DIR = os.path.join(BASE_DIR, 'data', 'MangaSegmentation/jsons_processed')
IMAGE_ROOT_DIR = os.path.join(BASE_DIR, 'data', 'Manga109_released_2023_12_07/images')
PRE_TRAINED_MODEL_DIR = os.path.join(BASE_DIR, 'models', 'bubble-detection','detectron2','pre-trained_model')

# The category we want to train on
TARGET_CATEGORY_ID = 5
TARGET_CATEGORY_NAME = "balloon"

In [None]:
# --- 2. Data Preparation ---

def prepare_manga_balloon_data(json_dir, image_root):
    """
    Loads pre-processed JSON files and filters for the target category.
    This function is now extremely fast as no conversion is needed.
    """
    dataset_dicts = []
    all_images = {}
    all_annotations = defaultdict(list)

    print("Loading and parsing PRE-PROCESSED JSON files...")
    for json_file in os.listdir(json_dir):
        if not json_file.endswith('.json'): continue
        with open(os.path.join(json_dir, json_file), 'r') as f:
            data = json.load(f)
            for img_info in data['images']:
                all_images[img_info['id']] = img_info
            for ann_info in data['annotations']:
                all_annotations[ann_info['image_id']].append(ann_info)

    print(f"Loaded data for {len(all_images)} total images.")
    
    for img_id, img_info in all_images.items():
        record = {}
        record["file_name"] = os.path.join(image_root, img_info['file_name'])
        record["image_id"] = img_id
        record["height"] = img_info['height']
        record["width"] = img_info['width']
        objs = []
        for ann in all_annotations[img_id]:
            if ann['category_id'] == TARGET_CATEGORY_ID:
                obj = {
                    "bbox": ann['bbox'],
                    "bbox_mode": BoxMode.XYWH_ABS,
                    "segmentation": ann['segmentation'], 
                    "category_id": 0,
                }
                objs.append(obj)
        if len(objs) > 0:
            record["annotations"] = objs
            dataset_dicts.append(record)
            
    print(f"Finished data preparation. Found {len(dataset_dicts)} images containing '{TARGET_CATEGORY_NAME}'.")
    return dataset_dicts


In [None]:
# --- 3. Register Datasets with Detectron2  ---

# Prepare the data
all_data = prepare_manga_balloon_data(JSON_DIR, IMAGE_ROOT_DIR)

# --- Group data by manga title ---
print("\nGrouping data by manga series for a robust train/val split...")
grouped_data = defaultdict(list)
for record in all_data:
    # Extract manga name from the file path.
    manga_name = os.path.basename(os.path.dirname(record['file_name']))
    grouped_data[manga_name].append(record)

print(f"Found {len(grouped_data)} unique manga series.")

# --- Split manga titles, not individual pages ---
manga_titles = list(grouped_data.keys())
train_titles, val_titles = train_test_split(manga_titles, test_size=0.2, random_state=42)

print(f"Splitting into {len(train_titles)} series for training and {len(val_titles)} for validation.")

# --- Reconstruct train and val sets based on the split titles ---
train_data = []
for title in train_titles:
    train_data.extend(grouped_data[title])

val_data = []
for title in val_titles:
    val_data.extend(grouped_data[title])

# Shuffle the datasets to ensure randomness within each set
random.Random(42).shuffle(train_data)
random.Random(42).shuffle(val_data)

print(f"\nFinal training set size: {len(train_data)} images")
print(f"Final validation set size: {len(val_data)} images")
# Verify that no manga series is in both sets
assert len(set(train_titles) & set(val_titles)) == 0, "Data leakage detected! Same manga in train and val."
print("Split verified: No data leakage between train and validation sets.")

In [None]:
# Create a list of dictionaries for the DataFrame
train_split_info = [{'manga_title': title, 'dataset_split': 'train'} for title in train_titles]
val_split_info = [{'manga_title': title, 'dataset_split': 'validation'} for title in val_titles]

# Register the datasets
if "manga_balloon_train" in DatasetCatalog.list():
    DatasetCatalog.remove("manga_balloon_train")
if "manga_balloon_val" in DatasetCatalog.list():
    DatasetCatalog.remove("manga_balloon_val")

DatasetCatalog.register("manga_balloon_train", lambda: train_data)
DatasetCatalog.register("manga_balloon_val", lambda: val_data)

# Set metadata for the datasets
MetadataCatalog.get("manga_balloon_train").set(thing_classes=[TARGET_CATEGORY_NAME])
MetadataCatalog.get("manga_balloon_val").set(thing_classes=[TARGET_CATEGORY_NAME])

balloon_metadata = MetadataCatalog.get("manga_balloon_train")

In [None]:

# --- 6. Training ---

print("\nConfiguring the model for training...")
cfg = get_cfg()

# Load the base configuration from Detectron2's model zoo
# Use model_zoo.get_config_file() instead of a local path
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml"))

# Point to the datasets
cfg.DATASETS.TRAIN = ("manga_balloon_train",)
cfg.DATASETS.TEST = ("manga_balloon_val",)
cfg.DATALOADER.NUM_WORKERS = 8

# Load pretrained weights from model zoo
# This will automatically download if not cached
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE = str(device)

# --- Training Hyperparameters ---
cfg.SOLVER.IMS_PER_BATCH = 1  # Batch size. Adjust based on GPU memory
cfg.SOLVER.BASE_LR = 0.00025 # Learning rate
cfg.SOLVER.MAX_ITER = 50     # Number of training iterations: 3000, set to 50 for quick test

cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
cfg.SOLVER.WARMUP_ITERS = 300        # Number step "initial" with low LR
cfg.SOLVER.WARMUP_FACTOR = 1.0 / 300 # Start from LR = BASE_LR * WARMUP_FACTOR

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128

# --- Crucial Step: Set the Number of Classes ---
# We only have one class (balloon), so we set NUM_CLASSES to 1.
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

# Define the output directory for logs and trained models
cfg.OUTPUT_DIR = os.path.join(BASE_DIR, 'models', 'bubble-detection','detectron2','output_balloon_segmentation_v3')

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [None]:
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)

In [None]:
print("Starting training...")
trainer.train()

In [None]:
# =============================================================
# --- 7. Setup for Inference and Evaluation ---
# =============================================================

print("\n--- Setting up for testing ---")
 
final_checkpoint_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
print(f"Loading final model weights from: {final_checkpoint_path}")
cfg.MODEL.WEIGHTS = final_checkpoint_path

# Set a threshold for filtering low-confidence predictions during testing
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # Adjust this (0.0 to 1.0) as needed

# 2. Create the Predictor
predictor = DefaultPredictor(cfg)
print("Predictor loaded successfully.")

# Get metadata for the validation set
val_metadata = MetadataCatalog.get("manga_balloon_val")

In [None]:
# ==============================================================
# --- 8. Qualitative Testing (Visualization) ---
# ==============================================================
print("\n--- Starting Qualitative Visualization on Random Validation Samples ---")

# Create a directory to save visualization results
# vis_output_dir = "./balloon_test_visualizations" #
vis_output_dir = os.path.join(BASE_DIR, 'output', 'bubble-detection','detectron2','balloon_test_visualizations')  
os.makedirs(vis_output_dir, exist_ok=True)

# Number of random samples to visualize
num_samples = 20

local_rng = random.Random(42)
samples = local_rng.sample(val_data, min(num_samples, len(val_data)))

for i, d in enumerate(samples):
    img_path = d["file_name"]

    # Overall path: .../images/MangaName/PageNumber.jpg
    dir_path, page_filename = os.path.split(img_path)
    _, manga_name = os.path.split(dir_path)
    
    # 
    page_name_no_ext = os.path.splitext(page_filename)[0]

    print(f"Processing sample {i+1}/{len(samples)}: Manga='{manga_name}', Page='{page_filename}'")

    # Read image (OpenCV reads as BGR)
    im = cv2.imread(img_path)

    # Perform Inference
    outputs = predictor(im)

    # Visualize
    v = Visualizer(im[:, :, ::-1],
                   metadata=val_metadata,
                   instance_mode=ColorMode.IMAGE_BW
    )

    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))

    # Output: pred_MangaName_PageNum.jpg
    save_filename = f"pred_{manga_name}_{page_name_no_ext}.jpg"
    save_path = os.path.join(vis_output_dir, save_filename)

    cv2.imwrite(save_path, out.get_image()[:, :, ::-1])

print(f"Saved {len(samples)} visualization results to '{vis_output_dir}'")

In [None]:
# ==============================================================
# --- 9. Quantitative Evaluation (COCO Metrics / AP) ---
# ==============================================================
print("\n--- Starting Quantitative Evaluation on Entire Validation Set ---")

# 1. Define the Evaluator
evaluator = COCOEvaluator("manga_balloon_val", output_dir=cfg.OUTPUT_DIR)

# 2. Create the Test Dataloader
val_loader = build_detection_test_loader(cfg, "manga_balloon_val")

# 3. Run Inference and Evaluation
results = inference_on_dataset(predictor.model, val_loader, evaluator)

print("\n--- Evaluation Results (Average Precision) ---")
pprint.pprint(results)

In [None]:
# !tensorboard --logdir ./output_balloon_segmentation_v3/