In [1]:
import argparse
import logging
import os
import time
import json

import torch

import detectron2.model_zoo
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.engine import DefaultPredictor

# recommended for nice information
logging.basicConfig(level=logging.INFO)

MODEL_ZOO_CONFIGS = {
    "R50_C4": "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml",
    "R50_DC5": "COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml",
    "R50_FPN": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
    "R101_C4": "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml",
    "R101_DC5": "COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml",
    "R101_FPN": "COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml",
    "X101": "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml",
    "R50": "COCO-Detection/retinanet_R_50_FPN_3x.yaml",
    "R101": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
}
# model = "R50_FPN"

In [2]:
def get_model_summary(model):
    model_config = MODEL_ZOO_CONFIGS[model]
    cfg = get_cfg()
    cfg.merge_from_file(detectron2.model_zoo.get_config_file(model_config))
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
    cfg.MODEL.DEVICE = 'cpu'
    # cfg.MODEL.WEIGHTS = detectron2.model_zoo.get_checkpoint_url(model_config)
    predictor = DefaultPredictor(cfg)
    return str(predictor.model)

# this version doesn't need to download weights, pretrained nor transfer for backbone...
from detectron2.modeling import build_model

def get_model_summary(model):
    model_config = MODEL_ZOO_CONFIGS[model]
    cfg = get_cfg()
    cfg.merge_from_file(detectron2.model_zoo.get_config_file(model_config))
    cfg.MODEL.DEVICE = 'cpu'
    model = build_model(cfg)
    return str(model)


In [3]:
!mkdir -p model_dumps

In [4]:
for model in MODEL_ZOO_CONFIGS:
    print(model)
    with open(f'model_dumps/{model}.txt', 'w') as f:
        f.write(get_model_summary(model) + '\n')

R50_C4
R50_DC5
R50_FPN
R101_C4
R101_DC5
R101_FPN
X101




R50




R101


In [5]:
# import torchsummary
# this seems not to work, actually
# and I don't know the input shape