step by step testing and commenting of GCN model

In [8]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from utils import depad, depad_2d

# Define a simple model class
class SimpleModel(torch.nn.Module):
    def __init__(self, edge_index, num_nodes, d_model, num_classes):
        super(SimpleModel, self).__init__()
        self.edge_index = edge_index
        self.num_nodes = num_nodes

        # Define GCN layers
        self.conv1 = GCNConv(d_model, d_model)
        self.conv2 = GCNConv(d_model, d_model)

        # Define linear layers
        self.lin1 = Linear(d_model * 2, d_model)
        self.lin_final = Linear(d_model, num_classes)

    def forward(self, hidden_state, pred_bboxes):

        # TODO move the explanations of forward propagation from actual model to here

        batch_size, num_nodes, d_model = hidden_state.shape

        hidden_state_reshaped = hidden_state.view(-1, d_model)

        edge_index_batch = self.edge_index.repeat(1, batch_size)
        batch_offsets = torch.arange(batch_size).repeat_interleave(self.edge_index.size(1)) * self.num_nodes
        edge_index_batch += batch_offsets
        edge_index_batch = edge_index_batch.to(hidden_state_reshaped.device)

        x = self.conv1(hidden_state_reshaped, edge_index_batch)
        x = F.relu(x)
        x = self.conv2(x, edge_index_batch)
        x = F.relu(x)

        x = x.view(batch_size, num_nodes, d_model)

        x1 = x[:, self.edge_index[0]]
        x2 = x[:, self.edge_index[1]]
        xpair = torch.cat((x1, x2), dim=-1)
        xpair = F.relu(self.lin1(xpair))

        xfin = self.lin_final(xpair)
        probs = F.log_softmax(xfin, dim=-1)

        bbox_pairs = torch.cat((
            pred_bboxes[:, self.edge_index[0]],
            pred_bboxes[:, self.edge_index[1]]
        ), dim=-1)

        return probs, bbox_pairs

# let's create a test model with only 2 nodes and 2 batches (4 bboxes in total)
batch_size = 2
num_nodes = 2
d_model = 5
num_classes = 3
edge_index = torch.tensor([[0, 1], [1, 0]])
model = SimpleModel(edge_index=edge_index, num_nodes=num_nodes, d_model=d_model, num_classes=num_classes)
hidden_state = torch.randn(batch_size, num_nodes, d_model)

# here we define a raw gt for testing
# in this case, it's a table with two horizontally adjacent cells
gt_html = ["<thead>","<td>","</td>","<td>","</td>","</thead>"]

# create gt bboxes in the (xmin, ymin, width, height) COCO format
# here's a cool visualization of the bboxes we're making
#   0 1 2 3 4 5 6 7 8 9
# 0 +-----+
# 1 |  1  |     +-----+
# 2 +-----+     |  2  |
# 3             |     |
# 4   + - +     +-----+
# 5   | 3 |
# 6   + - + 
# 7     +-----------+
# 8     |     4     |
# 9     +-----------+
gt_bboxes = [[[0,0,3,2],[6,1,3,3]],[[1,4,2,2],[2,7,6,2]]]
pred_bboxes = torch.tensor([[[0,0,3,2],[6,1,3,3]],[[1,4,2,2],[2,7,6,2]]])

# now let's run the model to get the predicted bbox pairs
probs, bbox_pairs = model(hidden_state, pred_bboxes)
# print(bbox_pairs) yields
# tensor([[[0, 0, 3, 2, 6, 1, 3, 3],
#          [6, 1, 3, 3, 0, 0, 3, 2]],

#         [[1, 4, 2, 2, 2, 7, 6, 2],
#          [2, 7, 6, 2, 1, 4, 2, 2]]])

# let's simulate the gt data struct according to how we defined it in data.py
# simulate padding
gt_bboxes[0].append([-1,-1,-1,-1])
gt_bboxes[1].append([-1,-1,-1,-1])
gt_bboxes = torch.tensor(gt_bboxes, dtype=torch.float32)
table_grid = torch.tensor([[[0,1,-1],[-1,-1,-1]],[[1,0,-1],[-1,-1,-1]]], dtype=torch.float32)

# create the depadding stuff
gt_bboxes = depad(gt_bboxes)
print(gt_bboxes)
table_grid = depad_2d(table_grid)
print(table_grid)

# now we can compare bbox pairs against gt bboxes based on some criterion
bbox_pairs = bbox_pairs.tolist()
for batch in bbox_pairs:
    for edge in batch:
        # let's split the edge data into bbox 1 and bbox 2
        # we pause here for now because we need to define the gt data structure before proceeding
        # ok the gt data struct has been decided
        bbox1 = edge[:4]
        bbox2 = edge[4:]
        print(bbox1,bbox2)

[[[0.0, 0.0, 3.0, 2.0], [6.0, 1.0, 3.0, 3.0]], [[1.0, 4.0, 2.0, 2.0], [2.0, 7.0, 6.0, 2.0]]]
[[[0.0, 1.0]], [[1.0, 0.0]]]
[0, 0, 3, 2] [6, 1, 3, 3]
[6, 1, 3, 3] [0, 0, 3, 2]
[1, 4, 2, 2] [2, 7, 6, 2]
[2, 7, 6, 2] [1, 4, 2, 2]
