In [8]:
import torch
import torch.nn.functional as F

from torch_geometric.nn import GCNConv
from torch_geometric.data import DenseDataLoader

from graf.data.docs import DocumentGraphs

In [6]:
dataset = DocumentGraphs('/Users/ygx/data/docs',  num_vocab=501)
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    print(data)

Processing...
Done!
Batch(edge_index=[1, 2, 14475391], x=[1, 501, 30])


In [7]:
dataset

DocumentGraphs(1)

In [43]:
dataset.num_node_features

30

In [16]:
class Net(torch.nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [17]:
model = Net()

In [18]:
model

Net(
  (conv1): GCNConv(30, 16)
  (conv2): GCNConv(16, 2)
)

In [24]:
dataset[0]

Data(edge_index=[2, 14475391], x=[501, 30])

In [26]:
model(dataset[0]).shape

torch.Size([501, 2])

In [37]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
target = torch.ones(501, dtype=torch.long)

model.train()
for epoch in range(5):
    optimizer.zero_grad()
    out = model(dataset[0])
    loss = F.nll_loss(out, target)
    print(f'Loss: {loss}')
    loss.backward()
    optimizer.step()

Loss: 0.003156759776175022
Loss: 0.003033219138160348
Loss: 0.0033191584516316652
Loss: 0.0016019244212657213
Loss: 0.001493251882493496
