<a href="https://colab.research.google.com/github/theviderlab/computer_vision/blob/main/Faster_R_CNN_con_Detectron2_Backbone_CVNet_Entrenamiento.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Instalación de dependencias
!pip install -q torch torchvision torchaudio
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

from google.colab import drive
import json
import sys

In [None]:
# 2. Definición de rutas y carga de CVNet
base_path = "/content/drive/MyDrive/ViderLab/06 - INFORMACIÓN/Capacitación/Master/TFM/Código/"
cvnet_path = base_path + "image_retrieval/backbones/"
sys.path.append(cvnet_path)

import matplotlib.pyplot as plt
import torch
import os
from detectron2.layers import ShapeSpec
from detectron2.modeling import BACKBONE_REGISTRY
from detectron2.modeling.backbone import Backbone
from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool
from detectron2.config import get_cfg
from detectron2.model_zoo import get_config_file
from detectron2.engine import DefaultTrainer
from detectron2.data import DatasetCatalog, MetadataCatalog
from cvnet.cvnet_model import CVNet_Rerank

# Montaje de Google Drive
drive.mount('/content/drive', force_remount=True)

# Paths para checkpoints y datos
pretrained_weights_path = base_path + "assets/weights/CVPR2022_CVNet_R50.pyth"
dataset_path = base_path + "assets/database/open-images"
train_json_path = dataset_path + "/annotations_touristic_train.json"
val_json_path   = os.path.join(dataset_path, "annotations_touristic_val.json")
output_dir = base_path + "/assets/weights/faster_r-cnn_cvnet_finetuned_openimages"

In [None]:
# 3. Carga de anotaciones y definición de clases
with open(train_json_path, 'r') as f:
    train_data = json.load(f)
with open(val_json_path, 'r') as f:
    val_data = json.load(f)

thing_classes = [
    "Building", "Castle", "Fountain", "Lighthouse",
    "Sculpture", "Skyscraper", "Tower"
]

In [None]:
# 4. Registro de dataset en Detectron2
DatasetCatalog.register("openimages_touristic_train", lambda: train_data)
MetadataCatalog.get("openimages_touristic_train").set(thing_classes=thing_classes)
DatasetCatalog.register("openimages_touristic_val", lambda: val_data)
MetadataCatalog.get("openimages_touristic_val").set(thing_classes=thing_classes)

In [None]:
# 5. Wrapper de CVNet para FPN
class CVNetBottomUp(Backbone):
    def __init__(self, cvnet):
        super().__init__()
        self.cvnet = cvnet

    def forward(self, x):
        return self.cvnet.extract_backbone_stages(x)

    def output_shape(self):
        return {
            "res2": ShapeSpec(channels=256, stride=4),
            "res3": ShapeSpec(channels=512, stride=8),
            "res4": ShapeSpec(channels=1024, stride=16),
            "res5": ShapeSpec(channels=2048, stride=32),
        }

In [None]:
# 6. Registro de constructor de backbone CVNet+FPN
@BACKBONE_REGISTRY.register()
def build_cvnet_fpn(cfg, input_shape):
    # Cargar CVNet_Rerank y sus pesos preentrenados
    cvnet = CVNet_Rerank(RESNET_DEPTH=50, REDUCTION_DIM=2048)
    checkpoint = torch.load(pretrained_weights_path, map_location='cpu')
    state = checkpoint.get('model_state', checkpoint)
    cvnet.load_state_dict(state, strict=False)
    cvnet = cvnet.cuda()
    # Congelar CVNet
    for p in cvnet.parameters():
        p.requires_grad = False
    cvnet.eval()
    # Construir FPN sobre CVNet
    bottom_up = CVNetBottomUp(cvnet)
    return FPN(
        bottom_up=bottom_up,
        in_features=cfg.MODEL.FPN.IN_FEATURES,
        out_channels=cfg.MODEL.FPN.OUT_CHANNELS,
        top_block=LastLevelMaxPool(),
        fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
    )

In [None]:
# 7. Configuración base de Detectron2 y del modelo
cfg = get_cfg()
cfg.merge_from_file(get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.RESNETS.DEPTH = 50
cfg.MODEL.WEIGHTS = ""  # sin pesos COCO
cfg.MODEL.BACKBONE.NAME = "build_cvnet_fpn"
# Configuración de FPN
cfg.MODEL.FPN.IN_FEATURES = ["res2", "res3", "res4", "res5"]
cfg.MODEL.FPN.OUT_CHANNELS = 256
cfg.MODEL.FPN.FUSE_TYPE = "sum"

In [None]:
# 8. Parámetros de entrenamiento
# Datasets
cfg.DATASETS.TRAIN = ("openimages_touristic_train",)
cfg.DATASETS.TEST  = ("openimages_touristic_val",)

# Número de trabajadores y directorio de salida
cfg.DATALOADER.NUM_WORKERS = 4
cfg.OUTPUT_DIR = output_dir

# Solver y LR
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.BASE_LR = 5e-5 * (8 / 2)

# Scheduler: Warmup + MultiStepLR
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"

# Factor multiplicativo al llegar a cada step
cfg.SOLVER.GAMMA = 0.1
# Parámetros de warm-up
cfg.SOLVER.WARMUP_FACTOR = 0.001    # LR inicial = BASE_LR * WARMUP_FACTOR
cfg.SOLVER.WARMUP_ITERS = 1000      # iters de warmup
cfg.SOLVER.WARMUP_METHOD = "linear" # lineal desde WARMUP_FACTOR hasta 1

# Definimos iteraciones: 5 épocas ≈ 5 * (160000 / batch)
steps_per_epoch = 160000 // batch_size()
cfg.SOLVER.MAX_ITER = steps_per_epoch * 5  # ≈ 100k iteraciones
# A qué iteraciones aplicar el decay
cfg.SOLVER.STEPS = (int(cfg.SOLVER.MAX_ITER * 0.6), int(cfg.SOLVER.MAX_ITER * 0.8))

# ROI Heads
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(thing_classes)

# Evaluación periódica
eval_period = 1000  # iteraciones
cfg.TEST.EVAL_PERIOD = eval_period

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [None]:
# 9. Inicializar trainer y entrenar
trainer = DefaultTrainer(gcfg)
trainer.resume_or_load(resume=False)
trainer.train()

In [None]:
# 10. Lectura y graficado de métricas de entrenamiento
metrics_path = os.path.join(gcfg.OUTPUT_DIR, "metrics.json")
metrics = []
with open(metrics_path) as f:
    for line in f:
        if line.strip() and not line.startswith("{\"iteration\": 0"):
            metrics.append(json.loads(line))
iterations = [m["iteration"] for m in metrics]
keys = ["total_loss", "loss_cls", "loss_box_reg", "loss_rpn_cls", "loss_rpn_loc"]
plt.figure(figsize=(12, 6))
for key in keys:
    plt.plot(iterations, [m.get(key) for m in metrics], label=key)
plt.xlabel("Iteración")
plt.ylabel("Pérdida")
plt.title("Curvas de entrenamiento")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()