In [1]:
import tqdm
import torch
import torch.nn.functional as F
from torch.nn import Parameter

from torch_geometric.nn import GAE
from torch_geometric.datasets import RelLinkPredDataset
from rrgcn import RRGCNEmbedder

In [2]:
dataset = RelLinkPredDataset('/project_scratch/rrgcn_datasets/RLPD', 'FB15k-237')
data = dataset[0]

In [3]:
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn_link_pred.py
class DistMultDecoder(torch.nn.Module):
    def __init__(self, in_dim, num_relations, hidden_channels):
        super().__init__()
        self.pre_lin = torch.nn.Linear(in_dim, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.rel_emb = Parameter(torch.Tensor(num_relations, hidden_channels))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.rel_emb)

    def forward(self, z, edge_index, edge_type):
        z = F.relu(self.pre_lin(z))
        z = F.relu(self.lin(z))
        z = F.relu(self.lin1(z))
        z = F.relu(self.lin2(z))
        z_src, z_dst = z[edge_index[0]], z[edge_index[1]]
        rel = self.rel_emb[edge_type]
        return torch.sum(z_src * rel * z_dst, dim=1)


In [4]:
rrgcn_dim = 32000
pca_dim = 8192
num_layers = 2
lr = 0.0001
h = 2048

model = GAE(
    RRGCNEmbedder(
        data.num_nodes,
        num_layers,
        dataset.num_relations,
        rrgcn_dim,
        device="cuda",
        seed=42,
    ),
    DistMultDecoder(pca_dim, dataset.num_relations // 2, hidden_channels=h),
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


def negative_sampling(edge_index, num_nodes):
    # Sample edges by corrupting either the subject or the object of each edge.
    mask_1 = torch.rand(edge_index.size(1)) < 0.5
    mask_2 = ~mask_1

    neg_edge_index = edge_index.clone()
    neg_edge_index[0, mask_1.to(edge_index.device)] = torch.randint(
        num_nodes, (mask_1.sum(),)
    ).to(edge_index.device)
    neg_edge_index[1, mask_2.to(edge_index.device)] = torch.randint(
        num_nodes, (mask_2.sum(),)
    ).to(edge_index.device)
    return neg_edge_index


def train(z):
    model.train()
    optimizer.zero_grad()

    pos_out = model.decode(z, data.train_edge_index, data.train_edge_type)

    neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes)
    neg_out = model.decode(z, neg_edge_index, data.train_edge_type)

    out = torch.cat([pos_out, neg_out])

    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
    loss = F.binary_cross_entropy_with_logits(out, gt)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    return float(loss)


@torch.no_grad()
def filter_scores(scores, batch, true_triples, head=True):
    """ Filters a score matrix by setting the scores of known non-target true triples to -infinity """

    device = "cuda" if torch.cuda.is_available() else "cpu"

    indices = []  # indices of triples whose scores should be set to -infinity

    heads, tails = true_triples

    for i, (s, p, o) in enumerate(batch):
        s, p, o = (s.item(), p.item(), o.item())
        if head:
            indices.extend([(i, si) for si in heads[p, o] if si != s])
        else:
            indices.extend([(i, oi) for oi in tails[s, p] if oi != o])
        # -- We add the indices of all know triples except the one corresponding to the target triples.

    indices = torch.tensor(indices, device=device)

    scores[indices[:, 0], indices[:, 1]] = float("-inf")


# score calculation code adapted from
# https://github.com/thiviyanT/torch-rgcn
# instead of torch_geometric (100x speedup)
@torch.no_grad()
def evaluate(
    model,
    z,
    test_edge_index,
    test_edge_type,
    true_triples,
    num_nodes,
    batch_size=16,
    hits_at_k=[1, 3, 10],
    filter_candidates=True,
    verbose=True,
):
    """ Evaluates a triple scoring model. Does the sorting in a single, GPU-accelerated operation. """

    device = "cuda" if torch.cuda.is_available() else "cpu"

    rng = tqdm.trange if verbose else range

    ranks = []
    for head in [True, False]:  # head or tail prediction

        for fr in rng(0, test_edge_type.numel(), batch_size):
            to = min(fr + batch_size, test_edge_type.numel())

            batch = (
                torch.vstack(
                    (
                        test_edge_index[0, fr:to],
                        test_edge_type[fr:to],
                        test_edge_index[1, fr:to],
                    )
                )
                .to(device=device)
                .T
            )
            bn, _ = batch.size()

            # compute the full score matrix (filter later)
            bases = batch[:, 1:] if head else batch[:, :2]
            targets = batch[:, 0] if head else batch[:, 2]

            # collect the triples for which to compute scores
            bexp = bases.view(bn, 1, 2).expand(bn, num_nodes, 2)
            ar = (
                torch.arange(num_nodes, device=device)
                .view(1, num_nodes, 1)
                .expand(bn, num_nodes, 1)
            )
            toscore = torch.cat([ar, bexp] if head else [bexp, ar], dim=2)
            assert toscore.size() == (bn, num_nodes, 3)
            toscore = toscore.reshape((bn * num_nodes, 3))
            to_score_edge_index = torch.vstack((toscore[:, 0], toscore[:, 2]))
            to_score_edge_type = toscore[:, 1]
            scores = model.decode(z, to_score_edge_index, to_score_edge_type)
            scores = scores.reshape((bn, num_nodes))

            # filter out the true triples that aren't the target
            if filter_candidates:
                filter_scores(scores, batch.cpu(), true_triples, head=head)

            # Select the true scores, and count the number of values larger than than
            true_scores = scores[torch.arange(bn, device=device), targets]
            raw_ranks = torch.sum(
                scores > true_scores.view(bn, 1), dim=1, dtype=torch.long
            )
            # -- This is the "optimistic" rank (assuming it's sorted to the front of the ties)
            num_ties = torch.sum(
                scores == true_scores.view(bn, 1), dim=1, dtype=torch.long
            )

            # Account for ties (put the true example halfway down the ties)
            branks = raw_ranks + (num_ties - 1) // 2

            ranks.extend((branks + 1).tolist())

    mrr = sum([1.0 / rank for rank in ranks]) / len(ranks)

    hits = []
    for k in hits_at_k:
        hits.append(sum([1.0 if rank <= k else 0.0 for rank in ranks]) / len(ranks))

    return mrr, tuple(hits)  # , ranks


def generate_true_dict(all_triples):
    """ Generates a pair of dictionaries containing all true tail and head completions """
    heads, tails = (
        {(p, o): [] for _, p, o in all_triples},
        {(s, p): [] for s, p, _ in all_triples},
    )

    for s, p, o in all_triples:
        heads[p, o].append(s)
        tails[s, p].append(o)

    return heads, tails


device = "cuda"

data = data.to(device)
model.to(device)

z = model.encode(data.edge_index, data.edge_type)

m = z.mean(0, keepdim=True)
s = z.std(0, unbiased=False, keepdim=True)
z -= m
z /= s

U, S, V = torch.pca_lowrank(z, pca_dim)
z = torch.matmul(z, V[:, :pca_dim])

train_batch = torch.vstack(
    (data.train_edge_index[0, :], data.train_edge_type, data.train_edge_index[1, :])
).T
valid_batch = torch.vstack(
    (data.valid_edge_index[0, :], data.valid_edge_type, data.valid_edge_index[1, :])
).T
test_batch = torch.vstack(
    (data.test_edge_index[0, :], data.test_edge_type, data.test_edge_index[1, :])
).T

true_triples = generate_true_dict(
    torch.vstack((train_batch, valid_batch, test_batch)).cpu().numpy()
)

z = torch.tensor(z).to(device)

mrr = 0
for epoch in range(1, 8001):
    loss = train(z)
    if (epoch % 100) == 0:
        print(f"Epoch: {epoch:05d}, Loss: {loss:.4f}")

model.eval()
mrr, hits = evaluate(
    model,
    z,
    data.test_edge_index,
    data.test_edge_type,
    true_triples,
    data.num_nodes,
    batch_size=25,
    verbose=False,
)
print(mrr)
print(hits)

  z = torch.tensor(z).to(device)


Epoch: 00100, Loss: 0.1382
Epoch: 00200, Loss: 0.1086
Epoch: 00300, Loss: 0.0872
Epoch: 00400, Loss: 0.0863
Epoch: 00500, Loss: 0.0717
Epoch: 00600, Loss: 0.0661
Epoch: 00700, Loss: 0.0642
Epoch: 00800, Loss: 0.0611
Epoch: 00900, Loss: 0.0576
Epoch: 01000, Loss: 0.0528
Epoch: 01100, Loss: 0.0563
Epoch: 01200, Loss: 0.0547
Epoch: 01300, Loss: 0.0536
Epoch: 01400, Loss: 0.0470
Epoch: 01500, Loss: 0.0449
Epoch: 01600, Loss: 0.0494
Epoch: 01700, Loss: 0.0470
Epoch: 01800, Loss: 0.0424
Epoch: 01900, Loss: 0.0414
Epoch: 02000, Loss: 0.0396
Epoch: 02100, Loss: 0.0452
Epoch: 02200, Loss: 0.0510
Epoch: 02300, Loss: 0.0491
Epoch: 02400, Loss: 0.0363
Epoch: 02500, Loss: 0.0364
Epoch: 02600, Loss: 0.0440
Epoch: 02700, Loss: 0.0423
Epoch: 02800, Loss: 0.0369
Epoch: 02900, Loss: 0.0394
Epoch: 03000, Loss: 0.0481
Epoch: 03100, Loss: 0.0378
Epoch: 03200, Loss: 0.0323
Epoch: 03300, Loss: 0.0337
Epoch: 03400, Loss: 0.0330
Epoch: 03500, Loss: 0.0331
Epoch: 03600, Loss: 0.0308
Epoch: 03700, Loss: 0.0319
E

  branks = raw_ranks + (num_ties - 1) // 2


0.23814451162854694
(0.15665005374767907, 0.2563764291996482, 0.4116339294439558)


In [5]:
print(mrr)
print(hits)

0.23814451162854694
(0.15665005374767907, 0.2563764291996482, 0.4116339294439558)
