In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import os
import json
from PIL import Image
from tqdm.notebook import tqdm
import torch
from utils import get_iou, coco2xyxy, xyxy2coco, get_psuedo_knn, process_target, get_table_grid

In [None]:
from model import GNLightning

if torch.cuda.is_available(): torch.cuda.empty_cache()

gnet = GNLightning.load_from_checkpoint(checkpoint_path=r"..\checkpoints\gnet_stage1.ckpt",
                                    d_model=128,
                                    lr=1e-3,
                                    batch_size=2,
                                    num_workers=0,
                                    train_path=r'..\datasets\gcn\stage3\ptn2gcn\train.jsonl',
                                    val_path=r'..\datasets\gcn\stage3\ptn2gcn\val.jsonl')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gnet.to(device)
gnet.eval()

GNLightning(
  (gnet): GraphNetwork(
    (conv1): GCNConv(8, 128)
    (conv2): GCNConv(128, 128)
    (lin1): Linear(in_features=256, out_features=128, bias=True)
    (lin_final): Linear(in_features=128, out_features=3, bias=True)
  )
  (criterion): NLLLoss()
)

In [28]:
def evaluate(
        input_labels,  
        image_dir,  
        iou_match_threshold=0.5,
        device = "cuda"):
    
    eval_values = []

    # Count total lines in the input file for progress tracking
    with open(input_labels, 'r', encoding="utf8") as infile:
        total_lines = sum(1 for _ in infile)

    # Progress bar setup
    progress = tqdm(total=total_lines, desc="Processing Batches", unit="samples")

    with open(input_labels, 'r', encoding="utf8") as infile:

        for line in infile:
            # Parse the JSON line
            label = json.loads(line)

            # Update progress bar
            progress.update(1)

            # === FORWARD PASS THROUGH THE DETR MODEL TO GET THE PRED BBOXES AND HIDDEN STATES ===

            image = Image.open(os.path.join(image_dir, label['filename'])).convert("RGB")

            # Forward pass through the model
            # ...

            bbox_indices = []
            filtered_boxes = []
            filtered_hidden_states = []
            gt_bboxes_coco = label['gt_bboxes']
            gt_bboxes_xyxy = [coco2xyxy(box) for box in gt_bboxes_coco]
            gt_bbox_indices = torch.tensor(label['gt_bbox_indices'], dtype=torch.int)

            for pred_box in gt_bboxes_xyxy:

                # Find the ground truth box with the highest IoU
                max_iou = 0
                best_gt_index = -1
                for i, gt_bbox in enumerate(gt_bboxes_xyxy):
                    iou = get_iou(pred_box, gt_bbox)
                    if iou > max_iou:
                        max_iou = iou
                        best_gt_index = i

                # Keep the DETR bbox if the highest IoU exceeds the threshold
                if max_iou >= iou_match_threshold:
                    filtered_boxes.append(pred_box)
                    bbox_indices.append(gt_bbox_indices[best_gt_index])

            filtered_boxes = torch.tensor([xyxy2coco(bbox) for bbox in filtered_boxes])
            filtered_hidden_states = torch.tensor(filtered_hidden_states)

            probs, edge_index = gnet(filtered_boxes, torch.tensor([image.size[0], image.size[1]]))

            probs, edge_index = probs.to('cpu'), edge_index.to('cpu')

            # === GET THE GROUNTRUTH EDGE SET AS A DICT OF (START BBOX INDEX, END BBOX INDEX): CLASS ===

            thead_grid, tbody_grid = get_table_grid(''.join(label['html']))
            table_grid = thead_grid + tbody_grid

            gt_edge_index = get_psuedo_knn(torch.tensor(gt_bboxes_coco))

            gt_bbox_index_pairs = torch.stack((
                gt_bbox_indices[gt_edge_index[0]],  # Start bounding boxes
                gt_bbox_indices[gt_edge_index[1]]   # End bounding boxes
            ), dim=1)

            gt_classes = process_target(gt_bbox_index_pairs, table_grid) # dtype long

            gt_edgeset = {}

            # Iterate over the pairs and classes
            for pair, gt_class in zip(gt_bbox_index_pairs, gt_classes):
                start_bbox_index = pair[0].item()  # Convert to a standard Python integer
                end_bbox_index = pair[1].item()   # Convert to a standard Python integer
                gt_edgeset[(start_bbox_index, end_bbox_index)] = gt_class.item()  # Map to the class value 

            # === PARSE THROUGH EACH PRED EDGE AND CHECK IF IT MATCHES WITH THE GT EDGESET ===
            numerator = 0
            counted_edges = set()
            for i, (start, end) in enumerate(edge_index.t()):
                
                # Get the predicted class for the edge
                predicted_class = torch.argmax(probs[i]).item()

                # no relationship edge, skip
                if predicted_class == 0: 
                    continue

                # Convert to tuple for comparison
                edge = (bbox_indices[start].item(), bbox_indices[end].item())

                # Ensure unique counting
                if edge in counted_edges:
                    continue

                if edge in gt_edgeset and gt_edgeset[edge] == predicted_class:
                    numerator += 1
                    counted_edges.add(edge)    # Convert to tuple for comparison
                edge = (bbox_indices[start].item(), bbox_indices[end].item())

                # Ensure unique counting
                if edge in counted_edges:
                    continue

                if edge in gt_edgeset and gt_edgeset[edge] == predicted_class:
                    numerator += 1
                    counted_edges.add(edge)

            denominator = sum(1 for value in gt_edgeset.values() if value != 0)

            # print('edge_index:', edge_index)
            # print('gt_edgeset keys:', gt_edgeset.keys())
            print('numerator, denominator:', numerator, denominator)

            eval_values.append(numerator/denominator)

            if progress.n == 50: break

        # Close the progress bar
        progress.close()

    return eval_values

In [None]:
eval_values = evaluate(
    input_labels=r'..\datasets\gcn\stage3\ptn2gcn\val.jsonl',
    # insert link to pubtabnet val image directory here
    image_dir=r"C:\Users\remote desktop\Downloads\pubtabnet\val",
    device=device
)

Processing Batches:   0%|          | 0/9081 [00:00<?, ?samples/s]

numerator, denominator: 24 52
numerator, denominator: 44 104
numerator, denominator: 20 20
numerator, denominator: 28 44
numerator, denominator: 74 112
numerator, denominator: 264 656
numerator, denominator: 12 30
numerator, denominator: 110 202
numerator, denominator: 46 80
numerator, denominator: 247 614
numerator, denominator: 98 162
numerator, denominator: 61 110
numerator, denominator: 56 104
numerator, denominator: 126 230
numerator, denominator: 90 134
numerator, denominator: 264 454
numerator, denominator: 263 650
numerator, denominator: 172 440
numerator, denominator: 405 666
numerator, denominator: 220 448
numerator, denominator: 71 162
numerator, denominator: 85 154
numerator, denominator: 43 80
numerator, denominator: 548 1072
numerator, denominator: 108 232
numerator, denominator: 88 146
numerator, denominator: 440 870
numerator, denominator: 82 168
numerator, denominator: 168 290
numerator, denominator: 32 82
numerator, denominator: 42 78
numerator, denominator: 138 362
n

In [31]:
print(sum(eval_values)/len(eval_values))

0.5288683135277303
