In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import os
import json
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
import torchvision.ops as ops  # For NMS
from utils import get_iou, state_dict_cleaner, detr2xyxy, coco2xyxy, xyxy2coco, get_psuedo_knn, process_target, get_table_grid

In [None]:
from model import GNLightning

torch.cuda.empty_cache()

gnet = GNLightning.load_from_checkpoint(
                                    checkpoint_path=r"..\checkpoints\gnet_stage3.ckpt",
                                    d_model=1024,
                                    lr=1e-3,
                                    batch_size=2,
                                    num_workers=0,
                                    train_path=r'..\misc\placeholder.jsonl',
                                    val_path=r'..\misc\placeholder.jsonl')

device = torch.device('cuda:0')
gnet.to(device)
gnet.eval()

c:\Users\remote desktop\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning_fabric\utilities\cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


GNLightning(
  (gnet): GraphNetwork(
    (conv1): GCNConv(256, 1024)
    (conv2): GCNConv(1024, 1024)
    (conv3): GCNConv(1024, 1024)
    (lin1): Linear(in_features=2048, out_features=1024, bias=True)
    (lin_final): Linear(in_features=1024, out_features=3, bias=True)
  )
  (criterion): NLLLoss()
)

In [3]:
# this DETR model is only used for forward pass with the weights we trained from 

from pytorch_lightning import LightningModule

class DetrLightning(LightningModule):
    def __init__(self, model_name, checkpoint, num_labels):
        super().__init__()
        
        self.model = DetrForObjectDetection.from_pretrained(model_name, num_labels=num_labels,ignore_mismatched_sizes=True)
        state_dict = torch.load(checkpoint)['state_dict']
        state_dict = state_dict_cleaner(state_dict)
        self.model.load_state_dict(state_dict)
    
    def forward(self, **inputs):
        return self.model(**inputs)

In [38]:
def evaluate(
        input_labels,  
        image_dir, 
        detr_model, 
        checkpoint,
        num_labels=64,
        iou_threshold=0.8, 
        confidence_threshold=0.5, 
        iou_match_threshold=0.5,
        device = "cuda"):
    
    eval_values = []

    # Set up the image processor and the DETR model
    image_processor = DetrImageProcessor.from_pretrained(detr_model)
    detr = DetrLightning(detr_model, checkpoint,num_labels)
    detr.to(device)
    detr.eval()

    # Count total lines in the input file for progress tracking
    with open(input_labels, 'r', encoding="utf8") as infile:
        total_lines = sum(1 for _ in infile)

    # Progress bar setup
    progress = tqdm(total=total_lines, desc="Processing Batches", unit="samples")

    with open(input_labels, 'r', encoding="utf8") as infile:

        for line in infile:
            # Parse the JSON line
            label = json.loads(line)

            # Update progress bar 
            progress.update(1)

            # === FORWARD PASS THROUGH THE DETR MODEL TO GET THE PRED BBOXES AND HIDDEN STATES ===

            image = Image.open(os.path.join(image_dir, label['filename'])).convert("RGB")
            inputs = image_processor(images=image, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Forward pass through the model
            outputs = detr(**inputs)

            logits = outputs.logits[0]  # Shape: (batch_size, num_queries, num_classes)
            pred_boxes = outputs.pred_boxes[0]  # Shape: (num_queries (100) , 4)
            image_width, image_height = image.size  # Get original image dimensions
            pred_boxes = torch.tensor(
                [detr2xyxy(box.cpu().tolist(), image_width, image_height) for box in pred_boxes],
                device=device
            )
            hidden_states = outputs.last_hidden_state[0]  # Shape: (batch_size, num_queries, hidden_dim)

            # as per the docs
            prob = F.softmax(logits, dim=-1)
            scores, _ = prob[..., :-1].max(-1)

            # Filter by confidence threshold
            keep = scores > confidence_threshold
            pred_boxes = pred_boxes[keep]
            hidden_states = hidden_states[keep]
            scores = scores[keep]

            # Apply NMS
            if pred_boxes.shape[0] > 0:
                keep_nms = ops.nms(pred_boxes, scores, iou_threshold)
                pred_boxes = pred_boxes[keep_nms]
                hidden_states = hidden_states[keep_nms]

            bbox_indices = []
            filtered_boxes = []
            filtered_hidden_states = []
            gt_bboxes_coco = label['gt_bboxes']
            gt_bboxes_xyxy = [coco2xyxy(box) for box in gt_bboxes_coco]
            gt_bbox_indices = torch.tensor(label['gt_bbox_indices'], dtype=torch.int) 

            for pred_box, hidden_state in zip(pred_boxes, hidden_states):

                pred_box = pred_box.detach().cpu().numpy().tolist()

                # Find the ground truth box with the highest IoU
                max_iou = 0
                best_gt_index = -1
                for i, gt_bbox in enumerate(gt_bboxes_xyxy):
                    iou = get_iou(pred_box, gt_bbox)
                    if iou > max_iou:
                        max_iou = iou
                        best_gt_index = i

                # Keep the DETR bbox if the highest IoU exceeds the threshold
                if max_iou >= iou_match_threshold:
                    filtered_boxes.append(pred_box)
                    filtered_hidden_states.append(hidden_state.detach().cpu().numpy().tolist())
                    bbox_indices.append(gt_bbox_indices[best_gt_index])

            filtered_boxes = torch.tensor([xyxy2coco(bbox) for bbox in filtered_boxes])
            filtered_hidden_states = torch.tensor(filtered_hidden_states)

            probs, edge_index = gnet(filtered_boxes, filtered_hidden_states)

            probs, edge_index = probs.to('cpu'), edge_index.to('cpu')

            # === GET THE GROUNTRUTH EDGE SET AS A DICT OF (START BBOX INDEX, END BBOX INDEX): CLASS ===

            thead_grid, tbody_grid = get_table_grid(''.join(label['html']))
            table_grid = thead_grid + tbody_grid

            gt_edge_index = get_psuedo_knn(torch.tensor(gt_bboxes_coco))

            gt_bbox_index_pairs = torch.stack((
                gt_bbox_indices[gt_edge_index[0]],  # Start bounding boxes
                gt_bbox_indices[gt_edge_index[1]]   # End bounding boxes
            ), dim=1)

            gt_classes = process_target(gt_bbox_index_pairs, table_grid) # dtype long

            gt_edgeset = {}

            # Iterate over the pairs and classes
            for pair, gt_class in zip(gt_bbox_index_pairs, gt_classes):
                start_bbox_index = pair[0].item()  # Convert to a standard Python integer
                end_bbox_index = pair[1].item()   # Convert to a standard Python integer
                gt_edgeset[(start_bbox_index, end_bbox_index)] = gt_class.item()  # Map to the class value 

            # === PARSE THROUGH EACH PRED EDGE AND CHECK IF IT MATCHES WITH THE GT EDGESET ===
            numerator = 0
            for i, (start, end) in enumerate(edge_index.t()):
                
                # Get the predicted class for the edge
                predicted_class = torch.argmax(probs[i]).item()

                # no relationship edge, skip
                if predicted_class == 0: continue

                if (bbox_indices[start].item(), bbox_indices[end].item()) in gt_edgeset and gt_edgeset[(bbox_indices[start].item(), bbox_indices[end].item())] == predicted_class:
                    numerator+=1

            denominator = sum(1 for value in gt_edgeset.values() if value != 0)

            eval_values.append(numerator/denominator)

            if progress.n == 50: break

        # Close the progress bar
        progress.close()

    return eval_values

In [None]:
eval_values = evaluate(
    input_labels=r'path to the output of running ptn2gcn with split=val',
    image_dir=r"path to pubtabnet val images directory",
    detr_model="facebook/detr-resnet-50",
    checkpoint=r"..\checkpoint\detr.ckpt",
)

Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DetrForObjectDetection were not initialized from the model checkpoin

Processing Batches:   0%|          | 0/9081 [00:00<?, ?samples/s]

In [40]:
print(sum(eval_values)/len(eval_values))

0.22622420393111212
