In [1]:
import os
import argparse
import json
from datetime import datetime
import sys
import random
import numpy as np
import torch
from torch.utils.data import DataLoader

import sys
sys.path.append("../detr")
from engine import evaluate, train_one_epoch
from models import build_model
import util.misc as utils
import datasets.transforms as R
from models import build_model

In [2]:
from table_datasets import PDFTablesDataset, TightAnnotationCrop, RandomPercentageCrop, RandomErasingWithTarget, ToPILImageWithTarget, RandomMaxResize, RandomCrop
from eval import eval_coco, eval_tsr

In [3]:
from main import get_class_map, get_transform, get_data, get_model

In [4]:
from PIL import Image, ImageDraw

In [5]:
from tqdm import tqdm

## Detection Visualization

In [6]:
# model_load_path = "../detection_train_output/20220520164328/model_20.pth"
model_load_path = "../pubtables1m_detection_detr_r18.pth"
data_type = "detection"


In [7]:
config_args = json.load(open(f"{data_type}_config.json", 'rb'))
args = type('Args', (object,), config_args)
args.model_load_path = model_load_path
args.data_type = data_type
args.mode = "eval"
args.data_root_dir = f"/home/shiki/hdd/WikiTableExtraction/{data_type}"
print(args.__dict__)
print('-' * 100)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

print("loading model")
device = torch.device(args.device)
model, criterion, postprocessors = get_model(args, device)

data_loader_test, dataset_test = get_data(args)
model.eval()
criterion.eval()

{'lr': 5e-05, 'lr_backbone': 1e-05, 'batch_size': 2, 'weight_decay': 0.0001, 'epochs': 20, 'lr_drop': 1, 'lr_gamma': 0.9, 'clip_max_norm': 0.1, 'backbone': 'resnet18', 'num_classes': 2, 'dilation': False, 'position_embedding': 'sine', 'emphasized_weights': {}, 'enc_layers': 6, 'dec_layers': 6, 'dim_feedforward': 2048, 'hidden_dim': 256, 'dropout': 0.1, 'nheads': 8, 'num_queries': 15, 'pre_norm': True, 'masks': False, 'aux_loss': False, 'mask_loss_coef': 1, 'dice_loss_coef': 1, 'ce_loss_coef': 1, 'bbox_loss_coef': 5, 'giou_loss_coef': 2, 'eos_coef': 0.4, 'set_cost_class': 1, 'set_cost_bbox': 5, 'set_cost_giou': 2, 'device': 'cuda', 'seed': 42, 'start_epoch': 0, 'num_workers': 1, '__module__': '__main__', '__dict__': <attribute '__dict__' of 'Args' objects>, '__weakref__': <attribute '__weakref__' of 'Args' objects>, '__doc__': None, 'model_load_path': '../pubtables1m_detection_detr_r18.pth', 'data_type': 'detection', 'mode': 'eval', 'data_root_dir': '/home/shiki/hdd/WikiTableExtraction/

SetCriterion(
  (matcher): HungarianMatcher()
)

In [10]:
OUT_PATH = "../detection_train_output/pretrained_vis"

In [11]:
@torch.no_grad()
def save_detection(model, postprocessors, base_ds, device, thresh=0.5):
    idx = np.random.choice(range(len(base_ds)))
    for idx in tqdm(range(len(base_ds))):
        page_id = base_ds.page_ids[idx]
        img_path = os.path.join(base_ds.root, "..", "images", page_id + base_ds.image_extension)
        annot_path = os.path.join(base_ds.root, page_id + ".xml")
        img = Image.open(img_path).convert("RGB")
        draw = ImageDraw.Draw(img)
        samples = [base_ds[idx][0].to(device)]
        targets = [{k: v.to(device) for k, v in base_ds[idx][1].items()}]
        outputs = model(samples)

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors['bbox'](outputs, orig_target_sizes)

        for score, bbox in zip(results[0]['scores'], results[0]['boxes']):
            if score.cpu().item() < thresh:
                continue
            bbox = bbox.cpu().numpy().flatten()
            draw.rectangle(bbox, outline='red', width=2)
#         for bbox in targets[0]['boxes']:
#             bbox = bbox.cpu().numpy().flatten()
#             print(bbox)
#             draw.rectangle(bbox, outline='green', width=2)
#         print(page_id)
        img.save(f"{OUT_PATH}/{page_id}.png")

In [12]:
save_detection(model, postprocessors, dataset_test, device)

100%|██████████| 3809/3809 [28:31<00:00,  2.23it/s]  


## Structure

In [22]:
model_load_path = "../structure_train_output/20220521084016/model_20.pth"
# model_load_path = "../pubtables1m_structure_detr_r18.pth"
data_type = "structure"
config_args = json.load(open(f"{data_type}_config.json", 'rb'))
args = type('Args', (object,), config_args)
args.model_load_path = model_load_path
args.data_type = data_type
args.mode = "eval"
args.data_root_dir = f"/home/shiki/hdd/WikiTableExtraction/{data_type}"
print(args.__dict__)
print('-' * 100)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

print("loading model")
device = torch.device(args.device)
model, criterion, postprocessors = get_model(args, device)

data_loader_test, dataset_test = get_data(args)
model.eval()
criterion.eval()

{'lr': 5e-05, 'lr_backbone': 1e-05, 'batch_size': 2, 'weight_decay': 0.0001, 'epochs': 20, 'lr_drop': 1, 'lr_gamma': 0.9, 'clip_max_norm': 0.1, 'backbone': 'resnet18', 'num_classes': 6, 'dilation': False, 'position_embedding': 'sine', 'emphasized_weights': {}, 'enc_layers': 6, 'dec_layers': 6, 'dim_feedforward': 2048, 'hidden_dim': 256, 'dropout': 0.1, 'nheads': 8, 'num_queries': 125, 'pre_norm': True, 'masks': False, 'aux_loss': False, 'mask_loss_coef': 1, 'dice_loss_coef': 1, 'ce_loss_coef': 1, 'bbox_loss_coef': 5, 'giou_loss_coef': 2, 'eos_coef': 0.4, 'set_cost_class': 1, 'set_cost_bbox': 5, 'set_cost_giou': 2, 'device': 'cuda', 'seed': 42, 'start_epoch': 0, 'num_workers': 1, '__module__': '__main__', '__dict__': <attribute '__dict__' of 'Args' objects>, '__weakref__': <attribute '__weakref__' of 'Args' objects>, '__doc__': None, 'model_load_path': '../structure_train_output/20220521084016/model_20.pth', 'data_type': 'structure', 'mode': 'eval', 'data_root_dir': '/home/shiki/hdd/Wik

SetCriterion(
  (matcher): HungarianMatcher()
)

In [16]:
# ! mkdir ../structure_train_output/pretrained_vis

mkdir: cannot create directory ‘../structure_train_output/pretrained_vis’: File exists


In [33]:
class_color = {
    0: 'red', 
    1: 'blue', 
    2: 'green', 
    3: 'purple', 
    4: 'gray',
    5: 'orange'
}

OUT_PATH = "/home/shiki/hdd/TableRecOut/detr/trained/structure"

@torch.no_grad()
def random_show_structure(model, postprocessors, base_ds, device, thresh=0.9):
    # idx = np.random.choice(range(len(base_ds)))
    for idx in tqdm(range(len(base_ds))):
        page_id = base_ds.page_ids[idx]
        img_path = os.path.join(base_ds.root, "..", "images", page_id + base_ds.image_extension)
        annot_path = os.path.join(base_ds.root, page_id + ".xml")
        img = Image.open(img_path).convert("RGB")
        draw = ImageDraw.Draw(img)
        samples = [base_ds[idx][0].to(device)]
        targets = [{k: v.to(device) for k, v in base_ds[idx][1].items()}]
        outputs = model(samples)

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors['bbox'](outputs, orig_target_sizes)
        for i in range(len(results[0]['boxes'])):
            if results[0]['scores'][i] < thresh:
                continue

            bbox = results[0]['boxes'][i].cpu().numpy().flatten()
            color = class_color[results[0]['labels'][i].cpu().item()]
            draw.rectangle(bbox, outline=color, width=2)
        img.save(f"{OUT_PATH}/{page_id}.png")

In [34]:
random_show_structure(model, postprocessors, dataset_test, device)

100%|██████████| 10272/10272 [24:38<00:00,  6.95it/s] 
