In [1]:
from rrgcn import RRGCNEmbedder
from torch_geometric.datasets import Entities
import torch
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import rdflib
import re
from tqdm import tqdm
import math

# !pip install transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [2]:
## edited entities.py from PyG that also saves node_features
import logging
import os
import os.path as osp
from collections import Counter
from typing import Callable, List, Optional

import numpy as np
import torch

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_tar,
)


class Entities(InMemoryDataset):
    url = 'https://data.dgl.ai/dataset/{}.tgz'

    def __init__(self, root: str, name: str,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None):
        self.name = name.lower()
        assert self.name in ['aifb', 'am', 'mutag', 'bgs']
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')

    @property
    def num_relations(self) -> int:
        return self.data.edge_type.max().item() + 1

    @property
    def num_classes(self) -> int:
        return self.data.train_y.max().item() + 1

    @property
    def raw_file_names(self) -> List[str]:
        return [
            f'{self.name}_stripped.nt.gz',
            'completeDataset.tsv',
            'trainingSet.tsv',
            'testSet.tsv',
        ]

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def download(self):
        path = download_url(self.url.format(self.name), self.root)
        extract_tar(path, self.raw_dir)
        os.unlink(path)

    def process(self):
        import gzip

        import pandas as pd
        import rdflib as rdf

        graph_file, task_file, train_file, test_file = self.raw_paths

        with hide_stdout():
            g = rdf.Graph()
            with gzip.open(graph_file, 'rb') as f:
                g.parse(file=f, format='nt')

        freq = Counter(g.predicates())

        relations = sorted(set(g.predicates()), key=lambda p: -freq.get(p, 0))
        subjects = set(g.subjects())
        objects = set(g.objects())
        nodes = list(subjects.union(objects))

        N = len(nodes)
        R = 2 * len(relations)

        relations_dict = {rel: i for i, rel in enumerate(relations)}
        nodes_dict = {node: i for i, node in enumerate(nodes)}

        edges = []
        node_features = {}
        for s, p, o in g.triples((None, None, None)):
            src, dst, rel = nodes_dict[s], nodes_dict[o], relations_dict[p]
            edges.append([src, dst, 2 * rel])
            edges.append([dst, src, 2 * rel + 1])

            # SAVE LITERAL VALUES
            if isinstance(o, rdf.Literal):
                literal_type = p#type(o.value)
                if literal_type not in node_features:
                    node_features[literal_type] = [[dst], [o.value]]
                else:
                    node_features[literal_type][0].append(dst)
                    node_features[literal_type][1].append(o.value)

        edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
        perm = (N * R * edges[0] + R * edges[1] + edges[2]).argsort()
        edges = edges[:, perm]

        edge_index, edge_type = edges[:2], edges[2]

        if self.name == 'am':
            label_header = 'label_cateogory'
            nodes_header = 'proxy'
        elif self.name == 'aifb':
            label_header = 'label_affiliation'
            nodes_header = 'person'
        elif self.name == 'mutag':
            label_header = 'label_mutagenic'
            nodes_header = 'bond'
        elif self.name == 'bgs':
            label_header = 'label_lithogenesis'
            nodes_header = 'rock'

        labels_df = pd.read_csv(task_file, sep='\t')
        labels_set = set(labels_df[label_header].values.tolist())
        labels_dict = {lab: i for i, lab in enumerate(list(labels_set))}
        nodes_dict = {np.unicode(key): val for key, val in nodes_dict.items()}

        train_labels_df = pd.read_csv(train_file, sep='\t')
        train_indices, train_labels = [], []
        for nod, lab in zip(train_labels_df[nodes_header].values,
                            train_labels_df[label_header].values):
            train_indices.append(nodes_dict[nod])
            train_labels.append(labels_dict[lab])

        train_idx = torch.tensor(train_indices, dtype=torch.long)
        train_y = torch.tensor(train_labels, dtype=torch.long)

        test_labels_df = pd.read_csv(test_file, sep='\t')
        test_indices, test_labels = [], []
        for nod, lab in zip(test_labels_df[nodes_header].values,
                            test_labels_df[label_header].values):
            test_indices.append(nodes_dict[nod])
            test_labels.append(labels_dict[lab])

        test_idx = torch.tensor(test_indices, dtype=torch.long)
        test_y = torch.tensor(test_labels, dtype=torch.long)

        data = Data(edge_index=edge_index, edge_type=edge_type,
                    train_idx=train_idx, train_y=train_y, test_idx=test_idx,
                    test_y=test_y, num_nodes=N, node_features=node_features)

        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name.upper()}{self.__class__.__name__}()'


class hide_stdout(object):
    def __enter__(self):
        self.level = logging.getLogger().level
        logging.getLogger().setLevel(logging.ERROR)

    def __exit__(self, *args):
        logging.getLogger().setLevel(self.level)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = Entities('./', 'am')
data = dataset[0].to(device)

In [4]:
for k, v in data.node_features.items():
    print(k)
    print(v[1][0:10])
    print(f'num entries {len(v[1])} - unique entries {len(set(v[1]))}')
    print('--')

http://purl.org/collections/nl/am/dimensionPart
['passe-partout', 'geheel', 'b-maat', 'a-maat', 'c-maat', 'c-maat', 'a-maat', 'a-maat', 'KA 12834.1', 'a']
num entries 6028 - unique entries 577
--
http://purl.org/collections/nl/am/reproductionReference
['KA_6989.JPG', 'A_10069_000.jpg', 'KA_14018.JPG', 'A_12399_000.jpg', 'O 2577.38', 'A_19506', 'O 11540.09', 'O 1320.34', 'A_50466', 'A_44633']
num entries 110861 - unique entries 104142
--
http://www.w3.org/2000/01/rdf-schema#label
['WH 1ste verdieping 2003-01-23 2003-08-22 intern 15111 Transportnummer:2003-0066 intern ', 'CBD 3 kast 84-3 1975-01-01   17471   ', ' Portretten van Nederlandse beeldende kunstenaars : repertorium [R] p. 132 ', 'Een en ander over glasgravure p. 276, t.o. p. 276 afb. ', 'AHM dep 3 blok 2-4 1975-01-01   16378   ', 'AHM textielrestauratie 2006-10-24 2008-09-03  16283   ', 'WH Palet Willet boudoir 2001-11-04 2001-08-28  26292   ', 'AHM PK stelling 27-14 II 2009-03-21   14631 doos nr 251 intern ', ' Geschiedenis va

In [5]:
# Divide literals in: 
# categoricals: will be modelled as if they are regular nodes
# numericals: will be modelled as univariate numbers
# dates: will be modelled as dates
# strings: will be modelled as textual strings
# all other literal types will be removed!
am = rdflib.Namespace("http://purl.org/collections/nl/am/")
categoricals = [
    am["dimensionPart"],
    am["dimensionUnit"],
    am["currentLocationLref"],
    am["currentLocationType"],
    am["priref"],
    am["AHMTextsAuthorLref"],
    am["exhibitionLref"],
    am["alternativeNumberType"],
    am["documentationShelfmark"],
    am["documentationTitleLref"],
    am["documentationTitleArticle"],
    am["alternativeNumberInstitution"],
    am["exhibitionVenue"],
    am["relatedObjectAssociation"]
]
numericals = [am["dimensionValue"]]
dates = [
    am["reproductionDate"],
    am["currentLocationDateStart"],
    am["documentationSortyear"],
    am["productionDateStart"],
    am["exhibitionDateEnd"],
    am["productionDateEnd"],
    am["currentLocationDateEnd"],
    am["AHMTextsPubl"],
    am["acquisitionDate"],
    am["AHMTextsDate"],
    am["birthDateEnd"],
    am["exhibitionDateStart"],
    am["deathDateEnd"],
    am["deathDateStart"],
]
strings = [
    am["creatorRole"],
    am["name"],
    am["documentationTitle"],
    am["title"],
    rdflib.URIRef("http://www.w3.org/2004/02/skos/core#prefLabel"),
    am["AHMTextsType"],
    am["exhibitionTitle"],
    am["usedFor"],
    am["birthPlace"],
    am["currentLocationFitness"],
    am["creatorQualifier"],
    am["deathPlace"],
    am["AHMTextsTekst"],
    am["biography"],
    am["relatedObjectNotes"],
    am["creditLine"],
    am["relatedObjectTitle"],
    rdflib.URIRef("http://www.w3.org/2004/02/skos/core#altLabel"),
    am["occupation"],
    am["source"],
    am["equivalentName"],
    am["nationality"],
]

In [6]:
# Store literal types remove literals unaccounted for in previous step
types_to_remove = set(data.node_features.keys()).difference(
    set(strings).union(set(categoricals)).union(set(dates)).union(set(categoricals))
)
removed_literal_idxs = []
for t in types_to_remove:
    removed_literal_idxs.append(torch.tensor(data.node_features[t][0], dtype=torch.long))

In [7]:
processed_node_features = {}

In [8]:
def date_to_number(date: str):
    if date is None:
        return None

    date = date.replace("?", "").strip()
    date = date.replace("(", "").strip()
    date = date.replace(")", "").strip()
    date = re.sub(r"(\D*)$", "", date).strip()
    date = re.sub(r"^(\D*)", "", date).strip()

    date = date.replace("/", "-")
    date = date.replace("\\", "-")
    date = date.replace(".", "-")
    date = date.replace(" ", "-")

    if "-" in date:
        parts = date.split("-")
    else:
        parts = [date]

    parts = [p for p in parts if len(p.strip())]
    values = [1, 1 / 12, 1 / 365]
    val = 0
    for v, p in zip(values, parts):
        try:
            val += v * int(p)
        except:
            return torch.nan
    return val


for l in dates:
    indices = torch.tensor(data.node_features[l][0], dtype=torch.long)
    float_dates = torch.tensor(
        [date_to_number(d) for d in data.node_features[l][1]], dtype=torch.float32
    )

    # remove invalid nodes from graph
    removed_literal_idxs.append(indices[torch.isnan(float_dates)])

    indices = indices[~torch.isnan(float_dates)]
    float_dates = float_dates[~torch.isnan(float_dates)]

    processed_node_features[len(processed_node_features)] = [
        indices,
        float_dates.reshape(-1, 1),
    ]


In [9]:
for l in numericals:
    processed_numericals = []
    for n in data.node_features[l][1]:
        try:
            processed_numericals.append(float(n.replace(",", ".")))
        except:
            processed_numericals.append(torch.nan)

    indices = torch.tensor(data.node_features[l][0], dtype=torch.long)
    numerical_features = torch.tensor(processed_numericals, dtype=torch.float32)


    # remove invalid nodes from graph
    removed_literal_idxs.append(indices[torch.isnan(numerical_features)])

    indices = indices[~torch.isnan(numerical_features)]
    numerical_features = numerical_features[~torch.isnan(numerical_features)]

    processed_node_features[len(processed_node_features)] = [
        indices,
        numerical_features.reshape(-1, 1),
    ]

In [10]:
tokenizer = AutoTokenizer.from_pretrained("GroNLP/bert-base-dutch-cased")
model = AutoModelForMaskedLM.from_pretrained("GroNLP/bert-base-dutch-cased",)
model = model.to(device)

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[
        0
    ]  # First element of model_output contains all token embeddings
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


batch_size = 64
for l in tqdm(strings):
    model.eval()
    indices = torch.tensor(data.node_features[l][0], dtype=torch.long)

    sents = data.node_features[l][1]
    num_batches = math.ceil(len(sents) / batch_size)

    s_embs = []
    for b in range(num_batches):
        inputs = tokenizer(
            sents[b * batch_size : (b + 1) * batch_size],
            padding="max_length",
            return_tensors="pt",
            max_length=256,
            truncation=True
        ).to(device)
        s_emb = mean_pooling(
            model._modules["bert"](
                **inputs
            ),
            inputs["attention_mask"]
        ).detach().cpu()
        s_embs.append(s_emb)

    s_emb = torch.vstack(s_embs)
    
    processed_node_features[len(processed_node_features)] = [
        indices,
        s_emb,
    ]


100%|██████████| 22/22 [22:21<00:00, 60.99s/it]  


In [11]:
literals = torch.hstack(removed_literal_idxs).squeeze().to(device)
mask = torch.isin(data.edge_index[0], literals) | torch.isin(
    data.edge_index[1], literals
)
edge_index = data.edge_index[:, ~mask]
edge_type = data.edge_type[~mask]

In [12]:
train_idx, val_idx, y_train, y_val = train_test_split(
    data.train_idx,
    data.train_y,
    stratify=data.train_y.cpu().numpy(),
    test_size=0.1,
    random_state=42,
)

In [13]:
embedder = RRGCNEmbedder(
    num_nodes=data.num_nodes,
    num_relations=dataset.num_relations,
    num_layers=5,
    emb_size=512,
    device=device,
)

In [19]:
# for node features to work well, they have to be normalized
# you can choose "standard" for StandardScaler, "robust" for RobustScaler, "quantile"
# for QuantileTransformer and "power" for PowerTransformer
#
# you could also pass sklearn compatible scalers by passing a dict keyed by
# literal type, e.g.:
# {0: StandardScaler(), 1: RobustScaler()}
from sklearn.preprocessing import StandardScaler, QuantileTransformer

train_embs = embedder.embeddings(
    edge_index,
    edge_type,
    node_features=processed_node_features,
    node_features_scalers={
        k: None
        if v[1].shape[1] > 1
        else QuantileTransformer(output_distribution="normal")
        for k, v in processed_node_features.items()
    },
    idx=train_idx,
)


100%|██████████| 1/1 [03:00<00:00, 180.20s/it]


In [20]:
# only fit node feature scalers on nodes reachable from train nodes,
# for val and test nodes, reuse the fit scalers using embedder.get_last_fit_scalers()
val_embs = embedder.embeddings(
    edge_index,
    edge_type,
    node_features=processed_node_features,
    node_features_scalers=embedder.get_last_fit_scalers(),
    idx=val_idx,
)
test_embs = embedder.embeddings(
    edge_index,
    edge_type,
    node_features=processed_node_features,
    node_features_scalers=embedder.get_last_fit_scalers(),
    idx=data.test_idx,
)

100%|██████████| 1/1 [03:01<00:00, 181.36s/it]
100%|██████████| 1/1 [03:04<00:00, 184.37s/it]


In [21]:
task_type = "GPU" if torch.cuda.is_available() else "CPU"
clf = CatBoostClassifier(
    iterations=10_000,
    learning_rate=0.10,
    early_stopping_rounds=100,
    task_type=task_type,
    random_seed=42,
    use_best_model=True,
    auto_class_weights="Balanced",
    verbose=False
)
clf = clf.fit(
    train_embs.cpu().numpy(),
    y_train.cpu().numpy(),
    eval_set=(val_embs.cpu().numpy(), y_val.cpu().numpy()),
)



In [22]:
print(classification_report(data.test_y.cpu().numpy(), clf.predict(test_embs.cpu().numpy())))

              precision    recall  f1-score   support

           0       0.88      0.88      0.88         8
           1       0.97      1.00      0.99        69
           2       1.00      0.80      0.89         5
           3       0.93      0.81      0.87        16
           4       0.89      0.73      0.80        11
           5       1.00      1.00      1.00        11
           6       1.00      1.00      1.00         3
           7       0.70      1.00      0.82        23
           8       0.94      0.88      0.91        17
           9       0.83      0.60      0.70        25
          10       0.55      0.60      0.57        10

    accuracy                           0.88       198
   macro avg       0.88      0.85      0.86       198
weighted avg       0.89      0.88      0.88       198

