augment the coco format labels with DETR hidden states for stage 3 of the project

input: output of ptn2gcn

runs DETR on each model to get the

In [3]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import os
import json
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
import torchvision.ops as ops  # For NMS
from utils import get_iou

In [None]:
def gcn2stage3(
        input_labels, 
        output_labels, 
        image_dir, 
        detr_model, 
        checkpoint, 
        batch_size, 
        iou_threshold=0.5, 
        confidence_threshold=0.5, 
        iou_match_threshold=0.5):
    # Set up the image processor and the DETR model
    image_processor = DetrImageProcessor.from_pretrained(detr_model)
    model = DetrForObjectDetection.from_pretrained(detr_model)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # Count total lines in the input file for progress tracking
    with open(input_labels, 'r', encoding="utf8") as infile:
        total_lines = sum(1 for _ in infile)

    # Progress bar setup
    progress = tqdm(total=total_lines, desc="Processing Batches", unit="samples")

    with open(input_labels, 'r', encoding="utf8") as infile, open(output_labels, 'w', encoding="utf8") as outfile:
        labels = []

        for line in infile:
            # Parse the JSON line
            label = json.loads(line)
            labels.append(label)

            # Update progress bar
            progress.update(1)

            if len(labels) == batch_size:
                # Load and process the batch of images
                images = [Image.open(os.path.join(image_dir, label['filename'])).convert("RGB") for label in labels]
                inputs = image_processor(images=images, return_tensors="pt")

                # Forward pass through the model
                outputs = model(**inputs)

                logits = outputs.logits  # Shape: (batch_size, num_queries, num_classes)
                pred_boxes = outputs.pred_boxes  # Shape: (batch_size, num_queries, 4)
                hidden_states = outputs.last_hidden_state  # Shape: (batch_size, num_queries, hidden_dim)

                # Process each image in the batch
                for idx, label in enumerate(labels):
                    logits_img = logits[idx]  # Shape: (num_queries, num_classes)
                    boxes_img = pred_boxes[idx]  # Shape: (num_queries, 4)
                    hidden_states_img = hidden_states[idx]  # Shape: (num_queries, hidden_dim)

                    # as per the docs
                    prob = F.softmax(logits_img, dim=-1)
                    scores, _ = prob[..., :-1].max(-1)

                    # Filter by confidence threshold
                    keep = scores > confidence_threshold
                    boxes_img = boxes_img[keep]
                    hidden_states_img = hidden_states_img[keep]
                    scores = scores[keep]

                    # Apply NMS
                    if boxes_img.shape[0] > 0:
                        keep_nms = ops.nms(boxes_img, scores, iou_threshold)
                        boxes_img = boxes_img[keep_nms]
                        hidden_states_img = hidden_states_img[keep_nms]

                    # Match DETR boxes to ground truth boxes
                    bbox_indices = []
                    filtered_boxes = []
                    filtered_hidden_states = []
                    gt_bboxes = label['gt_bboxes']
                    gt_bbox_indices = label['gt_bbox_indices']

                    for detr_bbox, hidden_state in zip(boxes_img, hidden_states_img):
                        # Convert DETR bbox to COCO format (list of 4 numbers)
                        detr_bbox_coco = detr_bbox.detach().cpu().numpy().tolist()

                        # Find the ground truth box with the highest IoU
                        max_iou = 0
                        best_gt_index = -1
                        for i, gt_bbox in enumerate(gt_bboxes):
                            iou = get_iou(detr_bbox_coco, gt_bbox)
                            if iou > max_iou:
                                max_iou = iou
                                best_gt_index = i

                        # Keep the DETR bbox if the highest IoU exceeds the threshold
                        if max_iou >= iou_match_threshold:
                            filtered_boxes.append(detr_bbox_coco)
                            filtered_hidden_states.append(hidden_state.detach().cpu().numpy().tolist())
                            bbox_indices.append(gt_bbox_indices[best_gt_index])

                    # Add fields to the label
                    label['bboxes'] = filtered_boxes
                    label['hidden_states'] = filtered_hidden_states
                    label['bbox_indices'] = bbox_indices

                # Write the processed batch to the output file
                for label in labels:
                    outfile.write(json.dumps(label) + '\n')

                # Reset batch
                labels = []

        # Close the progress bar
        progress.close()


In [None]:
gcn2stage3(
    input_labels=r"C:\Users\tangy\Downloads\DETR-GFTE\datasets\ptn_examples_val\PubTabNet_Examples-val.jsonl",
    output_labels=r"C:\Users\tangy\Downloads\DETR-GFTE\datasets\ptn_examples_val\ptn_examples_val.jsonl",
    image_dir=r"",
    detr_model="facebook/detr-resnet-50",
    checkpoint="",
    batch_size=4
)