In [1]:
import cv2

from ditod import add_vit_config

import torch

from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = "publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml"
opts = ['MODEL.WEIGHTS', 'https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_mrcnn.pth']
image = "img.png"

In [3]:
def predict(image):
    # Step 1: instantiate config
    cfg = get_cfg()
    add_vit_config(cfg)
    cfg.merge_from_file(config)

    # Step 2: add model weights URL to config
    cfg.merge_from_list(opts)

    # Step 3: set device
    device = "cpu"
    cfg.MODEL.DEVICE = device

    # Step 4: define model
    predictor = DefaultPredictor(cfg)
    # Step 5: run inference
    img = cv2.imread(image)

    md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
    if cfg.DATASETS.TEST[0]=='icdar2019_test':
        md.set(thing_classes=["table"])
    else:
        md.set(thing_classes=["text","title","list","table","figure"])
        
    output = predictor(img)["instances"]
        
    v = Visualizer(img[:, :, ::-1],
                md,
                scale=1.0,
                instance_mode=ColorMode.SEGMENTATION)
    result = v.draw_instance_predictions(output.to("cpu"))
    result_image = result.get_image()[:, :, ::-1]
    
    return img, result_image, output.to("cpu")

In [5]:
img, result_img, output = predict("img.png")

  "See the documentation of nn.Upsample for details.".format(mode)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [6]:
cv2.imwrite("out.jpg", result_img)

True

In [7]:
print(output)

Instances(num_instances=19, image_height=1638, image_width=2186, fields=[pred_boxes: Boxes(tensor([[ 317.9785,  388.1022,  774.4274, 1315.1630],
        [ 317.3665,  399.4422,  711.4623,  620.7697],
        [ 321.0168,  346.7388,  768.9996, 1359.7750],
        [ 325.9033,  400.3381,  728.4852,  614.8550],
        [ 323.7854,  761.2457,  777.2447,  928.8315],
        [ 323.5898,  760.9687,  777.3021,  930.0524],
        [ 324.2019,  486.9088,  630.0565,  548.1574],
        [ 312.9458,  394.8235,  771.6876, 1315.3929],
        [ 323.9550,  761.5464,  777.0757,  931.2034],
        [ 318.8689, 1088.4080,  765.8888, 1291.1981],
        [ 323.9459,  763.7854,  776.8008,  932.8073],
        [ 313.3763,  397.3802,  770.4563, 1324.1418],
        [ 325.0838, 1178.7592,  770.7908, 1318.9058],
        [ 316.9553,  480.3191,  633.3759,  628.1420],
        [ 322.5407,  762.5354,  777.0079,  934.0746],
        [ 316.3450,  394.0755,  768.6317, 1314.1211],
        [ 333.1350,  872.7140,  650.6919,  92

In [8]:
def sort_index(instance):
    bbox = instance.pred_boxes.tensor
    sorting = sorted(range(bbox.size()[0]), key=lambda k: bbox[k][1].numpy())
    return sorting


In [9]:
def sort_tensor(tensor, sort_mask):
    return tensor[sort_mask]

In [10]:
def sort(instance):
    sort_mask = sort_index(instance)
    # sort pred_boxes
    instance.pred_boxes.tensor = sort_tensor(instance.pred_boxes.tensor, sort_mask)
    # sort score
    instance.scores = sort_tensor(instance.scores, sort_mask)
    # sort pred_classes
    instance.pred_classes = sort_tensor(instance.pred_classes, sort_mask)
    # sort pred_masks
    instance.pred_masks = sort_tensor(instance.pred_masks, sort_mask)
    return instance

In [11]:
print(sort(output))

Instances(num_instances=19, image_height=1638, image_width=2186, fields=[pred_boxes: Boxes(tensor([[ 321.0168,  346.7388,  768.9996, 1359.7750],
        [ 317.9785,  388.1022,  774.4274, 1315.1630],
        [ 316.3450,  394.0755,  768.6317, 1314.1211],
        [ 312.9458,  394.8235,  771.6876, 1315.3929],
        [ 313.3763,  397.3802,  770.4563, 1324.1418],
        [ 317.3665,  399.4422,  711.4623,  620.7697],
        [ 325.9033,  400.3381,  728.4852,  614.8550],
        [ 317.6023,  480.2979,  633.4324,  628.1290],
        [ 316.9553,  480.3191,  633.3759,  628.1420],
        [ 324.2019,  486.9088,  630.0565,  548.1574],
        [ 320.5102,  563.7816,  613.2209,  619.7298],
        [ 323.5898,  760.9687,  777.3021,  930.0524],
        [ 323.7854,  761.2457,  777.2447,  928.8315],
        [ 323.9550,  761.5464,  777.0757,  931.2034],
        [ 322.5407,  762.5354,  777.0079,  934.0746],
        [ 323.9459,  763.7854,  776.8008,  932.8073],
        [ 333.1350,  872.7140,  650.6919,  92