### Please run with Google Colab with Good GPU
<a href="https://colab.research.google.com/github/Ichikawa-Satoshi/SI-Org-chart/blob/main/test_deeplearning/cross_valid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import json
import random
from sklearn.model_selection import KFold
from google.colab import drive
drive.mount('/content/drive')
import numpy as np

Mounted at /content/drive


In [2]:
# Detectron2 has not released pre-built binaries for the latest pytorch (https://github.com/facebookresearch/detectron2/issues/4053)
# so we install from source instead. This takes a few minutes.
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

Collecting git+https://github.com/facebookresearch/detectron2.git
  Cloning https://github.com/facebookresearch/detectron2.git to /tmp/pip-req-build-k9gvw1x1
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/detectron2.git /tmp/pip-req-build-k9gvw1x1
  Resolved https://github.com/facebookresearch/detectron2.git to commit 9604f5995cc628619f0e4fd913453b4d7d61db3f
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.8 (from detectron2==0.6)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting fvcore<0.1.6,>=0.1.5 (from detectron2==0.6)
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath<0.1.10,>=0.1.7 (from detectron2==0.6)
  Downloading iopath-0.1.9-py3-none-any.whl.metadata (370 bytes)
Collecting omegaconf<2.

In [3]:
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader

In [7]:
# path
path = "/content/drive/MyDrive/SI-Org-Chart/data/Org_chart/learning/train"
path_coco = "/content/drive/MyDrive/SI-Org-Chart/data/Org_chart/learning/Organization_annotation.json"

# load data
with open(path_coco) as f:
    coco_data = json.load(f)

annotations = coco_data["annotations"]
images = coco_data["images"]

# setting for K-fold cross validation
K = 5  # num of fold
kf = KFold(n_splits=K, shuffle=True, random_state=42)

# Cross validation
ap_scores = []
for fold, (train_idx, val_idx) in enumerate(kf.split(images)):
    print(f"Fold {fold + 1} / {K}")

    # split data (train and validation)
    train_images = [images[i] for i in train_idx]
    val_images = [images[i] for i in val_idx]

    train_ids = {img["id"] for img in train_images}
    train_annotations = [ann for ann in annotations if ann["image_id"] in train_ids]

    val_ids = {img["id"] for img in val_images}
    val_annotations = [ann for ann in annotations if ann["image_id"] in val_ids]

    train_coco = {"images": train_images, "annotations": train_annotations, "categories": coco_data["categories"]}
    val_coco = {"images": val_images, "annotations": val_annotations, "categories": coco_data["categories"]}

    # annotation paths
    train_coco_path = f"/content/drive/MyDrive/SI-Org-Chart/data/Org_chart/learning/train_fold{fold}.json"
    val_coco_path = f"/content/drive/MyDrive/SI-Org-Chart/data/Org_chart/learning/val_fold{fold}.json"

    with open(train_coco_path, "w") as f:
        json.dump(train_coco, f)
    with open(val_coco_path, "w") as f:
        json.dump(val_coco, f)

    # Detectron2
    register_coco_instances(f"org_train_{fold}", {}, train_coco_path, path)
    register_coco_instances(f"org_val_{fold}", {}, val_coco_path, path)
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.DATASETS.TRAIN = (f"org_train_{fold}",)
    cfg.DATASETS.TEST = (f"org_val_{fold}",)
    cfg.DATALOADER.NUM_WORKERS = 2
    cfg.SOLVER.IMS_PER_BATCH = 1
    cfg.SOLVER.BASE_LR = 0.0004
    cfg.SOLVER.MAX_ITER = 500
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2

    # train
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()

    # evaluation
    evaluator = COCOEvaluator(f"org_val_{fold}", cfg, False, output_dir=cfg.OUTPUT_DIR)
    val_loader = build_detection_test_loader(cfg, f"org_val_{fold}")
    eval_results = inference_on_dataset(trainer.model, val_loader, evaluator)

    # AP
    ap = eval_results["bbox"]["AP"]      # IoU 50-95: mAP
    ap50 = eval_results["bbox"]["AP50"]  # IoU 50: AP
    ap75 = eval_results["bbox"]["AP75"]  # IoU 75: AP

    print(f"Fold {fold + 1}: AP={ap:.2f}, AP50={ap50:.2f}, AP75={ap75:.2f}")
    ap_scores.append((ap, ap50, ap75))

# results
mean_ap = np.mean([score[0] for score in ap_scores])
mean_ap50 = np.mean([score[1] for score in ap_scores])
mean_ap75 = np.mean([score[2] for score in ap_scores])

print(f"\nFinal Cross-validation Results:")
print(f"Mean AP: {mean_ap:.2f}")
print(f"Mean AP50: {mean_ap50:.2f}")
print(f"Mean AP75: {mean_ap75:.2f}")


Fold 1 / 5
[03/10 03:54:52 d2.engine.defaults]: Model:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )

roi_heads.box_predictor.bbox_pred.{bias, weight}
roi_heads.box_predictor.cls_score.{bias, weight}
roi_heads.mask_head.predictor.{bias, weight}


[03/10 03:54:52 d2.engine.train_loop]: Starting training from iteration 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[03/10 03:55:03 d2.utils.events]:  eta: 0:01:49  iter: 19  total_loss: 5.898  loss_cls: 1.096  loss_box_reg: 0.5258  loss_mask: 0.6853  loss_rpn_cls: 3.042  loss_rpn_loc: 0.4501    time: 0.3658  last_time: 0.7255  data_time: 0.1918  last_data_time: 0.4941   lr: 1.5585e-05  max_mem: 1304M
[03/10 03:55:12 d2.utils.events]:  eta: 0:01:51  iter: 39  total_loss: 2.741  loss_cls: 0.9235  loss_box_reg: 0.6741  loss_mask: 0.6399  loss_rpn_cls: 0.2083  loss_rpn_loc: 0.2843    time: 0.3398  last_time: 0.3009  data_time: 0.1219  last_data_time: 0.0949   lr: 3.1569e-05  max_mem: 1427M
[03/10 03:55:20 d2.utils.events]:  eta: 0:02:19  iter: 59  total_loss: 2.31  loss_cls: 0.703  loss_box_reg: 0.612  loss_mask: 0.5651  loss_rpn_cls: 0.129  loss_rpn_loc: 0.2197    time: 0.3544  last_time: 0.3787  data_time: 0.1785  last_data_time: 0.1656   lr: 4.7553e-05  max_mem: 1458M
[03/10 03:55:27 d2.utils.events]:  eta: 0:02:12  iter: 79  total_loss: 2.149  loss_cls: 0.572  loss_box_reg: 0.7768  loss_mask: 0.471

roi_heads.box_predictor.bbox_pred.{bias, weight}
roi_heads.box_predictor.cls_score.{bias, weight}
roi_heads.mask_head.predictor.{bias, weight}


[03/10 03:57:09 d2.engine.train_loop]: Starting training from iteration 0
[03/10 03:57:14 d2.utils.events]:  eta: 0:01:29  iter: 19  total_loss: 6.058  loss_cls: 1.028  loss_box_reg: 0.6013  loss_mask: 0.685  loss_rpn_cls: 3.519  loss_rpn_loc: 0.4249    time: 0.1907  last_time: 0.2197  data_time: 0.0188  last_data_time: 0.0023   lr: 1.5585e-05  max_mem: 4314M
[03/10 03:57:18 d2.utils.events]:  eta: 0:01:27  iter: 39  total_loss: 2.856  loss_cls: 0.8958  loss_box_reg: 0.6752  loss_mask: 0.636  loss_rpn_cls: 0.2428  loss_rpn_loc: 0.3226    time: 0.1988  last_time: 0.2204  data_time: 0.0025  last_data_time: 0.0023   lr: 3.1569e-05  max_mem: 4314M
[03/10 03:57:22 d2.utils.events]:  eta: 0:01:25  iter: 59  total_loss: 2.267  loss_cls: 0.6907  loss_box_reg: 0.675  loss_mask: 0.5587  loss_rpn_cls: 0.08727  loss_rpn_loc: 0.2409    time: 0.2003  last_time: 0.1981  data_time: 0.0025  last_data_time: 0.0024   lr: 4.7553e-05  max_mem: 4314M
[03/10 03:57:26 d2.utils.events]:  eta: 0:01:22  iter: 79

roi_heads.box_predictor.bbox_pred.{bias, weight}
roi_heads.box_predictor.cls_score.{bias, weight}
roi_heads.mask_head.predictor.{bias, weight}


[03/10 03:59:09 d2.engine.train_loop]: Starting training from iteration 0
[03/10 03:59:13 d2.utils.events]:  eta: 0:01:41  iter: 19  total_loss: 5.787  loss_cls: 1.031  loss_box_reg: 0.5272  loss_mask: 0.6921  loss_rpn_cls: 3.017  loss_rpn_loc: 0.3833    time: 0.2032  last_time: 0.1483  data_time: 0.0184  last_data_time: 0.0024   lr: 1.5585e-05  max_mem: 4314M
[03/10 03:59:17 d2.utils.events]:  eta: 0:01:29  iter: 39  total_loss: 2.761  loss_cls: 0.8781  loss_box_reg: 0.6924  loss_mask: 0.6446  loss_rpn_cls: 0.2999  loss_rpn_loc: 0.2575    time: 0.1972  last_time: 0.2081  data_time: 0.0025  last_data_time: 0.0024   lr: 3.1569e-05  max_mem: 4314M
[03/10 03:59:21 d2.utils.events]:  eta: 0:01:28  iter: 59  total_loss: 2.194  loss_cls: 0.6994  loss_box_reg: 0.6428  loss_mask: 0.5658  loss_rpn_cls: 0.1191  loss_rpn_loc: 0.215    time: 0.1986  last_time: 0.1766  data_time: 0.0026  last_data_time: 0.0027   lr: 4.7553e-05  max_mem: 4314M
[03/10 03:59:25 d2.utils.events]:  eta: 0:01:24  iter: 7

roi_heads.box_predictor.bbox_pred.{bias, weight}
roi_heads.box_predictor.cls_score.{bias, weight}
roi_heads.mask_head.predictor.{bias, weight}


[03/10 04:01:06 d2.engine.train_loop]: Starting training from iteration 0
[03/10 04:01:10 d2.utils.events]:  eta: 0:01:30  iter: 19  total_loss: 5.308  loss_cls: 0.8743  loss_box_reg: 0.481  loss_mask: 0.6841  loss_rpn_cls: 3.043  loss_rpn_loc: 0.3688    time: 0.1928  last_time: 0.1727  data_time: 0.0191  last_data_time: 0.0024   lr: 1.5585e-05  max_mem: 4314M
[03/10 04:01:14 d2.utils.events]:  eta: 0:01:29  iter: 39  total_loss: 2.778  loss_cls: 0.8157  loss_box_reg: 0.6799  loss_mask: 0.6386  loss_rpn_cls: 0.2907  loss_rpn_loc: 0.3193    time: 0.1952  last_time: 0.2166  data_time: 0.0025  last_data_time: 0.0026   lr: 3.1569e-05  max_mem: 4314M
[03/10 04:01:18 d2.utils.events]:  eta: 0:01:25  iter: 59  total_loss: 2.474  loss_cls: 0.6824  loss_box_reg: 0.6282  loss_mask: 0.5674  loss_rpn_cls: 0.1656  loss_rpn_loc: 0.3173    time: 0.1947  last_time: 0.2273  data_time: 0.0024  last_data_time: 0.0024   lr: 4.7553e-05  max_mem: 4314M
[03/10 04:01:22 d2.utils.events]:  eta: 0:01:21  iter: 

roi_heads.box_predictor.bbox_pred.{bias, weight}
roi_heads.box_predictor.cls_score.{bias, weight}
roi_heads.mask_head.predictor.{bias, weight}


[03/10 04:03:05 d2.engine.train_loop]: Starting training from iteration 0
[03/10 04:03:10 d2.utils.events]:  eta: 0:01:40  iter: 19  total_loss: 5.874  loss_cls: 0.8572  loss_box_reg: 0.5211  loss_mask: 0.687  loss_rpn_cls: 3.327  loss_rpn_loc: 0.4621    time: 0.2074  last_time: 0.2403  data_time: 0.0171  last_data_time: 0.0025   lr: 1.5585e-05  max_mem: 4326M
[03/10 04:03:14 d2.utils.events]:  eta: 0:01:36  iter: 39  total_loss: 2.715  loss_cls: 0.7883  loss_box_reg: 0.666  loss_mask: 0.6432  loss_rpn_cls: 0.3204  loss_rpn_loc: 0.2898    time: 0.2038  last_time: 0.2111  data_time: 0.0025  last_data_time: 0.0024   lr: 3.1569e-05  max_mem: 4326M
[03/10 04:03:18 d2.utils.events]:  eta: 0:01:31  iter: 59  total_loss: 2.292  loss_cls: 0.6694  loss_box_reg: 0.6206  loss_mask: 0.5694  loss_rpn_cls: 0.1092  loss_rpn_loc: 0.2244    time: 0.2036  last_time: 0.1755  data_time: 0.0025  last_data_time: 0.0028   lr: 4.7553e-05  max_mem: 4326M
[03/10 04:03:22 d2.utils.events]:  eta: 0:01:23  iter: 7