In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu113.html

# 第六次作业

在本次作业中，我们利用R-GCN来完整实体分类任务。实体分类和节点分类任务相似，都是对图里的节点进行分类。

## 数据集加载

我们使用AIFB数据集来完成实体分类。R-GCN的论文里面就使用了这个数据集来完成实体分类任务。

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Entities


dataset = Entities('./data', 'AIFB')
data = dataset[0]

node_idx = torch.cat([data.train_idx, data.test_idx], dim=0)

In [3]:
data

Data(edge_index=[2, 58086], edge_type=[58086], train_idx=[140], train_y=[140], test_idx=[36], test_y=[36], num_nodes=8285)

In [4]:
data.edge_type.max().item()+1

90

In [5]:
data.train_idx

tensor([7378, 1749, 6778, 6502, 6436, 2732, 3456, 6266, 6916,  471, 8055,  899,
        4877,  828, 7095, 2016, 3595, 1738, 7727, 2631, 1014, 1011,  928, 6588,
        2201, 3435, 1820, 6030,  256, 6925, 4707, 2773,  637, 1361, 8088,  650,
        2548, 7784, 6652, 3793, 3780, 8079, 4510, 6915,  481, 5608, 6585, 5264,
         795, 5083, 1224, 2411, 1229, 6651, 7643, 6369, 6445, 7340, 7923, 3942,
        8206, 6496, 7230, 2865,  682, 4511, 7394, 4259, 6611, 2587,  871, 1625,
        4672, 3527, 2175,  963,  648, 7178, 6271, 7021, 7428, 5173, 5402, 2147,
         157, 6853, 5438, 2370, 3281, 3967, 6384, 4560, 7379, 7193,  752, 3761,
        7793, 5713, 1192, 6254, 7542, 5602,  129, 7460, 4967, 6756, 2682, 6196,
        1161, 2677,  876, 1171, 6073, 7901,  636, 8122, 4393, 4021, 4850, 5897,
        8164, 3552, 7567, 4337, 7296, 8117, 6685, 2033, 4427, 1577, 5519, 4201,
        6069,  843,  763,  398, 6118, 7777, 7978, 1286])

我们可以看到在AIFB数据集中，有58086条边，90种关系，8285个节点。其中有140个节点是训练集中的节点，36个节点是测试集中的节点。

## 代码填空
完成空缺的代码部分，完成实体分类任务。

In [6]:
# 大家既可以自己写RGCNConv，也可以用如下的代码调用RGCNConv
from torch_geometric.nn import RGCNConv
from torch.nn import Parameter

class RGCN(torch.nn.Module):
    
    def __init__(self, num_nodes, hidden_channels, num_classes, num_relations):
        super().__init__()
        
        ###################
        #### 代码填空 ######
        ###################
        super().__init__()

        self.node_emb = Parameter(torch.Tensor(num_nodes, hidden_channels))
        self.conv1 = RGCNConv(hidden_channels, hidden_channels, num_relations, num_bases=5)
        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations, num_bases=5)
        self.decode= torch.nn.Linear(hidden_channels, num_classes)

        self.reset_parameters()
        
    def reset_parameters(self):
        """初始化模型参数"""
        ###################
        #### 代码填空 ######
        ###################

        torch.nn.init.xavier_uniform_(self.node_emb)
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
    
    def forward(self, edge_index, edge_type):
        """前向传播"""
        ###################
        #### 代码填空 ######
        ###################

        x = self.node_emb
        x = self.conv1(x, edge_index, edge_type).relu_()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        
        x = self.decode(x)
        return x

In [7]:
def train():
    """训练模型"""
    ###################
    #### 代码填空 ######
    ###################
    model.train()
    optimizer.zero_grad()

    pred = model(data.edge_index, data.edge_type)
    pred = F.log_softmax(pred, dim=-1)
    loss = F.nll_loss(pred[data.train_idx], data.train_y)
    
    loss.backward()
    optimizer.step()

    return loss.item()

In [8]:
@torch.no_grad()
def test():
    model.eval()
    pred = model(data.edge_index, data.edge_type).argmax(dim=-1)
    ###########################################
    #### 代码填空，计算train_acc和test_acc ######
    ##########################################

    train_acc = (pred[data.train_idx] == data.train_y).sum() / len(data.train_idx)
    test_acc = (pred[data.test_idx] == data.test_y).sum() / len(data.test_idx)

    return train_acc.item(), test_acc.item()

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RGCN(data.num_nodes, 16, dataset.num_classes, dataset.num_relations).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)


for epoch in range(1, 100):
    loss = train()
    train_acc, test_acc = test()
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:02d}, TrainLoss: {loss:.4f}, TrainAcc: {train_acc:.4f} '
              f'TestAcc: {test_acc:.4f}')

Epoch: 05, TrainLoss: 1.0421, TrainAcc: 0.7429 TestAcc: 0.6667
Epoch: 10, TrainLoss: 0.3468, TrainAcc: 0.9357 TestAcc: 0.7222
Epoch: 15, TrainLoss: 0.0437, TrainAcc: 1.0000 TestAcc: 0.8889
Epoch: 20, TrainLoss: 0.0059, TrainAcc: 1.0000 TestAcc: 0.9167
Epoch: 25, TrainLoss: 0.0026, TrainAcc: 1.0000 TestAcc: 0.8889
Epoch: 30, TrainLoss: 0.0018, TrainAcc: 1.0000 TestAcc: 0.8611
Epoch: 35, TrainLoss: 0.0006, TrainAcc: 1.0000 TestAcc: 0.8611
Epoch: 40, TrainLoss: 0.0003, TrainAcc: 1.0000 TestAcc: 0.8611
Epoch: 45, TrainLoss: 0.0002, TrainAcc: 1.0000 TestAcc: 0.8611
Epoch: 50, TrainLoss: 0.0003, TrainAcc: 1.0000 TestAcc: 0.8889
Epoch: 55, TrainLoss: 0.0024, TrainAcc: 1.0000 TestAcc: 0.8889
Epoch: 60, TrainLoss: 0.0003, TrainAcc: 1.0000 TestAcc: 0.8889
Epoch: 65, TrainLoss: 0.0013, TrainAcc: 1.0000 TestAcc: 0.8889
Epoch: 70, TrainLoss: 0.0012, TrainAcc: 1.0000 TestAcc: 0.9167
Epoch: 75, TrainLoss: 0.0022, TrainAcc: 1.0000 TestAcc: 0.9167
Epoch: 80, TrainLoss: 0.0017, TrainAcc: 1.0000 TestAcc: