In [4]:
!pip install torch-geometric


Collecting torch-geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.9/63.9 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.4.0


In [5]:
import os.path as osp
import time

import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score

In [6]:
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader


from torch_geometric.nn import GATConv



In [12]:
ppi_dataset = PPI(root='./ppi_data/')

Downloading https://data.dgl.ai/dataset/ppi.zip
Extracting ppi_data/ppi.zip
Processing...
Done!


In [13]:
path = './ppi_data/'

In [14]:
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [30]:
train_dataset.num_classes

121

In [23]:
train_dataset.print_summary()

PPI (#graphs=20):
+------------+----------+----------+
|            |   #nodes |   #edges |
|------------+----------+----------|
| mean       |   2245.3 |  61318.4 |
| std        |    766.2 |  28601.2 |
| min        |    591   |   7708   |
| quantile25 |   1806   |  42288   |
| median     |   2326   |  59862   |
| quantile75 |   2799.2 |  85368.5 |
| max        |   3480   | 106754   |
+------------+----------+----------+


In [15]:
class GATPPI(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATConv(train_dataset.num_features, 256, heads=4)
        self.lin1 = torch.nn.Linear(train_dataset.num_features, 4 * 256)
        self.conv2 = GATConv(4 * 256, 256, heads=4)
        self.lin2 = torch.nn.Linear(4 * 256, 4 * 256)
        self.conv3 = GATConv(4 * 256, train_dataset.num_classes, heads=6,
                             concat=False)
        self.lin3 = torch.nn.Linear(4 * 256, train_dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index) + self.lin1(x))
        x = F.elu(self.conv2(x, edge_index) + self.lin2(x))
        x = self.conv3(x, edge_index) + self.lin3(x)
        return x

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
loss_op = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [17]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = loss_op(model(data.x, data.edge_index), data.y)
        total_loss += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    return total_loss / len(train_loader.dataset)

In [18]:
@torch.no_grad()
def test(loader):
    model.eval()

    ys, preds = [], []
    for data in loader:
        ys.append(data.y)
        out = model(data.x.to(device), data.edge_index.to(device))
        preds.append((out > 0).float().cpu())

    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()
    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

In [31]:
test(test_loader)

0.9868027882209436

In [19]:
times = []
for epoch in range(1, 101):
    start = time.time()
    loss = train()
    val_f1 = test(val_loader)
    test_f1 = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '
          f'Test: {test_f1:.4f}')
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

Epoch: 001, Loss: 0.9247, Val: 0.4610, Test: 0.4633
Epoch: 002, Loss: 0.5474, Val: 0.5345, Test: 0.5381
Epoch: 003, Loss: 0.4962, Val: 0.5220, Test: 0.5287
Epoch: 004, Loss: 0.4622, Val: 0.5424, Test: 0.5546
Epoch: 005, Loss: 0.4251, Val: 0.6322, Test: 0.6493
Epoch: 006, Loss: 0.3791, Val: 0.6952, Test: 0.7173
Epoch: 007, Loss: 0.3334, Val: 0.7331, Test: 0.7581
Epoch: 008, Loss: 0.2880, Val: 0.7814, Test: 0.8069
Epoch: 009, Loss: 0.2507, Val: 0.8136, Test: 0.8404
Epoch: 010, Loss: 0.2139, Val: 0.8403, Test: 0.8669
Epoch: 011, Loss: 0.1799, Val: 0.8482, Test: 0.8752
Epoch: 012, Loss: 0.1541, Val: 0.8779, Test: 0.9038
Epoch: 013, Loss: 0.1294, Val: 0.8891, Test: 0.9137
Epoch: 014, Loss: 0.1173, Val: 0.9017, Test: 0.9246
Epoch: 015, Loss: 0.1010, Val: 0.9189, Test: 0.9400
Epoch: 016, Loss: 0.0841, Val: 0.9304, Test: 0.9501
Epoch: 017, Loss: 0.0725, Val: 0.9313, Test: 0.9507
Epoch: 018, Loss: 0.0671, Val: 0.9381, Test: 0.9557
Epoch: 019, Loss: 0.0584, Val: 0.9378, Test: 0.9550
Epoch: 020, 