In [None]:
# import os
# from accelerate.utils import write_basic_config

# write_basic_config()
# os._exit(00)

In [None]:
!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html

In [None]:
import pickle
import os
from PIL import Image
import numpy as np
from torchvision import transforms
from tqdm import tqdm
def extract_component(pickle_file, image_path_root, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    with open(pickle_file, 'rb') as file:
        data = pickle.load(file)
    for k in tqdm(data.keys()):
        components = data[k]['components']
        for comp in components:
            if os.path.exists(os.path.join(save_dir,f"{comp['object_id']}.png")):
                continue
            try:
              img = Image.open(f"{os.path.join(image_path_root,k)}_page-{comp['page']}.png").convert("RGB")
              bbox = comp['bbox']
              cropped_img = transforms.functional.crop(img,top=bbox[1],left=bbox[0],height=bbox[3],width=bbox[2])
              cropped_img.save(os.path.join(save_dir,f"{comp['object_id']}.png"))
            except Exception as e:
              print(comp)
              print(e)
              print(k)

In [None]:
extract_component("val_data.pkl","val/val","components")
extract_component("test_data.pkl","test/test","components")

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms
import pickle
from transformers import AutoImageProcessor
import torch
class CompVisualDataset(Dataset):
    def __init__(self, pickle_file,image_path_root):
        super().__init__()
        with open(pickle_file, 'rb') as file:
            data = pickle.load(file)
        self.components = []
        self.root_path = image_path_root
        for k in data.keys():
            for comp in data[k]['components']:
                if comp['bbox'] == [0.0, 0.0, 0.0, 0.0]:
                  continue
                self.components.append(comp)

    def __len__(self):
        return len(self.components)

    def __getitem__(self, index):
        comp = self.components[index]
        img = Image.open(os.path.join(self.root_path, f"{comp['object_id']}.png")).convert("RGB")
        return img, comp['object_id']

def collate_fn(batch):
    imgs = [e[0] for e in batch]
    object_ids = [e[1] for e in batch]
    return imgs, object_ids

In [None]:
visual_val_dataset = CompVisualDataset('val_data.pkl','components')
visual_test_dataset = CompVisualDataset('test_data.pkl','components')

In [None]:
from transformers import AutoModel
class VisualEncoder(torch.nn.Module):
    def __init__(self,):
        super().__init__()
        self.dinvov2 = AutoModel.from_pretrained('facebook/dinov2-base')
        self.dinvov2.config.return_dict=False

    def forward(self, pixel_values):
        outputs = self.dinvov2(pixel_values)
        sequence_outputs = outputs[0]
        cls_token = sequence_outputs[:,0]
        patch_tokens = sequence_outputs[:,1:]
        embedding = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
        return embedding
encoder = VisualEncoder()

In [None]:
model = VisualEncoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using: {device}")
model.to(device)
model.eval()
with torch.no_grad():
    traced_model = torch.jit.trace(model, torch.rand(1,3,224,224).to(device))

In [None]:
image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')

In [None]:
from tqdm import tqdm
import os
def extract_features(dataloader, feature_path):
  if not os.path.exists(feature_path):
    os.makedirs(feature_path)
  with torch.no_grad():
      for imgs, object_ids in tqdm(dataloader):
          image_inputs = image_processor(imgs, return_tensors="pt").to(device)
          features = model(image_inputs.pixel_values)
          for idx, obj_id in enumerate(object_ids):
            torch.save(features[idx],os.path.join(feature_path,f"{obj_id}.pt"))

In [None]:
visual_val_dataloader = DataLoader(visual_val_dataset,batch_size=32, collate_fn= collate_fn, num_workers=4)
visual_test_dataloader = DataLoader(visual_test_dataset,batch_size=32, collate_fn= collate_fn, num_workers=4)

In [None]:
extract_features(visual_val_dataloader,  'visual_features')
print("Extraction completed for val set!")
extract_features(visual_test_dataloader,  'visual_features')
print("Extraction completed for test set!")

In [None]:
import torch
import dgl
graphs = []
cat_no_rel = ['other',
 'report_title',
 'title',
 'table_of_contents',
 'cross',
 'list_of_tables',
 'appendix_list',
 'references',
 'list_of_figures']

def generate_graphs(pkl_file, graphs):
    with open(pkl_file,"rb") as f:
        data = pickle.load(f)
    for doc in tqdm(data.keys()):
        components = data[doc]['components']
        nodes = sorted([comp['object_id'] for comp in components if comp['category'] not in cat_no_rel])
        edges = torch.combinations(torch.arange(len(nodes)), r=2) #NC2
        g = dgl.DGLGraph()
        g.add_nodes(len(nodes))
        g.add_edges(edges[:,0],edges[:,1])
        g.add_edges(edges[:,1],edges[:,0])
        g.ndata['obj_id'] = torch.tensor(nodes)
        graphs.append(g)
    return graphs

graphs = generate_graphs("val_data.pkl",graphs)
graphs = generate_graphs("test_data.pkl",graphs)
dgl.save_graphs("graphs.bin",graphs)

In [None]:
from torch.utils.data import Dataset, DataLoader
import pickle
import dgl
import os
import torch
import numpy as np
class GraphDataset(Dataset):
    def __init__(self, graph_file, feature_dir):
        super().__init__()
        self.graphs,_ = dgl.load_graphs(graph_file)
        self.feature_dir = feature_dir

    def __len__(self):
        return len(self.graphs)

    def load_feat(self,nodes):
        tensors = []
        for idx, node in enumerate(nodes):
            try:
                tensors.append(torch.load(f"{self.feature_dir}/{node}.pt",map_location=torch.device("cpu"),weights_only=False).unsqueeze(0))
            except:
                tensors.append(torch.zeros((1,1536)))
        return torch.cat(tensors,dim=0)

    def __getitem__(self, index):
        g = self.graphs[index]
        nodes = g.ndata['obj_id']
        feats = self.load_feat(nodes) if len(nodes) != 0 else None
        return g, feats

In [None]:
import torch
import torch.nn as nn
import dgl
import torch.nn.functional as F
class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, 1)

    def apply_edges(self, edges):
        """
        Computes a scalar score for each edge of the given graph.

        Parameters
        ----------
        edges :
            Has three members ``src``, ``dst`` and ``data``, each of
            which is a dictionary representing the features of the
            source nodes, the destination nodes, and the edges
            themselves.

        Returns
        -------
        dict
            A dictionary of new edge features.
        """
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(self.apply_edges)
            return g.edata['score']

In [None]:
model = MLPPredictor(1536)
model.load_state_dict(torch.load("predictor.pth",weights_only=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
model.to(device)

In [None]:
dataset = GraphDataset("graphs.bin","visual_features")
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda batch: batch[0])

In [None]:
threshold = 0.5

In [None]:
all_predicted_edges= []
for g, feats in tqdm(dataloader):
    if feats:
        scores = model(g.to(device),feats.to(device))
        edges = g.edges()
        nodes = list(g.ndata['obj_ids'])
        p_eids = torch.where(scores >= threshold)
        p_u_edge = list(edges[0][p_eids])
        p_v_edge = list(edges[1][p_eids])
        get_obj_id = lambda x: nodes.index(x)
        predicted_edges = list(zip(list(map(get_obj_id,p_u_edge)),list(map(get_obj_id,p_v_edge))))
        all_predicted_edges.append(predicted_edges)
    else:
        all_predicted_edges.append([])

In [None]:
import pandas as pd
df = pd.DataFrame(columns=['ID','Parent'])

In [None]:
all_relations =[('summary', 'paragraph'),
 ('figure', 'figure_caption'),
 ('table', 'table_caption'),
 ('form_title', 'form_body'),
 ('section', 'subsection'),
 ('subsection', 'subsubsection'),
 ('section', 'paragraph'),
 ('subsubsection', 'paragraph'),
 ('paragraph', 'list'),
 ('subsubsection', 'subsubsubsection'),
 ('subsubsubsection', 'paragraph'),
 ('subsection', 'paragraph'),
 ('subsection', 'list'),
 ('summary', 'form_body'),
 ('summary', 'form'),
 ('abstract', 'form'),
 ('abstract', 'form_body'),
 ('subsection', 'form_body'),
 ('subsubsection', 'form_body'),
 ('section', 'list'),
 ('section', 'form_body'),
 ('abstract', 'paragraph'),
 ('section', 'form'),
 ('subsubsection', 'list'),
 ('subsubsubsection', 'list'),
 ('subsection', 'form'),
 ('subsubsubsection', 'subsubsubsubsection'),
 ('subsubsubsubsection', 'paragraph'),
 ('subsubsubsection', 'form_body'),
 ('subsubsection', 'form'),
 ('subsubsubsection', 'form')]