In [1]:
from rrgcn import RRGCNEmbedder
from torch_geometric.datasets import Entities
import torch
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = Entities('./', 'aifb')
data = dataset[0].to(device)

In [3]:
embedder = RRGCNEmbedder(num_nodes=data.num_nodes, num_relations=dataset.num_relations, num_layers=1, emb_size=512, device=device)

In [4]:
embs = embedder.embeddings(data.edge_index, data.edge_type, idx=torch.hstack((data.train_idx, data.test_idx)))

100%|██████████| 1/1 [00:00<00:00,  5.65it/s]


In [5]:
train_embs, val_embs, y_train, y_val = train_test_split(embs[:len(data.train_idx)], data.train_y, stratify=data.train_y.cpu().numpy(), test_size=0.1, random_state=42)
test_embs = embs[len(data.train_idx):]

In [6]:
task_type = 'GPU' if torch.cuda.is_available() else 'CPU'
clf = CatBoostClassifier(iterations=1000, early_stopping_rounds=10, task_type=task_type, random_seed=42, auto_class_weights="Balanced")
clf = clf.fit(train_embs.cpu().numpy(), y_train.cpu().numpy(), eval_set=(val_embs.cpu().numpy(), y_val.cpu().numpy()))



Learning rate set to 0.085304
0:	learn: 1.2969302	test: 1.3342397	best: 1.3342397 (0)	total: 18.6ms	remaining: 18.6s
1:	learn: 1.2000661	test: 1.2338934	best: 1.2338934 (1)	total: 30.4ms	remaining: 15.2s
2:	learn: 1.1425615	test: 1.2085750	best: 1.2085750 (2)	total: 43.1ms	remaining: 14.3s
3:	learn: 1.0679692	test: 1.1699182	best: 1.1699182 (3)	total: 56.1ms	remaining: 14s
4:	learn: 0.9919849	test: 1.1372393	best: 1.1372393 (4)	total: 67.9ms	remaining: 13.5s
5:	learn: 0.9336028	test: 1.1186410	best: 1.1186410 (5)	total: 80.1ms	remaining: 13.3s
6:	learn: 0.8787697	test: 1.1008139	best: 1.1008139 (6)	total: 92.1ms	remaining: 13.1s
7:	learn: 0.8352548	test: 1.0676963	best: 1.0676963 (7)	total: 105ms	remaining: 13s
8:	learn: 0.7851750	test: 1.0534535	best: 1.0534535 (8)	total: 116ms	remaining: 12.8s
9:	learn: 0.7387055	test: 1.0352245	best: 1.0352245 (9)	total: 128ms	remaining: 12.7s
10:	learn: 0.6983655	test: 1.0326858	best: 1.0326858 (10)	total: 140ms	remaining: 12.6s
11:	learn: 0.667678

In [7]:
print(classification_report(clf.predict(test_embs.cpu().numpy()), data.test_y.cpu().numpy()))

              precision    recall  f1-score   support

           0       0.93      0.88      0.90        16
           1       0.92      1.00      0.96        11
           2       1.00      1.00      1.00         3
           3       0.83      0.83      0.83         6

    accuracy                           0.92        36
   macro avg       0.92      0.93      0.92        36
weighted avg       0.92      0.92      0.92        36

