In [48]:
import torch_geometric
from torch_geometric.datasets import Planetoid

## 加载 Data

In [49]:
dataset = Planetoid(root="data",name= "Cora")

### 查看data内容

In [50]:
print(dataset)
print(type(dataset))
print("number of graphs:\t\t",len(dataset))
print("number of classes:\t\t",dataset.num_classes)
print("number of node features:\t",dataset.num_node_features)
print("number of edge features:\t",dataset.num_edge_features)

Cora()
torch_geometric.datasets.planetoid.Planetoid
number of graphs:		 1
number of classes:		 7
number of node features:	 1433
number of edge features:	 0


#### 查看Data类

In [51]:
print(type(dataset.data))

<class 'torch_geometric.data.data.Data'>


In [52]:
print("点边关系 = ", dataset.data.edge_index)
print("点边关系的shape =", dataset.data.edge_index.shape)

点边关系 =  tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])
点边关系的shape = torch.Size([2, 10556])


点边关系第一行，表示各个点的起点；第二行表示各个点的终点；所以通过这两行的信息，也就将点边关系描述清楚了，图也就描述完了。

In [53]:
print("点的具体属性 = ",dataset.data.x)

点的具体属性 =  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [40]:
print("点的具体属性的shape = ", dataset.data.x.shape)

点的具体属性的shape =  torch.Size([2708, 1433])


说明有2708个点，每个点的属性值，即embedding是1433维。

In [54]:
print("y标 = ", dataset.data.y)

y标 =  tensor([3, 4, 4,  ..., 3, 3, 3])


In [55]:
print("y标的shape = ", dataset.data.y.shape)

y标的shape =  torch.Size([2708])


每个点都有一个y标，总共有2708个。

### Graph可视化

In [56]:
from matplotlib import pylab
from torch_geometric.utils.convert import to_networkx
import matplotlib.pyplot as plt
import networkx as nx

def save_graph(graph,file_name):
    #initialze Figure
    plt.figure(num=None, figsize=(40, 40), dpi=80)
    plt.axis('off')
    fig = plt.figure(1)
    graph = to_networkx(graph)

    pos = nx.spring_layout(graph)
    nx.draw_networkx_nodes(graph,pos)
    nx.draw_networkx_edges(graph,pos)
    nx.draw_networkx_labels(graph,pos)


    plt.savefig(file_name,bbox_inches="tight")
    pylab.close()
    del fig

In [15]:
save_graph(dataset.data,"my_graph.pdf")

### 模型构建及训练

In [44]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

In [45]:
data = dataset[0]

In [46]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [19]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv = SAGEConv(dataset.num_features,
                             dataset.num_classes,
                             aggr="max") # max, mean, add ...)

    def forward(self):
        x = self.conv(data.x, data.edge_index)
        return F.log_softmax(x, dim=1)

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [21]:
def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [47]:
best_val_acc = test_acc = 0
for epoch in range(1,200):
    train()
    _, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc

        
    log = 'Epoch: {:03d}, Val: {:.4f}, Test: {:.4f}'
    if epoch % 10 == 0:
        for name, param in model.named_parameters():
            print(f"参数 {name} 已更新 {param.requires_grad}")
        print(log.format(epoch, best_val_acc, test_acc),"\n")


参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 010, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 020, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 030, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 040, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 050, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 060, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 070, Val: 0.7460, Test: 0.7660 

参数 conv.lin_l.weight 已更新 True
参数 conv.lin_l.bias 已更新 True
参数 conv.lin_r.weight 已更新 True
Epoch: 080, Val:

KeyboardInterrupt: 