In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeFeatures

In [131]:
dataset = TUDataset( root='~/tmp', name = 'ENZYMES', transform = NormalizeFeatures() )
dlen = len( dataset )

dlen

600

In [132]:
dataset.shuffle()

perc = 0.9
n = int( dlen * perc )
train_dataset = dataset[ :n ]
test_dataset = dataset[ n: ]

train_dataloader = DataLoader( train_dataset, batch_size= 32, shuffle = True )
test_dataloader = DataLoader( test_dataset, batch_size = 32, shuffle = False )

len( test_dataloader.dataset )

60

In [133]:
class GCN( nn.Module ):
    def __init__( self, hidden_channels ):
        super( GCN, self ).__init__()
        self.conv1 = GCNConv( dataset.num_node_features, hidden_channels )
        self.conv2 = GCNConv( hidden_channels, hidden_channels )
        self.conv3 = GCNConv( hidden_channels, hidden_channels )
        # self.conv4 = GCNConv( hidden_channels, hidden_channels )
        self.fc = nn.Linear( hidden_channels, dataset.num_classes )
        
    def forward( self, x, edge_index, batch ):
        x = self.conv1( x, edge_index )
        x = x.relu()
        x = self.conv2( x, edge_index )
        x = x.relu()
        x = self.conv3( x, edge_index )
        # x = x.relu()
        # x = self.conv4( x, edge_index )
        x = global_mean_pool( x, batch )                        # 全局平均池化 -> 需要传入batch!!
        x = F.dropout( x, p = 0.5, training = self.training )
        x = self.fc( x ) 
        return x

In [134]:
model = GCN( hidden_channels = 64 )
loss_f = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam( model.parameters(), lr = 0.01 )

In [135]:
for data in train_dataloader:
    print( data.batch.size() )

torch.Size([1073])
torch.Size([988])
torch.Size([1032])
torch.Size([1036])
torch.Size([1024])
torch.Size([1003])
torch.Size([1068])
torch.Size([965])
torch.Size([1053])
torch.Size([1043])
torch.Size([928])
torch.Size([933])
torch.Size([1126])
torch.Size([1060])
torch.Size([1051])
torch.Size([1096])
torch.Size([841])


In [136]:
def train():
    model.train()
    for data in train_dataloader:
        optimizer.zero_grad()
        out = model( data.x, data.edge_index, data.batch )
        loss = loss_f( out, data.y )
        loss.backward()
        optimizer.step()

   
def test( loader ): 
    model.eval()
    corr = 0
    for data in loader:
        out = model( data.x, data.edge_index, data.batch )
        pred = out.argmax( dim = 1 )
        corr += int( ( pred == data.y ).sum().item() )
    return corr / len( loader.dataset )
        

for epoch in range( 500 ):
    train()
    acc_train = test( train_dataloader ) * 100
    acc_test= test( test_dataloader ) * 100
    print( 'Train accuracy: %.3f, Test accuracy: %.3f' % ( acc_train, acc_test ) )

Train accuracy: 21.111, Test accuracy: 0.000
Train accuracy: 21.852, Test accuracy: 0.000
Train accuracy: 26.296, Test accuracy: 0.000
Train accuracy: 30.000, Test accuracy: 0.000
Train accuracy: 28.704, Test accuracy: 0.000
Train accuracy: 25.185, Test accuracy: 0.000
Train accuracy: 30.185, Test accuracy: 0.000
Train accuracy: 31.667, Test accuracy: 0.000
Train accuracy: 29.630, Test accuracy: 0.000
Train accuracy: 32.222, Test accuracy: 0.000
Train accuracy: 32.222, Test accuracy: 0.000
Train accuracy: 29.444, Test accuracy: 0.000
Train accuracy: 31.111, Test accuracy: 0.000
Train accuracy: 27.037, Test accuracy: 0.000
Train accuracy: 29.815, Test accuracy: 0.000
Train accuracy: 31.852, Test accuracy: 0.000
Train accuracy: 29.444, Test accuracy: 0.000
Train accuracy: 31.296, Test accuracy: 0.000
Train accuracy: 31.111, Test accuracy: 0.000
Train accuracy: 33.333, Test accuracy: 1.667
Train accuracy: 34.259, Test accuracy: 3.333
Train accuracy: 33.148, Test accuracy: 0.000
Train accu