# 如何将pyg和fastai结合起来呢？

### 引入库区

In [6]:
from fastai import *
from fastai.basic_data import *
from fastai.torch_core import *
from fastai.data_block import *
from fastai.data_block import PreProcessors
import torch_geometric
import torch

import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool, JumpingKnowledge
from torch_geometric.datasets import TUDataset

### 自己写的函数区

In [5]:
# 定义自己的网络模型
class GCN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(hidden, hidden))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    


### 试验区

数据下载

In [7]:
dataset = TUDataset('./', 'MUTAG')

Downloading https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/MUTAG.zip
Extracting ./MUTAG.zip
Processing...
Done!


In [8]:
dataset

MUTAG(188)

In [9]:
dataset[0]

Data(edge_attr=[38, 4], edge_index=[2, 38], x=[17, 7], y=[1])