In [8]:
from mmdet.models.detectors.grounding_dino import GroundingDINO
from mmdet.models.dense_heads.grounding_dino_head import GroundingDINOHead
from mmdet.structures.det_data_sample import DetDataSample
from mmdet.apis.inference import init_detector, inference_detector

from pathlib import Path
import os

In [11]:
workdir = Path(os.path.abspath("./workdir"))
workdir.mkdir(exist_ok=True)

workdir

PosixPath('/Users/shreyas/Developer/Research/GroundingDINO/agent/workdir')

In [6]:
import copy
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
from mmcv.ops import RoIPool
from mmcv.transforms import Compose
from mmengine.config import Config
from mmengine.dataset import default_collate
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint

from mmdet.registry import DATASETS
from mmdet.utils import ConfigType
from mmdet.evaluation import get_classes
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample, SampleList
from mmdet.utils import get_test_pipeline_cfg

ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]

def inference_tensors(
    model: GroundingDINO,
    imgs: ImagesType,
    test_pipeline: Optional[Compose] = None,
    text_prompt: Optional[str] = None,
    custom_entities: bool = False,
) -> Union[DetDataSample, SampleList]:
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str, ndarray, Sequence[str/ndarray]):
           Either image files or loaded images.
        test_pipeline (:obj:`Compose`): Test pipeline.

    Returns:
        :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
        If imgs is a list or tuple, the same length list type results
        will be returned, otherwise return the detection results directly.
    """

    if isinstance(imgs, (list, tuple)):
        is_batch = True
    else:
        imgs = [imgs]
        is_batch = False

    cfg = model.cfg

    if test_pipeline is None:
        cfg = cfg.copy()
        test_pipeline = get_test_pipeline_cfg(cfg)
        if isinstance(imgs[0], np.ndarray):
            # Calling this method across libraries will result
            # in module unregistered error if not prefixed with mmdet.
            test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'

        test_pipeline = Compose(test_pipeline)

    if model.data_preprocessor.device.type == 'cpu':
        for m in model.modules():
            assert not isinstance(
                m, RoIPool
            ), 'CPU inference with RoIPool is not supported currently.'

    result_list = []
    for i, img in enumerate(imgs):
        # prepare data
        if isinstance(img, np.ndarray):
            # TODO: remove img_id.
            data_ = dict(img=img, img_id=0)
        else:
            # TODO: remove img_id.
            data_ = dict(img_path=img, img_id=0)

        if text_prompt:
            data_['text'] = text_prompt
            data_['custom_entities'] = custom_entities

        # build the data pipeline
        data_ = test_pipeline(data_)

        data_['inputs'] = [data_['inputs']]
        data_['data_samples'] = [data_['data_samples']]

        # forward the model
        # with torch.no_grad():
        # results = model.test_step(data_)[0]

        data = model.data_preprocessor(data_, False)
        # results = model._run_forward(data, mode="predict")[0]

        # results = model(**data, mode="predict")[0]
        results = model.predict(data["inputs"], data["data_samples"])[0]

        result_list.append(results)

    if not is_batch:
        return result_list[0]
    else:
        return result_list

In [7]:
m: GroundingDINO = init_detector(
    config="../libraries/mmdet-configs/Swin-B/rrsc-grounding_dino_swin-b_finetune_16xb2_1x_coco.py",
    checkpoint='../checkpoints/swin_b_epoch_16.pth',
    device='mps'
)

bbox_head: GroundingDINOHead = m.bbox_head

Loads checkpoint by local backend from path: ../checkpoints/swin_b_epoch_16.pth
