In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

In [None]:
dataset = TUDataset( root='~/tmp', name = 'ENZYMES' )
dlen = len( dataset )

In [None]:
dataset.shuffle()

perc = 0.9
n = int( dlen * perc )
train_dataset = dataset[ :n ]
test_dataset = dataset[ n: ]

train_dataloader = DataLoader( train_dataset, batch_size= 32, shuffle = True )
test_dataloader = DataLoader( test_dataset, batch_size = 32, shuffle = False )

train_dataset

In [None]:
class GCN( nn.Module ):
    def __init__( self, hidden_channels ):
        super( GCN, self ).__init__()
        self.conv1 = GCNConv( dataset.num_node_features, hidden_channels )
        self.conv2 = GCNConv( hidden_channels, hidden_channels )
        self.conv3 = GCNConv( hidden_channels, hidden_channels )
        self.fc = nn.Linear( hidden_channels, dataset.num_classes )
        
    def forward( self, x, edge_index, batch ):
        x = self.conv1( x, edge_index )
        x = x.relu()
        x = self.conv2( x, edge_index )
        x = x.relu()
        x = self.conv3( x, edge_index )
        x = global_mean_pool( x, batch )                        # 全局平均池化 -> 需要传入batch!!
        x = F.dropout( x, p = 0.5, training = self.training )
        x = self.fc( x ) 
        return x

In [None]:
model = GCN( hidden_channels = 64 )
loss_f = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam( model.parameters(), lr = 0.01 )

In [None]:
for data in train_dataloader:
    print( data.batch.size() )

In [None]:
def train():
    model.train()
    for data in train_dataloader:
        optimizer.zero_grad()
        out = model( data.x, data.edge_index, data.batch )
        loss = loss_f( out, data.y )
        loss.backward()
        optimizer.step()

   
def test( loader ): 
    model.eval()
    r = 0
    for data in loader:
        out = model( data.x, data.edge_index, data.batch )
        pred = out.argmax( dim = 1 )
        r += int( ( pred == data.y ).sum() )
        print( r )
    return r / len( loader.dataset )
        

for epoch in range( 500 ):
    train()
    acc = test( test_dataloader ) * 100
    print( str(acc) + '%' )