# Knowledge Graph Embedding
На основе материалов из [PyG]('https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.TransE.html?highlight=transe').

In [None]:
import os
import os.path as osp
import torch
import torch.optim as optim

# Install required packages.
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

2.5.0+cu121
Looking in links: https://data.pyg.org/whl/nightly/torch-2.5.0+cu121.html
Collecting git+https://github.com/pyg-team/pytorch_geometric.git
  Cloning https://github.com/pyg-team/pytorch_geometric.git to /tmp/pip-req-build-ubor8wni
  Running command git clone --filter=blob:none --quiet https://github.com/pyg-team/pytorch_geometric.git /tmp/pip-req-build-ubor8wni
  Resolved https://github.com/pyg-team/pytorch_geometric.git to commit f5c829344517c823c24abb08ce2fc7cf00ff29f7
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


Будем использовать популярный, полный KG - `FB15k_237`.

In [None]:
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = "/content/data/FB15k/"

train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

Выберем модель для обучения эмбеддингов. В будущем можно попробовать и другие модели (`DistMult`, `ComplEx`, `RotatE`).

In [None]:
model = TransE(
    num_nodes=train_data.num_nodes,
    num_relations=train_data.num_edge_types,
    hidden_channels=50
).to(device)

Инициализируем loader.

In [None]:
loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=1000,
    shuffle=True,
)

Выбираем рекомендованные параметры для оптимизатора. Их также можно подобрать, используя валидационную выборку.

In [None]:
optimizer_map = {
    'transe': optim.Adam(model.parameters(), lr=0.01),
    'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
    'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
    'rotate': optim.Adam(model.parameters(), lr=1e-3),
}
optimizer = optimizer_map['transe']

Как обычно, определеяем процессы обучения и тестирования.

In [None]:
def train():
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples

@torch.no_grad()
def test(data):
    model.eval()
    return model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=20000,
        k=10,
    )

Обучаем на 200 эпохах и параллельно отслеживаем качество на валидационных данных.

In [None]:
for epoch in range(1, 201):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 50 == 0:
        rank, mrr, hits = test(val_data)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
              f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')

Epoch: 001, Loss: 0.7586
Epoch: 002, Loss: 0.5482
Epoch: 003, Loss: 0.4294
Epoch: 004, Loss: 0.3470
Epoch: 005, Loss: 0.2975
Epoch: 006, Loss: 0.2642
Epoch: 007, Loss: 0.2436
Epoch: 008, Loss: 0.2273
Epoch: 009, Loss: 0.2149
Epoch: 010, Loss: 0.2042
Epoch: 011, Loss: 0.1971
Epoch: 012, Loss: 0.1904
Epoch: 013, Loss: 0.1842
Epoch: 014, Loss: 0.1802
Epoch: 015, Loss: 0.1759
Epoch: 016, Loss: 0.1712
Epoch: 017, Loss: 0.1685
Epoch: 018, Loss: 0.1637
Epoch: 019, Loss: 0.1610
Epoch: 020, Loss: 0.1598
Epoch: 021, Loss: 0.1560
Epoch: 022, Loss: 0.1534
Epoch: 023, Loss: 0.1513
Epoch: 024, Loss: 0.1482
Epoch: 025, Loss: 0.1459
Epoch: 026, Loss: 0.1452
Epoch: 027, Loss: 0.1432
Epoch: 028, Loss: 0.1413
Epoch: 029, Loss: 0.1397
Epoch: 030, Loss: 0.1369
Epoch: 031, Loss: 0.1364
Epoch: 032, Loss: 0.1354
Epoch: 033, Loss: 0.1334
Epoch: 034, Loss: 0.1318
Epoch: 035, Loss: 0.1309
Epoch: 036, Loss: 0.1296
Epoch: 037, Loss: 0.1289
Epoch: 038, Loss: 0.1266
Epoch: 039, Loss: 0.1260
Epoch: 040, Loss: 0.1256


100%|██████████| 17535/17535 [06:16<00:00, 46.59it/s]


Epoch: 050, Val Mean Rank: 338.95, Val MRR: 0.2216, Val Hits@10: 0.3701
Epoch: 051, Loss: 0.1146
Epoch: 052, Loss: 0.1132
Epoch: 053, Loss: 0.1137
Epoch: 054, Loss: 0.1131
Epoch: 055, Loss: 0.1118
Epoch: 056, Loss: 0.1106
Epoch: 057, Loss: 0.1106
Epoch: 058, Loss: 0.1106
Epoch: 059, Loss: 0.1086
Epoch: 060, Loss: 0.1096
Epoch: 061, Loss: 0.1079
Epoch: 062, Loss: 0.1075
Epoch: 063, Loss: 0.1064
Epoch: 064, Loss: 0.1065
Epoch: 065, Loss: 0.1062
Epoch: 066, Loss: 0.1064
Epoch: 067, Loss: 0.1069
Epoch: 068, Loss: 0.1060
Epoch: 069, Loss: 0.1050
Epoch: 070, Loss: 0.1039
Epoch: 071, Loss: 0.1035
Epoch: 072, Loss: 0.1034
Epoch: 073, Loss: 0.1030
Epoch: 074, Loss: 0.1024
Epoch: 075, Loss: 0.1020
Epoch: 076, Loss: 0.1016
Epoch: 077, Loss: 0.1018
Epoch: 078, Loss: 0.1019
Epoch: 079, Loss: 0.1016
Epoch: 080, Loss: 0.1011
Epoch: 081, Loss: 0.1005
Epoch: 082, Loss: 0.1002
Epoch: 083, Loss: 0.0997
Epoch: 084, Loss: 0.0998
Epoch: 085, Loss: 0.1004
Epoch: 086, Loss: 0.0991
Epoch: 087, Loss: 0.0989
Epo

100%|██████████| 17535/17535 [06:54<00:00, 42.31it/s]


Epoch: 100, Val Mean Rank: 295.50, Val MRR: 0.2289, Val Hits@10: 0.3717
Epoch: 101, Loss: 0.0951
Epoch: 102, Loss: 0.0962
Epoch: 103, Loss: 0.0954
Epoch: 104, Loss: 0.0953
Epoch: 105, Loss: 0.0944
Epoch: 106, Loss: 0.0947
Epoch: 107, Loss: 0.0944
Epoch: 108, Loss: 0.0941
Epoch: 109, Loss: 0.0945
Epoch: 110, Loss: 0.0937
Epoch: 111, Loss: 0.0934
Epoch: 112, Loss: 0.0938
Epoch: 113, Loss: 0.0935
Epoch: 114, Loss: 0.0932
Epoch: 115, Loss: 0.0933
Epoch: 116, Loss: 0.0931
Epoch: 117, Loss: 0.0930
Epoch: 118, Loss: 0.0933
Epoch: 119, Loss: 0.0940
Epoch: 120, Loss: 0.0917
Epoch: 121, Loss: 0.0929
Epoch: 122, Loss: 0.0925
Epoch: 123, Loss: 0.0929
Epoch: 124, Loss: 0.0927
Epoch: 125, Loss: 0.0927
Epoch: 126, Loss: 0.0929
Epoch: 127, Loss: 0.0922
Epoch: 128, Loss: 0.0910
Epoch: 129, Loss: 0.0919
Epoch: 130, Loss: 0.0910
Epoch: 131, Loss: 0.0914
Epoch: 132, Loss: 0.0911
Epoch: 133, Loss: 0.0918
Epoch: 134, Loss: 0.0923
Epoch: 135, Loss: 0.0910
Epoch: 136, Loss: 0.0911
Epoch: 137, Loss: 0.0907
Epo

100%|██████████| 17535/17535 [07:08<00:00, 40.89it/s]


Epoch: 150, Val Mean Rank: 281.93, Val MRR: 0.2259, Val Hits@10: 0.3673
Epoch: 151, Loss: 0.0892
Epoch: 152, Loss: 0.0896
Epoch: 153, Loss: 0.0894
Epoch: 154, Loss: 0.0902
Epoch: 155, Loss: 0.0895
Epoch: 156, Loss: 0.0894
Epoch: 157, Loss: 0.0906
Epoch: 158, Loss: 0.0900
Epoch: 159, Loss: 0.0896
Epoch: 160, Loss: 0.0887
Epoch: 161, Loss: 0.0897
Epoch: 162, Loss: 0.0885
Epoch: 163, Loss: 0.0896
Epoch: 164, Loss: 0.0882
Epoch: 165, Loss: 0.0889
Epoch: 166, Loss: 0.0881
Epoch: 167, Loss: 0.0884
Epoch: 168, Loss: 0.0886
Epoch: 169, Loss: 0.0880
Epoch: 170, Loss: 0.0882
Epoch: 171, Loss: 0.0873
Epoch: 172, Loss: 0.0877
Epoch: 173, Loss: 0.0890
Epoch: 174, Loss: 0.0880
Epoch: 175, Loss: 0.0885
Epoch: 176, Loss: 0.0872
Epoch: 177, Loss: 0.0878
Epoch: 178, Loss: 0.0883
Epoch: 179, Loss: 0.0879
Epoch: 180, Loss: 0.0877
Epoch: 181, Loss: 0.0878
Epoch: 182, Loss: 0.0877
Epoch: 183, Loss: 0.0866
Epoch: 184, Loss: 0.0880
Epoch: 185, Loss: 0.0870
Epoch: 186, Loss: 0.0873
Epoch: 187, Loss: 0.0871
Epo

100%|██████████| 17535/17535 [06:53<00:00, 42.39it/s]

Epoch: 200, Val Mean Rank: 276.36, Val MRR: 0.2234, Val Hits@10: 0.3697





Тестируем итоговую модель на тестовых данных.

In [None]:
rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
      f'Test Hits@10: {hits_at_10:.4f}')

100%|██████████| 20466/20466 [07:58<00:00, 42.74it/s]

Test Mean Rank: 282.98, Test MRR: 0.2199, Test Hits@10: 0.3624





Далее можно модифицирвать `TransE` или использовать другие методы (например, `Query2box`) для обучения эмбеддингов на примерах запросов к графу знаний.