In [1]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu113.html

Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 5.9 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 43.7 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 35.8 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (747 kB)
[K     |████████████████████████████████| 747 kB 37.9 MB/s 
[?25hCollecting torch-geometric
  Downloading torch_geometric-2.0.2.tar.gz (325 kB)
[K     |███

# 第六课 图深度学习的应用（一）

理论课程上我们已经了解到图神经网络（GNN）的一些实际应用。相应地，在本次实践课上我们将对图神经网络在知识图谱当中的应用。

## GNN在知识图谱中的应用

在本节中我们以R-GCN(Relational GCN)为例，介绍GNN在知识图谱补全(knowledge graph completition)的应用。

注：R-GCN来自论文Modeling Relational Data with Graph Convolutional Networks，见链接https://arxiv.org/abs/1703.06103 。

## 1. 知识图谱数据集介绍

### 1.1 FB15k-237数据集介绍

FB15k-237数据集是学术界常用的知识图谱标准数据集。它是从Freebase知识库中取出一小部分主题词组成的子图，包含237种关系和14.5k种实体。这里的关系我们可以理解成图里的边，实体理解为图里的节点。

注：Freebase的主要数据来源包括维基百科Wikipedia、世界名人数据库NNDB、开放音乐数据库MusicBrainz等。

下面这个示例图展示了FB15k-237数据集的一小部分（英文版），图片源自https://arxiv.org/pdf/1911.04910.pdf 。方框标注出来的就是实体（节点），箭头表示的是关系（边）。我们可以看到实体中包含Piano（钢琴）、USA（美国）、Opera（歌剧）等等，关系中包含artists（艺术家）、nationality（国籍）等等。不难看出，这个子图主要是在描述Sergei Rachmanioff这位俄国著名的作曲家。

![snapshot-fb15k](fb15k.png)

接下来我们通过PyG来加载FB15k-237数据集。在这个数据集上我们需要完成知识图谱补全的任务，也就是给定两个实体，判断它们之间存不存关系。

In [2]:
from torch_geometric.datasets import RelLinkPredDataset
import torch

dataset = RelLinkPredDataset('./data', 'FB15k-237')
data = dataset[0]

Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/entities.dict
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/relations.dict
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/test.txt
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/train.txt
Downloading https://raw.githubusercontent.com/MichSchli/RelationPrediction/master/data/FB-Toutanova/valid.txt
Processing...
Done!


In [3]:
data

Data(edge_index=[2, 544230], num_nodes=14541, edge_type=[544230], train_edge_index=[2, 272115], train_edge_type=[272115], valid_edge_index=[2, 17535], valid_edge_type=[17535], test_edge_index=[2, 20466], test_edge_type=[20466])

In [4]:
data.train_edge_index

tensor([[3364, 8077, 3776,  ..., 9124, 1318, 7664],
        [8619, 6142, 7385,  ..., 7803, 3679, 6680]])

In [5]:
data.train_edge_type.max()+1, data.train_edge_type.shape

(tensor(237), torch.Size([272115]))

In [6]:
data.edge_type.max()+1

tensor(474)

### 1.2 AIFB数据集

由于本次课程FB15k-237数据比较大，为了便于同学们测试程序，我们采用另外一个小一些的数据集，AIFB数据集。在AIFB数据集中，有58086条边，90种关系，8285个实体。

In [7]:
from torch_geometric.datasets import Entities
dataset = Entities('./data', 'AIFB')
data = dataset[0]

Downloading https://data.dgl.ai/dataset/aifb.tgz
Extracting data/aifb.tgz
Processing...
Done!


In [8]:
data, data.edge_type.max().item()+1

(Data(edge_index=[2, 58086], edge_type=[58086], train_idx=[140], train_y=[140], test_idx=[36], test_y=[36], num_nodes=8285),
 90)

AIFB数据集本来是用于实体分类任务。所以我们需要额外划分一下边的训练/验证/测试集。

In [9]:
n = int(0.1 * data.edge_index.size(1))

data.test_edge_index = data.edge_index[:, :n]
data.test_edge_type = data.edge_type[: n]

data.valid_edge_index = data.edge_index[:, n: 2*n]
data.valid_edge_type = data.edge_type[n: 2*n]

data.train_edge_index = data.edge_index[:, 2*n: ]
data.train_edge_type = data.edge_type[2*n: ]

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
data = data.to(device)

## 2. 用于知识图谱的模型

### 2.1 定义R-GCN编码器和DistMult解码器

不难看出，知识图谱补全的任务和链接预测非常相似。所以我们可以用上节课所讲的自编码器的框架来完成知识图谱补全的任务。 \
调用GAE这个类，然后把我们定义好的编码器解码器作为输入传进这个类：GAE(Encoder, Decoder)。我们的编码器当然就是RGCN，而编码器是DistMult（常见的用于知识图谱表示的解码器）。

In [11]:
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn import GAE, RGCNConv

#### 2.1.1 R-GCN编码器
[R-GCN](https://arxiv.org/abs/1703.06103)是GCN在知识图谱领域中的延伸。它对GCN做了一些改动，以便于其能够捕捉实体和关系的信息。具体地，对于实体$i$，一层R-GCN的聚合过程表示为：
\
$$\mathbf{h}^{\prime}_i = \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
\frac{1}{|\mathcal{N}_r(i)|} \mathbf{W}_r \cdot \mathbf{h}_j + \mathbf{W}_{0} \cdot
\mathbf{h}_i $$

其中，$\mathbf{W}_0$和$\mathbf{W}_r$是模型的参数；$\mathcal{R}$表示知识图中所有的关系的集合，也就是边的类型；$\mathcal{N}_r(i)$表示实体$i$的以$r$为关系的邻居的集合。\
上述公式的第一项，表示聚合来自不同关系$r$的邻居的信息；公式的第二项表示保留实体自身的信息。\
值得注意的是，每一种边类型$r$都对应了一个变换矩阵$W_r$；当数据集中边的类型特别多的时候，$W_r$的个数就会非常多，从而导致模型的参数非常多，从而复杂度变得非常高。一种解决解决这个问题的方式是基分解（Basis-decomposition），通过定义一组基$\{V_1, V_2, \ldots, V_B\}$，它们的加权和用来表示参数矩阵。具体地，
$$W_r = \sum_{b=1}^{B}a_{rb}V_b$$
其中$a_{rb}$是对应于$W_r$的加权系数。这样的话，我们参数就只有B个$V_b$矩阵和一堆常量$a_{rb}$。\
R-GCN提出了Block-diagonal-decomposition来降低模型的复杂度，这里我们不再过多描述。感兴趣的同学可以参考[原论文 
Modeling Relational Data with Graph Convolutional Networks](https://arxiv.org/abs/1703.06103)的2.2小节。

那么接下来我们动手来实现一层R-GCN。

In [12]:
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch.nn import Parameter

class RGCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels, num_relations, 
                 num_bases=None, aggr='mean', **kwargs):
        """
        参数说明
        ----
        in_channels: 输入神经元个数
        out_channels: 输出神经元个数
        num_relations: 关系的种数
        num_bases: 基的数量
        aggr: 聚合方式
        """
        super().__init__(aggr=aggr, node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases
        
        if num_bases is not None: # 如果给定了基的数量，我们按照如下的方式定义参数矩阵
            self.weight = Parameter(
                torch.Tensor(num_bases, in_channels, out_channels))
            self.comp = Parameter(torch.Tensor(num_relations, num_bases)) # self.comp就是$a_{rb}$
        else:
            self.weight = Parameter(
                torch.Tensor(num_relations, in_channels, out_channels))
            self.register_parameter('comp', None)
            
        self.w0 = Parameter(torch.Tensor(in_channels, out_channels))
        self.bias = Parameter(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.comp)
        glorot(self.w0)
        zeros(self.bias)

    def forward(self, x, edge_index, edge_type=None):
        """
        参数说明
        ----
        x: 实体的特征矩阵
        edge_index: 边索引
        edge_type: 边类型（即关系）
        """

        x_l, x_r = x, x  # 分别对应终止节点和起始节点的特征矩阵
        size = (x_l.size(0), x_r.size(0))

        assert edge_type is not None
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        if self.num_bases is not None:  # 基分解 Basis-decomposition
            weight = (self.comp @ weight.view(self.num_bases, -1)).view(
                self.num_relations, self.in_channels, self.out_channels) 

        for i in range(self.num_relations):
            tmp = edge_index[:, edge_type == i] # 把这个关系对应的边取出来
            h = self.propagate(tmp, x=x_l, size=size)
            out = out + (h @ weight[i]) # R-GCN公式第一项

        w0 = self.w0
        out += x_r @ w0 # R-GCN公式第二项
        out += self.bias
        return out

    def message(self, x_j):
        return x_j

In [13]:
class RGCNEncoder(torch.nn.Module):
    
    def __init__(self, num_nodes, hidden_channels, num_relations):
        """
        参数说明
        ----
        num_nodes: 节点数量
        hidden_channels: 隐藏神经元个数
        num_relations: 关系的种数
        """
        super().__init__()
        self.node_emb = Parameter(torch.Tensor(num_nodes, hidden_channels))
        self.conv1 = RGCNConv(hidden_channels, hidden_channels, num_relations, num_bases=5)
        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations, num_bases=5)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.node_emb)
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        
    def forward(self, edge_index, edge_type):
        """
        参数说明
        ----
        edge_index: 边索引
        edge_type: 边类型（即关系）
        """
        x = self.node_emb
        x = self.conv1(x, edge_index, edge_type).relu_()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

In [14]:
# 因为原来的数据比较大，这里我们取前50条边测试一下
toy_edge_index = data.edge_index[:, :50] 
toy_edge_type = data.edge_type[:50]
toy_num_nodes = toy_edge_index.max().item() + 1
toy_num_relations = toy_edge_type.max().item() + 1

In [15]:
toy_num_nodes, toy_num_relations

(7703, 63)

In [16]:
encoder =  RGCNEncoder(toy_num_nodes, hidden_channels=16,
                num_relations=toy_num_relations).to(device)
z = encoder(toy_edge_index, toy_edge_type) 

In [17]:
z, z.shape

(tensor([[-0.0045, -0.0185, -0.0084,  ..., -0.0092,  0.0128, -0.0106],
         [-0.0072, -0.0078,  0.0173,  ..., -0.0039, -0.0153,  0.0071],
         [-0.0015, -0.0109,  0.0037,  ...,  0.0140, -0.0087,  0.0146],
         ...,
         [-0.0055,  0.0028, -0.0044,  ...,  0.0004,  0.0019, -0.0006],
         [ 0.0121, -0.0088, -0.0082,  ...,  0.0020, -0.0087, -0.0017],
         [-0.0025, -0.0102,  0.0167,  ...,  0.0054, -0.0202,  0.0002]],
        device='cuda:0', grad_fn=<AddBackward0>), torch.Size([7703, 16]))

#### 2.1.2 DistMult解码器
DistMult是知识图谱中常用的解码器。它用如下的函数来计算给定三元组$(z_\text{src}, r, z_\text{dst})$是真实三元组的可能性：\
$$f(z_\text{src}, r, z_\text{dst}) = z_\text{src}^T R_r z_\text{dst},$$
其中$R_r$表示关系$r$的嵌入（构成的对角矩阵）。\
我们通常称$f(z_\text{src}, r, z_\text{dst})$为分数函数(score function)。

In [18]:
class DistMultDecoder(torch.nn.Module):
    
    def __init__(self, num_relations, hidden_channels):
        super().__init__()
        self.rel_emb = Parameter(torch.Tensor(num_relations, hidden_channels)) # 关系r的嵌入的参数矩阵
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.rel_emb)

    def forward(self, z, edge_index, edge_type):
        """前向传播
        
        参数说明
        ----
        z: 编码器的输出
        edge_index: 边索引
        edge_type: 边类型
        """
        z_src, z_dst = z[edge_index[0]], z[edge_index[1]] # 取出实体的嵌入
        rel = self.rel_emb[edge_type]
        return torch.sum(z_src * rel * z_dst, dim=1)  # 计算score function

    def loss(self, z, train_edge_index, train_edge_type, num_nodes):
        """计算正边和负边的二值交叉熵
        
        参数说明
        ----
        z: 编码器的输出
        train_edge_index: 训练集中边的索引
        train_edge_type: 训练集中边的类型
        num_nodes: 节点数量
        """
        pos_out = self.forward(z, train_edge_index, train_edge_type)

        neg_edge_index = negative_sampling(train_edge_index, num_nodes)
        neg_out = self.forward(z, neg_edge_index, train_edge_type)

        out = torch.cat([pos_out, neg_out])
        gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
        cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt)
        reg_loss = z.pow(2).mean() + self.rel_emb.pow(2).mean() # 正则项
        loss = cross_entropy_loss + 1e-2 * reg_loss
        return loss
    
def negative_sampling(edge_index, num_nodes):
    """采样负边：随机破坏每条边的初始或者终止节点"""
    mask_1 = torch.rand(edge_index.size(1)) < 0.5 # 选一半的边破坏终止节点
    mask_2 = ~mask_1 # 选另一半的边破坏起始节点

    neg_edge_index = edge_index.clone()
    neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(), )).to(device)
    neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(), )).to(device)
    return neg_edge_index

In [19]:
decoder = DistMultDecoder(toy_num_relations, hidden_channels=16).to(device)
out = decoder(z, toy_edge_index, toy_edge_type)

In [20]:
out, out.shape

(tensor([-7.5956e-05,  9.9950e-05,  1.0870e-04, -2.7023e-05, -4.8169e-05,
         -9.1931e-05, -3.5832e-05,  1.0504e-04, -2.8472e-05,  1.8635e-05,
         -2.1819e-05, -1.5726e-05, -1.5313e-04,  1.5302e-04,  6.2634e-05,
          9.4934e-05, -2.9663e-05, -4.2502e-05,  1.2604e-04, -2.8465e-05,
          1.3336e-04,  4.1126e-05, -1.1284e-06,  4.7070e-05, -4.2751e-05,
         -1.6520e-04, -9.2914e-06,  1.1865e-04, -3.9190e-05,  2.2280e-04,
         -2.3947e-05,  9.0296e-05,  3.6421e-05,  1.5231e-05,  6.8740e-05,
          4.9193e-05, -1.0122e-04,  1.1152e-04,  2.8308e-05,  1.0912e-04,
          1.0533e-04, -9.2577e-05,  8.1163e-06,  2.1672e-04, -4.2681e-05,
          3.4621e-05, -9.3450e-05,  4.3897e-05, -4.6970e-05,  2.1095e-05],
        device='cuda:0', grad_fn=<SumBackward1>), torch.Size([50]))

In [21]:
decoder.loss(z, toy_edge_index, toy_edge_type, toy_num_nodes)

tensor(0.6934, device='cuda:0', grad_fn=<AddBackward0>)

#### 2.1.3 R-GCN编码器+DistMult解码器
接下来，我们将用R-GCN编码器和DistMult解码器搭建一个自编码器。

我们可以自己写一个`GAE`的类，也可以选择调用`torch_geometric.nn.GAE`。

In [22]:
class GAE(torch.nn.Module):
    """图自编码器。
    """
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, *args, **kwargs): 
        """编码功能"""
        return self.encoder(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """解码功能"""
        return self.decoder(*args, **kwargs)

In [23]:
# 也可以选择调用torch_geometric.nn.GAE
# from torch_geometric.nn import GAE 

model = GAE(encoder, decoder)
z = model.encode(toy_edge_index, toy_edge_type)
model.decoder.loss(z, toy_edge_index, toy_edge_type, toy_num_nodes)

tensor(0.6934, device='cuda:0', grad_fn=<AddBackward0>)

## 3. 在知识图谱补全的任务中训练R-GCN

In [24]:
def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)    
    for epoch in range(1, 1001):
        model.train()
        optimizer.zero_grad()

        z = model.encode(data.train_edge_index, data.train_edge_type)    
        loss = model.decoder.loss(z, data.train_edge_index, data.train_edge_type, data.num_nodes)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) # 希望梯度不要太大
        optimizer.step()
        if (epoch % 20) == 0:
            print(f'Epoch: {epoch:05d}, Loss: {loss:.4f}')
        if (epoch % 100) == 0:
            valid_mrr, test_mrr = test()
            print(f'Val MRR: {valid_mrr:.4f}, Test MRR: {test_mrr:.4f}')

对于知识图谱补全任务的性能的衡量，常用的指标是MRR (Mean Reciprocal Rank)，它返回所有正确答案的预测排名的倒数的均值（越大越好）。大家可以参考[该链接](https://www.cxybb.com/article/qq_36158230/120254381)。

In [26]:
"""
source code: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn_link_pred.py
author: pyg-team
"""

from tqdm import tqdm
import torch

@torch.no_grad()
def compute_mrr(model, data, z, edge_index, edge_type):
    ranks = []
    for i in tqdm(range(edge_type.numel())):
        (src, dst), rel = edge_index[:, i], edge_type[i]

        # Try all nodes as tails, but delete true triplets:
        tail_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        for (heads, tails), types in [
            (data.train_edge_index, data.train_edge_type),
            (data.valid_edge_index, data.valid_edge_type),
            (data.test_edge_index, data.test_edge_type),
        ]:
            tail_mask[tails[(heads == src) & (types == rel)]] = False

        tail = torch.arange(data.num_nodes)[tail_mask]
        tail = torch.cat([torch.tensor([dst]), tail])
        head = torch.full_like(tail, fill_value=src)
        eval_edge_index = torch.stack([head, tail], dim=0)
        eval_edge_type = torch.full_like(tail, fill_value=rel)

        out = model.decode(z, eval_edge_index, eval_edge_type)
        perm = out.argsort(descending=True)
        rank = int((perm == 0).nonzero(as_tuple=False).view(-1)[0])
        ranks.append(rank + 1)

        # Try all nodes as heads, but delete true triplets:
        head_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        for (heads, tails), types in [
            (data.train_edge_index, data.train_edge_type),
            (data.valid_edge_index, data.valid_edge_type),
            (data.test_edge_index, data.test_edge_type),
        ]:
            head_mask[heads[(tails == dst) & (types == rel)]] = False

        head = torch.arange(data.num_nodes)[head_mask]
        head = torch.cat([torch.tensor([src]), head])
        tail = torch.full_like(head, fill_value=dst)
        eval_edge_index = torch.stack([head, tail], dim=0)
        eval_edge_type = torch.full_like(head, fill_value=rel)

        out = model.decode(z, eval_edge_index, eval_edge_type)
        perm = out.argsort(descending=True)
        rank = int((perm == 0).nonzero(as_tuple=False).view(-1)[0])
        ranks.append(rank + 1)

    return (1. / torch.tensor(ranks, dtype=torch.float)).mean()

In [27]:
# from utils import compute_mrr 
@torch.no_grad()
def test():
    model.eval()
    z = model.encode(data.train_edge_index, data.train_edge_type)

    valid_mrr = compute_mrr(model, data, z, data.valid_edge_index, data.valid_edge_type)
    test_mrr = compute_mrr(model, data, z, data.test_edge_index, data.test_edge_type)

    return valid_mrr, test_mrr

In [28]:
model = GAE(
    RGCNEncoder(data.num_nodes, hidden_channels=64,
                num_relations=dataset.num_relations),
    DistMultDecoder(dataset.num_relations, hidden_channels=64),
).to(device)

In [29]:
train()

Epoch: 00020, Loss: 0.2976
Epoch: 00040, Loss: 0.1948
Epoch: 00060, Loss: 0.1619
Epoch: 00080, Loss: 0.1442
Epoch: 00100, Loss: 0.1230


100%|██████████| 5808/5808 [00:21<00:00, 271.47it/s]
100%|██████████| 5808/5808 [00:21<00:00, 272.19it/s]


Val MRR: 0.0917, Test MRR: 0.1404
Epoch: 00120, Loss: 0.1107
Epoch: 00140, Loss: 0.0982
Epoch: 00160, Loss: 0.0920
Epoch: 00180, Loss: 0.0813
Epoch: 00200, Loss: 0.0803


100%|██████████| 5808/5808 [00:21<00:00, 271.11it/s]
100%|██████████| 5808/5808 [00:21<00:00, 272.28it/s]


Val MRR: 0.1668, Test MRR: 0.1875
Epoch: 00220, Loss: 0.0735
Epoch: 00240, Loss: 0.0728
Epoch: 00260, Loss: 0.0663
Epoch: 00280, Loss: 0.0664
Epoch: 00300, Loss: 0.0652


100%|██████████| 5808/5808 [00:21<00:00, 271.68it/s]
100%|██████████| 5808/5808 [00:21<00:00, 273.05it/s]


Val MRR: 0.2643, Test MRR: 0.2674
Epoch: 00320, Loss: 0.0612
Epoch: 00340, Loss: 0.0614
Epoch: 00360, Loss: 0.0578
Epoch: 00380, Loss: 0.0574
Epoch: 00400, Loss: 0.0558


100%|██████████| 5808/5808 [00:21<00:00, 271.87it/s]
100%|██████████| 5808/5808 [00:21<00:00, 272.47it/s]


Val MRR: 0.3138, Test MRR: 0.3254
Epoch: 00420, Loss: 0.0575
Epoch: 00440, Loss: 0.0550
Epoch: 00460, Loss: 0.0549
Epoch: 00480, Loss: 0.0539
Epoch: 00500, Loss: 0.0513


100%|██████████| 5808/5808 [00:21<00:00, 266.50it/s]
100%|██████████| 5808/5808 [00:21<00:00, 269.98it/s]


Val MRR: 0.3575, Test MRR: 0.3755
Epoch: 00520, Loss: 0.0508
Epoch: 00540, Loss: 0.0523
Epoch: 00560, Loss: 0.0487
Epoch: 00580, Loss: 0.0507
Epoch: 00600, Loss: 0.0497


100%|██████████| 5808/5808 [00:21<00:00, 271.48it/s]
100%|██████████| 5808/5808 [00:21<00:00, 272.84it/s]


Val MRR: 0.3890, Test MRR: 0.3905
Epoch: 00620, Loss: 0.0489
Epoch: 00640, Loss: 0.0505
Epoch: 00660, Loss: 0.0469
Epoch: 00680, Loss: 0.0482
Epoch: 00700, Loss: 0.0459


100%|██████████| 5808/5808 [00:21<00:00, 271.45it/s]
100%|██████████| 5808/5808 [00:21<00:00, 273.20it/s]


Val MRR: 0.4029, Test MRR: 0.4149
Epoch: 00720, Loss: 0.0476
Epoch: 00740, Loss: 0.0457
Epoch: 00760, Loss: 0.0453
Epoch: 00780, Loss: 0.0477
Epoch: 00800, Loss: 0.0462


100%|██████████| 5808/5808 [00:21<00:00, 271.79it/s]
100%|██████████| 5808/5808 [00:21<00:00, 272.33it/s]


Val MRR: 0.4298, Test MRR: 0.4374
Epoch: 00820, Loss: 0.0456
Epoch: 00840, Loss: 0.0458
Epoch: 00860, Loss: 0.0444
Epoch: 00880, Loss: 0.0455
Epoch: 00900, Loss: 0.0439


100%|██████████| 5808/5808 [00:21<00:00, 272.44it/s]
100%|██████████| 5808/5808 [00:21<00:00, 272.15it/s]


Val MRR: 0.4397, Test MRR: 0.4414
Epoch: 00920, Loss: 0.0433
Epoch: 00940, Loss: 0.0423
Epoch: 00960, Loss: 0.0446
Epoch: 00980, Loss: 0.0447
Epoch: 01000, Loss: 0.0416


100%|██████████| 5808/5808 [00:21<00:00, 273.28it/s]
100%|██████████| 5808/5808 [00:21<00:00, 273.36it/s]

Val MRR: 0.4378, Test MRR: 0.4443



