In [3]:
import dgl.data

dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)
print(f'Node feature dimensionality: {dataset.dim_nfeats}')
print(f'Number of graph categories: {dataset.gclasses}')

Downloading /root/.dgl/GINDataset.zip from https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip...
Extracting file to /root/.dgl/GINDataset


In [10]:
import torch
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)

In [11]:
train_dataloader

<dgl.dataloading.pytorch.GraphDataLoader at 0x7fdd8bd55520>

In [18]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=260, num_edges=1158,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), tensor([1, 0, 0, 0, 0])]


## A Batched Graph in DGL

In [19]:
batched_graph ,labels = batch

In [20]:
batched_graph, labels

(Graph(num_nodes=260, num_edges=1158,
       ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
       edata_schemes={}),
 tensor([1, 0, 0, 0, 0]))

In [21]:
graphs = dgl.unbatch(batched_graph)
print(graphs)

[Graph(num_nodes=27, num_edges=127,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=136, num_edges=560,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=22, num_edges=104,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=50, num_edges=224,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=25, num_edges=143,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})]


## Modeling

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:
from dgl.nn import GraphConv, GATConv, GINConv, APPNPConv

### GCN line

In [31]:
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

### GINConv line

## Training

### original version ( not using GPU )

In [None]:
GCN_model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
GCN_optimizer = torch.optim.Adam(GCN_model.parameters(), lr=0.01)

for epoch in tqdm(range(100)):
    for batched_graph, labels in train_dataloader:
        
        pred = GCN_model(batched_graph, batched_graph.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        GCN_optimizer.zero_grad()
        loss.backward()
        GCN_optimizer.step()
        
num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = GCN_model(batched_graph, batched_graph.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)
    
print(f'Test accuracy: {num_correct / num_tests}')

### GPU version

In [78]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GCN_model = GCN_model.to(device)
GCN_optimizer = torch.optim.Adam(GCN_model.parameters(), lr=0.01)

In [79]:
GCN_model.train()
for i in tqdm(range(10000)):
    loss_list = []
    true_samples = 0
    num_samples = 0
    for batch_id, batch_data in enumerate(train_dataloader):
        bg, labels = batch_data
        graph_feats = bg.ndata.pop('attr').float()
        bg = bg.to(device)
        graph_feats, labels = graph_feats.to(device), labels.to(device)
        logits = GCN_model(bg, graph_feats) ##
        loss = F.cross_entropy(logits, labels)
        true_samples += (logits.argmax(1) == labels.long()).float().sum().item()
        num_samples += len(labels)
        loss_list.append(loss.item())
        GCN_optimizer.zero_grad()
        loss.backward()
        GCN_optimizer.step()
    if i % 100 == 0:
        print(f"Epoch {i:05d} | Loss : {np.mean(loss_list):.4f} | Accuracy : {true_samples/num_samples:.4f}")

  0%|          | 1/10000 [00:01<2:48:12,  1.01s/it]

Epoch 00000 | Loss : 0.5793 | Accuracy : 0.7404


  1%|          | 101/10000 [01:32<2:32:00,  1.09it/s]

Epoch 00100 | Loss : 0.5071 | Accuracy : 0.7618


  2%|▏         | 201/10000 [03:03<2:28:29,  1.10it/s]

Epoch 00200 | Loss : 0.5035 | Accuracy : 0.7596


  3%|▎         | 301/10000 [04:36<2:26:44,  1.10it/s]

Epoch 00300 | Loss : 0.5047 | Accuracy : 0.7663


  4%|▍         | 401/10000 [06:08<2:27:48,  1.08it/s]

Epoch 00400 | Loss : 0.5011 | Accuracy : 0.7629


  5%|▌         | 501/10000 [07:41<2:24:03,  1.10it/s]

Epoch 00500 | Loss : 0.5051 | Accuracy : 0.7663


  6%|▌         | 601/10000 [09:15<2:23:57,  1.09it/s]

Epoch 00600 | Loss : 0.5012 | Accuracy : 0.7640


  7%|▋         | 701/10000 [10:45<2:22:16,  1.09it/s]

Epoch 00700 | Loss : 0.5024 | Accuracy : 0.7685


  8%|▊         | 801/10000 [12:19<2:21:50,  1.08it/s]

Epoch 00800 | Loss : 0.5035 | Accuracy : 0.7596


  9%|▉         | 901/10000 [13:52<2:20:39,  1.08it/s]

Epoch 00900 | Loss : 0.5056 | Accuracy : 0.7517


 10%|█         | 1001/10000 [15:26<2:25:02,  1.03it/s]

Epoch 01000 | Loss : 0.5056 | Accuracy : 0.7652


 11%|█         | 1101/10000 [16:59<2:13:43,  1.11it/s]

Epoch 01100 | Loss : 0.5061 | Accuracy : 0.7607


 12%|█▏        | 1201/10000 [18:31<2:15:20,  1.08it/s]

Epoch 01200 | Loss : 0.5044 | Accuracy : 0.7551


 13%|█▎        | 1301/10000 [20:04<2:13:52,  1.08it/s]

Epoch 01300 | Loss : 0.5016 | Accuracy : 0.7596


 14%|█▍        | 1401/10000 [21:36<2:13:06,  1.08it/s]

Epoch 01400 | Loss : 0.5024 | Accuracy : 0.7629


 15%|█▌        | 1501/10000 [23:11<2:07:40,  1.11it/s]

Epoch 01500 | Loss : 0.5043 | Accuracy : 0.7573


 16%|█▌        | 1601/10000 [24:44<2:05:03,  1.12it/s]

Epoch 01600 | Loss : 0.5035 | Accuracy : 0.7674


 17%|█▋        | 1701/10000 [26:16<2:08:56,  1.07it/s]

Epoch 01700 | Loss : 0.5042 | Accuracy : 0.7629


 18%|█▊        | 1801/10000 [27:49<2:04:33,  1.10it/s]

Epoch 01800 | Loss : 0.5021 | Accuracy : 0.7573


 19%|█▉        | 1901/10000 [29:21<2:06:20,  1.07it/s]

Epoch 01900 | Loss : 0.5002 | Accuracy : 0.7618


 20%|██        | 2001/10000 [30:53<1:58:09,  1.13it/s]

Epoch 02000 | Loss : 0.5090 | Accuracy : 0.7697


 21%|██        | 2101/10000 [32:25<1:59:23,  1.10it/s]

Epoch 02100 | Loss : 0.5036 | Accuracy : 0.7674


 22%|██▏       | 2201/10000 [33:58<1:59:47,  1.09it/s]

Epoch 02200 | Loss : 0.4968 | Accuracy : 0.7674


 23%|██▎       | 2301/10000 [35:31<2:00:44,  1.06it/s]

Epoch 02300 | Loss : 0.4982 | Accuracy : 0.7674


 24%|██▍       | 2401/10000 [37:03<1:58:27,  1.07it/s]

Epoch 02400 | Loss : 0.5009 | Accuracy : 0.7618


 25%|██▌       | 2501/10000 [38:34<1:53:54,  1.10it/s]

Epoch 02500 | Loss : 0.5005 | Accuracy : 0.7719


 26%|██▌       | 2601/10000 [40:06<1:53:56,  1.08it/s]

Epoch 02600 | Loss : 0.4988 | Accuracy : 0.7629


 27%|██▋       | 2701/10000 [41:40<1:54:23,  1.06it/s]

Epoch 02700 | Loss : 0.4961 | Accuracy : 0.7742


 28%|██▊       | 2801/10000 [43:17<1:55:44,  1.04it/s]

Epoch 02800 | Loss : 0.5053 | Accuracy : 0.7640


 29%|██▉       | 2901/10000 [44:49<1:50:46,  1.07it/s]

Epoch 02900 | Loss : 0.4931 | Accuracy : 0.7708


 30%|███       | 3001/10000 [46:22<1:48:13,  1.08it/s]

Epoch 03000 | Loss : 0.4930 | Accuracy : 0.7719


 31%|███       | 3101/10000 [47:55<1:45:37,  1.09it/s]

Epoch 03100 | Loss : 0.4950 | Accuracy : 0.7685


 32%|███▏      | 3201/10000 [49:27<1:45:43,  1.07it/s]

Epoch 03200 | Loss : 0.4979 | Accuracy : 0.7652


 33%|███▎      | 3301/10000 [51:00<1:43:17,  1.08it/s]

Epoch 03300 | Loss : 0.4952 | Accuracy : 0.7697


 34%|███▍      | 3401/10000 [52:32<1:42:03,  1.08it/s]

Epoch 03400 | Loss : 0.4957 | Accuracy : 0.7697


 35%|███▌      | 3501/10000 [54:04<1:39:59,  1.08it/s]

Epoch 03500 | Loss : 0.4875 | Accuracy : 0.7910


 36%|███▌      | 3601/10000 [55:36<1:36:06,  1.11it/s]

Epoch 03600 | Loss : 0.4896 | Accuracy : 0.7719


 37%|███▋      | 3701/10000 [57:09<1:37:02,  1.08it/s]

Epoch 03700 | Loss : 0.5048 | Accuracy : 0.7719


 38%|███▊      | 3801/10000 [58:44<1:36:26,  1.07it/s]

Epoch 03800 | Loss : 0.4885 | Accuracy : 0.7764


 39%|███▉      | 3901/10000 [1:00:16<1:34:21,  1.08it/s]

Epoch 03900 | Loss : 0.4889 | Accuracy : 0.7629


 40%|████      | 4001/10000 [1:01:47<1:29:05,  1.12it/s]

Epoch 04000 | Loss : 0.4951 | Accuracy : 0.7607


 41%|████      | 4101/10000 [1:03:20<1:31:15,  1.08it/s]

Epoch 04100 | Loss : 0.4900 | Accuracy : 0.7787


 42%|████▏     | 4201/10000 [1:04:51<1:26:05,  1.12it/s]

Epoch 04200 | Loss : 0.4972 | Accuracy : 0.7618


 43%|████▎     | 4301/10000 [1:06:25<1:26:27,  1.10it/s]

Epoch 04300 | Loss : 0.4939 | Accuracy : 0.7652


 44%|████▍     | 4401/10000 [1:07:59<1:34:53,  1.02s/it]

Epoch 04400 | Loss : 0.4896 | Accuracy : 0.7753


 45%|████▌     | 4501/10000 [1:09:31<1:23:54,  1.09it/s]

Epoch 04500 | Loss : 0.4921 | Accuracy : 0.7742


 46%|████▌     | 4601/10000 [1:11:04<1:21:10,  1.11it/s]

Epoch 04600 | Loss : 0.4931 | Accuracy : 0.7820


 47%|████▋     | 4701/10000 [1:12:36<1:18:58,  1.12it/s]

Epoch 04700 | Loss : 0.4907 | Accuracy : 0.7809


 48%|████▊     | 4801/10000 [1:14:06<1:24:02,  1.03it/s]

Epoch 04800 | Loss : 0.4885 | Accuracy : 0.7888


 49%|████▉     | 4901/10000 [1:15:38<1:18:57,  1.08it/s]

Epoch 04900 | Loss : 0.4924 | Accuracy : 0.7764


 50%|█████     | 5001/10000 [1:17:11<1:17:23,  1.08it/s]

Epoch 05000 | Loss : 0.4920 | Accuracy : 0.7798


 51%|█████     | 5101/10000 [1:18:41<1:12:09,  1.13it/s]

Epoch 05100 | Loss : 0.4870 | Accuracy : 0.7775


 52%|█████▏    | 5201/10000 [1:20:12<1:14:00,  1.08it/s]

Epoch 05200 | Loss : 0.4907 | Accuracy : 0.7764


 53%|█████▎    | 5301/10000 [1:21:43<1:11:36,  1.09it/s]

Epoch 05300 | Loss : 0.4883 | Accuracy : 0.7854


 54%|█████▍    | 5401/10000 [1:23:16<1:11:47,  1.07it/s]

Epoch 05400 | Loss : 0.5130 | Accuracy : 0.7685


 55%|█████▌    | 5501/10000 [1:24:48<1:09:47,  1.07it/s]

Epoch 05500 | Loss : 0.4891 | Accuracy : 0.7730


 56%|█████▌    | 5601/10000 [1:26:20<1:08:20,  1.07it/s]

Epoch 05600 | Loss : 0.4925 | Accuracy : 0.7764


 57%|█████▋    | 5701/10000 [1:27:51<1:04:06,  1.12it/s]

Epoch 05700 | Loss : 0.4937 | Accuracy : 0.7685


 58%|█████▊    | 5801/10000 [1:29:23<1:07:06,  1.04it/s]

Epoch 05800 | Loss : 0.4905 | Accuracy : 0.7775


 59%|█████▉    | 5901/10000 [1:30:54<1:02:17,  1.10it/s]

Epoch 05900 | Loss : 0.4913 | Accuracy : 0.7708


 60%|██████    | 6001/10000 [1:32:27<1:03:11,  1.05it/s]

Epoch 06000 | Loss : 0.4910 | Accuracy : 0.7640


 61%|██████    | 6101/10000 [1:34:00<59:17,  1.10it/s]  

Epoch 06100 | Loss : 0.4939 | Accuracy : 0.7640


 62%|██████▏   | 6201/10000 [1:35:33<57:46,  1.10it/s]  

Epoch 06200 | Loss : 0.4920 | Accuracy : 0.7753


 63%|██████▎   | 6301/10000 [1:37:04<56:51,  1.08it/s]

Epoch 06300 | Loss : 0.4904 | Accuracy : 0.7764


 64%|██████▍   | 6401/10000 [1:38:36<57:38,  1.04it/s]  

Epoch 06400 | Loss : 0.4889 | Accuracy : 0.7697


 65%|██████▌   | 6501/10000 [1:40:09<54:47,  1.06it/s]

Epoch 06500 | Loss : 0.4889 | Accuracy : 0.7742


 66%|██████▌   | 6601/10000 [1:41:41<53:00,  1.07it/s]

Epoch 06600 | Loss : 0.4850 | Accuracy : 0.7820


 67%|██████▋   | 6701/10000 [1:43:17<50:59,  1.08it/s]  

Epoch 06700 | Loss : 0.4898 | Accuracy : 0.7787


 68%|██████▊   | 6801/10000 [1:44:48<49:32,  1.08it/s]

Epoch 06800 | Loss : 0.4857 | Accuracy : 0.7798


 69%|██████▉   | 6901/10000 [1:46:21<47:31,  1.09it/s]

Epoch 06900 | Loss : 0.4958 | Accuracy : 0.7798


 70%|███████   | 7001/10000 [1:47:53<46:23,  1.08it/s]

Epoch 07000 | Loss : 0.4891 | Accuracy : 0.7753


 71%|███████   | 7101/10000 [1:49:25<44:19,  1.09it/s]

Epoch 07100 | Loss : 0.4888 | Accuracy : 0.7831


 72%|███████▏  | 7201/10000 [1:50:57<43:30,  1.07it/s]

Epoch 07200 | Loss : 0.4903 | Accuracy : 0.7719


 73%|███████▎  | 7301/10000 [1:52:30<41:39,  1.08it/s]

Epoch 07300 | Loss : 0.4892 | Accuracy : 0.7697


 74%|███████▍  | 7401/10000 [1:54:02<40:02,  1.08it/s]

Epoch 07400 | Loss : 0.4921 | Accuracy : 0.7742


 75%|███████▌  | 7501/10000 [1:55:35<38:08,  1.09it/s]

Epoch 07500 | Loss : 0.4916 | Accuracy : 0.7674


 76%|███████▌  | 7601/10000 [1:57:06<35:36,  1.12it/s]

Epoch 07600 | Loss : 0.4924 | Accuracy : 0.7640


 77%|███████▋  | 7701/10000 [1:58:39<34:06,  1.12it/s]

Epoch 07700 | Loss : 0.4969 | Accuracy : 0.7719


 78%|███████▊  | 7801/10000 [2:00:11<33:56,  1.08it/s]

Epoch 07800 | Loss : 0.4907 | Accuracy : 0.7764


 79%|███████▉  | 7901/10000 [2:01:43<32:13,  1.09it/s]

Epoch 07900 | Loss : 0.4923 | Accuracy : 0.7708


 80%|████████  | 8001/10000 [2:03:15<30:45,  1.08it/s]

Epoch 08000 | Loss : 0.4906 | Accuracy : 0.7674


 81%|████████  | 8101/10000 [2:04:48<29:01,  1.09it/s]

Epoch 08100 | Loss : 0.4856 | Accuracy : 0.7753


 82%|████████▏ | 8201/10000 [2:06:22<28:00,  1.07it/s]

Epoch 08200 | Loss : 0.4899 | Accuracy : 0.7730


 83%|████████▎ | 8301/10000 [2:07:54<26:17,  1.08it/s]

Epoch 08300 | Loss : 0.4920 | Accuracy : 0.7640


 84%|████████▍ | 8401/10000 [2:09:25<24:35,  1.08it/s]

Epoch 08400 | Loss : 0.4924 | Accuracy : 0.7719


 85%|████████▌ | 8501/10000 [2:10:57<22:53,  1.09it/s]

Epoch 08500 | Loss : 0.4908 | Accuracy : 0.7764


 86%|████████▌ | 8601/10000 [2:12:30<21:13,  1.10it/s]

Epoch 08600 | Loss : 0.4915 | Accuracy : 0.7854


 87%|████████▋ | 8701/10000 [2:14:03<20:08,  1.08it/s]

Epoch 08700 | Loss : 0.4913 | Accuracy : 0.7652


 88%|████████▊ | 8801/10000 [2:15:34<17:51,  1.12it/s]

Epoch 08800 | Loss : 0.4899 | Accuracy : 0.7753


 89%|████████▉ | 8901/10000 [2:17:06<17:02,  1.08it/s]

Epoch 08900 | Loss : 0.4914 | Accuracy : 0.7775


 90%|█████████ | 9001/10000 [2:18:39<14:53,  1.12it/s]

Epoch 09000 | Loss : 0.4893 | Accuracy : 0.7730


 91%|█████████ | 9101/10000 [2:20:10<13:51,  1.08it/s]

Epoch 09100 | Loss : 0.4881 | Accuracy : 0.7742


 92%|█████████▏| 9201/10000 [2:21:41<11:54,  1.12it/s]

Epoch 09200 | Loss : 0.4863 | Accuracy : 0.7865


 93%|█████████▎| 9301/10000 [2:23:11<10:45,  1.08it/s]

Epoch 09300 | Loss : 0.4908 | Accuracy : 0.7663


 94%|█████████▍| 9401/10000 [2:24:44<09:13,  1.08it/s]

Epoch 09400 | Loss : 0.4931 | Accuracy : 0.7730


 95%|█████████▌| 9501/10000 [2:26:15<07:45,  1.07it/s]

Epoch 09500 | Loss : 0.5198 | Accuracy : 0.7753


 96%|█████████▌| 9601/10000 [2:27:48<06:10,  1.08it/s]

Epoch 09600 | Loss : 0.4885 | Accuracy : 0.7719


 97%|█████████▋| 9701/10000 [2:29:19<04:37,  1.08it/s]

Epoch 09700 | Loss : 0.4919 | Accuracy : 0.7674


 98%|█████████▊| 9801/10000 [2:30:51<03:01,  1.10it/s]

Epoch 09800 | Loss : 0.4925 | Accuracy : 0.7708


 99%|█████████▉| 9901/10000 [2:32:22<01:27,  1.13it/s]

Epoch 09900 | Loss : 0.4912 | Accuracy : 0.7798


100%|██████████| 10000/10000 [2:33:54<00:00,  1.08it/s]


Epoch 100 

| Loss : 0.5351 | Accuracy : 0.7517

Epoch 10000


### Hyperparameter + Ray tune
Fail

In [68]:
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [71]:
config = {
    "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16])
}

In [76]:
GCN_model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
GCN_optimizer = torch.optim.Adam(GCN_model.parameters(), lr=0.01)

In [65]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"
    if torch.cuda.device_count() > 1:
        GCN_model = nn.DataParallel(GCN_model)
GCN_model.to(device)

DataParallel(
  (module): DataParallel(
    (module): GCN(
      (conv1): GraphConv(in=3, out=16, normalization=both, activation=None)
      (conv2): GraphConv(in=16, out=2, normalization=both, activation=None)
    )
  )
)

In [None]:
GCN_model.train()
for i in tqdm(range(10000)):
    loss_list = []
    true_samples = 0
    num_samples = 0
    for batch_id, batch_data in enumerate(train_dataloader):
        bg, labels = batch_data
        graph_feats = bg.ndata.pop('attr').float()
        bg = bg.to(device)
        graph_feats, labels = graph_feats.to(device), labels.to(device)
        logits = GCN_model(bg, graph_feats) ##
        loss = F.cross_entropy(logits, labels)
        true_samples += (logits.argmax(1) == labels.long()).float().sum().item()
        num_samples += len(labels)
        loss_list.append(loss.item())
        GCN_optimizer.zero_grad()
        loss.backward()
        GCN_optimizer.step()
    if i % 100 == 0:
        print(f"Epoch {i:05d} | Loss : {np.mean(loss_list):.4f} | Accuracy : {true_samples/num_samples:.4f}")