In [1]:
import cv2

from ditod import add_vit_config

import torch

from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = "icdar19_configs/maskrcnn/maskrcnn_dit_base.yaml"
opts = []
image = "img.png"

In [3]:
def predict(image):
    # Step 1: instantiate config
    cfg = get_cfg()
    add_vit_config(cfg)
    cfg.merge_from_file(config)

    # Step 2: add model weights URL to config
    cfg.merge_from_list(opts)

    # Step 3: set device
    device = "cpu"
    cfg.MODEL.DEVICE = device

    # Step 4: define model
    predictor = DefaultPredictor(cfg)
    # Step 5: run inference
    img = cv2.imread(image)

    md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
    if cfg.DATASETS.TEST[0]=='icdar2019_test':
        md.set(thing_classes=["table"])
    else:
        md.set(thing_classes=["text","title","list","table","figure"])
        
    output = predictor(img)["instances"]
        
    v = Visualizer(img[:, :, ::-1],
                md,
                scale=1.0,
                instance_mode=ColorMode.SEGMENTATION)
    result = v.draw_instance_predictions(output.to("cpu"))
    result_image = result.get_image()[:, :, ::-1]
    
    return img, result_image, output.to("cpu")

In [4]:
img, result_img, output = predict(image)

dit-base-224-p16-500k-62d53a.pth: 1.11GB [12:50, 1.44MB/s]                                                                                                                                                                             
Some model parameters or buffers are not found in the checkpoint:
[34mbackbone.bottom_up.backbone.blocks.0.attn.proj.{bias, weight}[0m
[34mbackbone.bottom_up.backbone.blocks.0.attn.qkv.weight[0m
[34mbackbone.bottom_up.backbone.blocks.0.attn.{q_bias, v_bias}[0m
[34mbackbone.bottom_up.backbone.blocks.0.mlp.fc1.{bias, weight}[0m
[34mbackbone.bottom_up.backbone.blocks.0.mlp.fc2.{bias, weight}[0m
[34mbackbone.bottom_up.backbone.blocks.0.norm1.{bias, weight}[0m
[34mbackbone.bottom_up.backbone.blocks.0.norm2.{bias, weight}[0m
[34mbackbone.bottom_up.backbone.blocks.0.{gamma_1, gamma_2}[0m
[34mbackbone.bottom_up.backbone.blocks.1.attn.proj.{bias, weight}[0m
[34mbackbone.bottom_up.backbone.blocks.1.attn.qkv.weight[0m
[34mbackbone.bottom_up.backbon

The checkpoint state_dict contains keys that are not used by the model:
  [35mcls_token[0m
  [35mmask_token[0m
  [35mpos_embed[0m
  [35mpatch_embed.proj.{bias, weight}[0m
  [35mblocks.0.{gamma_1, gamma_2}[0m
  [35mblocks.0.norm1.{bias, weight}[0m
  [35mblocks.0.attn.{q_bias, v_bias}[0m
  [35mblocks.0.attn.qkv.weight[0m
  [35mblocks.0.attn.proj.{bias, weight}[0m
  [35mblocks.0.norm2.{bias, weight}[0m
  [35mblocks.0.mlp.fc1.{bias, weight}[0m
  [35mblocks.0.mlp.fc2.{bias, weight}[0m
  [35mblocks.1.{gamma_1, gamma_2}[0m
  [35mblocks.1.norm1.{bias, weight}[0m
  [35mblocks.1.attn.{q_bias, v_bias}[0m
  [35mblocks.1.attn.qkv.weight[0m
  [35mblocks.1.attn.proj.{bias, weight}[0m
  [35mblocks.1.norm2.{bias, weight}[0m
  [35mblocks.1.mlp.fc1.{bias, weight}[0m
  [35mblocks.1.mlp.fc2.{bias, weight}[0m
  [35mblocks.2.{gamma_1, gamma_2}[0m
  [35mblocks.2.norm1.{bias, weight}[0m
  [35mblocks.2.attn.{q_bias, v_bias}[0m
  [35mblocks.2.attn.qkv.weight[0m
  

In [5]:
print(output)

Instances(num_instances=100, image_height=793, image_width=1104, fields=[pred_boxes: Boxes(tensor([[7.4608e+02, 7.3757e+02, 8.0659e+02, 7.9300e+02],
        [7.2532e+02, 2.9953e+02, 7.5786e+02, 3.3130e+02],
        [4.2741e+02, 1.3261e+02, 4.8939e+02, 1.9714e+02],
        [3.4244e+02, 1.2849e+02, 4.0282e+02, 1.9047e+02],
        [9.0393e+02, 3.8139e+02, 9.6609e+02, 4.4461e+02],
        [6.4300e+02, 7.4757e+02, 6.7536e+02, 7.7855e+02],
        [4.3586e+02, 7.3809e+02, 4.9816e+02, 7.9300e+02],
        [3.4857e+02, 2.0689e+02, 4.1098e+02, 2.7029e+02],
        [9.7498e+02, 3.8929e+02, 1.0364e+03, 4.5112e+02],
        [6.3454e+02, 2.9999e+02, 6.6781e+02, 3.3129e+02],
        [5.2234e+02, 6.6276e+02, 5.6376e+02, 6.8455e+02],
        [3.5133e+02, 2.1827e+02, 3.9160e+02, 2.4053e+02],
        [7.4872e+02, 7.0543e+02, 7.8796e+02, 7.2760e+02],
        [7.2091e+02, 3.8996e+02, 7.8299e+02, 4.5154e+02],
        [6.8965e+02, 3.8959e+02, 7.5160e+02, 4.5171e+02],
        [4.7707e+02, 7.4711e+02, 5.3883