# Args

In [None]:
args = {
        'device': 0,
        'num_layer': 3,
        'emb_dim': 256,
        'drop_ratio': 0.2,
        'batch_size': 64,
        'lr': 0.01,
        'epochs': 50,
        'num_vocab': 5000,
        'max_seq_len': 5,
        'diff_pool_layers': [(2,5)],
        'max_num_nodes': 750,
        'random_split': False,
        'dataset': "ogbg-code2",
        'num_workers': 0,
        'model_save_path': "best_model_params",
        'eval_results_path': 'eval_results'
    }

# Setting up the dependencies and download the dataset

In [None]:
!pip install ogb
!pip install torch_geometric
!python -c "import ogb; print(ogb.__version__)"

import os
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.loader import DataLoader
import torch
import pandas as pd
import torch.nn.functional as F
from tqdm.notebook import tqdm
print(torch.__version__)

# The PyG built-in GCNConv
from torch_geometric.nn import GCNConv

import torch_geometric.transforms as T
from torch_geometric.nn import global_add_pool, global_mean_pool

dataset = PygGraphPropPredDataset(name = "ogbg-code2")

Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl (78 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7026 sha256=2d2f69b38ef6566edd9cf8cb933e8844cc5cef8fe8799b3ef41b709ca7c28622
  Stored in directory: /root/.cache/pip/wheels/3d/fe/b0/27a9892da57472e538c7452a721a9cf463cc03cf7379889266
Successfully built littleutils
Installing collected packages: littleut

Downloaded 0.91 GB: 100%|██████████| 934/934 [00:36<00:00, 25.59it/s]


Extracting dataset/code2.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 452741/452741 [00:01<00:00, 356232.98it/s]


Converting graphs into PyG objects...


100%|██████████| 452741/452741 [00:21<00:00, 20990.71it/s]


Saving...


Done!


# Utils

In [None]:
from collections import Counter
import numpy as np
import torch

class ASTNodeEncoder(torch.nn.Module):
    '''
        Input:
            x: default node feature. the first and second column represents node type and node attributes.
            depth: The depth of the node in the AST.

        Output:
            emb_dim-dimensional vector

    '''
    def __init__(self, emb_dim, num_nodetypes, num_nodeattributes, max_depth):
        super(ASTNodeEncoder, self).__init__()

        self.max_depth = max_depth

        self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim)
        self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim)
        self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim)


    def forward(self, x, depth):
        depth[depth > self.max_depth] = self.max_depth
        return self.type_encoder(x[:,0]) + self.attribute_encoder(x[:,1]) + self.depth_encoder(depth)



def get_vocab_mapping(seq_list, num_vocab):
    '''
        Input:
            seq_list: a list of sequences
            num_vocab: vocabulary size
        Output:
            vocab2idx:
                A dictionary that maps vocabulary into integer index.
                Additioanlly, we also index '__UNK__' and '__EOS__'
                '__UNK__' : out-of-vocabulary term
                '__EOS__' : end-of-sentence

            idx2vocab:
                A list that maps idx to actual vocabulary.

    '''

    vocab_cnt = {}
    vocab_list = []
    for seq in seq_list:
        for w in seq:
            if w in vocab_cnt:
                vocab_cnt[w] += 1
            else:
                vocab_cnt[w] = 1
                vocab_list.append(w)

    cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
    topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab]

    print('Coverage of top {} vocabulary:'.format(num_vocab))
    print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list))

    vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)}
    idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]

    # print(topvocab)
    # print([vocab_list[v] for v in topvocab[:10]])
    # print([vocab_list[v] for v in topvocab[-10:]])

    vocab2idx['__UNK__'] = num_vocab
    idx2vocab.append('__UNK__')

    vocab2idx['__EOS__'] = num_vocab + 1
    idx2vocab.append('__EOS__')

    # test the correspondence between vocab2idx and idx2vocab
    for idx, vocab in enumerate(idx2vocab):
        assert(idx == vocab2idx[vocab])

    # test that the idx of '__EOS__' is len(idx2vocab) - 1.
    # This fact will be used in decode_arr_to_seq, when finding __EOS__
    assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1)

    return vocab2idx, idx2vocab

def augment_edge(data):
    '''
        Input:
            data: PyG data object
        Output:
            data (edges are augmented in the following ways):
                data.edge_index: Added next-token edge. The inverse edges were also added.
                data.edge_attr (torch.Long):
                    data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
                    data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
    '''

    ##### AST edge
    edge_index_ast = data.edge_index
    edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))

    ##### Inverse AST edge
    edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0)
    edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1)


    ##### Next-token edge

    ## Obtain attributed nodes and get their indices in dfs order
    # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
    # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]

    ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
    attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1,) == 1)[0]

    ## build next token edge
    # Given: attributed_node_idx_in_dfs_order
    #        [1, 3, 4, 5, 8, 9, 12]
    # Output:
    #    [[1, 3, 4, 5, 8, 9]
    #     [3, 4, 5, 8, 9, 12]
    edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0)
    edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1)


    ##### Inverse next-token edge
    edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0)
    edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))


    data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim = 1)
    data.edge_attr = torch.cat([edge_attr_ast,   edge_attr_ast_inverse, edge_attr_nextoken,  edge_attr_nextoken_inverse], dim = 0)

    return data

def encode_y_to_arr(data, vocab2idx, max_seq_len):
    '''
    Input:
        data: PyG graph object
        output: add y_arr to data
    '''

    # PyG >= 1.5.0
    seq = data.y

    # PyG = 1.4.3
    # seq = data.y[0]

    data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len)

    return data

def encode_seq_to_arr(seq, vocab2idx, max_seq_len):
    '''
    Input:
        seq: A list of words
        output: add y_arr (torch.Tensor)
    '''

    augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq))
    return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long)


def decode_arr_to_seq(arr, idx2vocab):
    '''
        Input: torch 1d array: y_arr
        Output: a sequence of words.
    '''


    eos_idx_list = torch.nonzero(arr == len(idx2vocab) - 1, as_tuple=False) # find the position of __EOS__ (the last vocab in idx2vocab)
    if len(eos_idx_list) > 0:
        clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__
    else:
        clippted_arr = arr

    return list(map(lambda x: idx2vocab[x], clippted_arr.cpu()))


# Define GNN and DiffPool

In [None]:
from torch_geometric.nn import dense_diff_pool, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set, DenseGCNConv
from torch_geometric.utils import to_dense_batch, to_dense_adj
import torch.nn.functional as F

# The generic GCN used by different layers of DiffPool
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super(GCN, self).__init__()

        # GCN layers
        self.convs = torch.nn.ModuleList()
        self.convs.extend([DenseGCNConv(hidden_dim, hidden_dim) for i in range(num_layers-1)])
        self.convs.append(DenseGCNConv(hidden_dim, output_dim))

        # Batch norm
        self.bns = torch.nn.ModuleList()
        self.bns.extend([torch.nn.BatchNorm1d(hidden_dim) for i in range(num_layers-1)])

        # Log Softmax
        self.softmax = torch.nn.LogSoftmax(dim=1)

        # Probability of an element getting zeroed
        self.dropout = dropout

    def forward(self, x, adj):
        for i in range(len(self.convs)-1):
          x = self.convs[i](x, adj)
          x = self.bns[i](x.transpose(1,2))
          x = x.transpose(1,2)
          x = F.relu(x)
          x = F.dropout(x, p=self.dropout, training=self.training)
        out = self.convs[-1](x, adj)

        return out

# A DiffPool layer that takes in a graph and pool the nodes to produce a
# smaller graph
class DiffPool(torch.nn.Module):
    def __init__(self, gnn_embed, gnn_pool, number_of_nodes_out):
        super(DiffPool, self).__init__()
        self.gnn_embed = gnn_embed
        self.gnn_pool = gnn_pool
        self.number_of_nodes_out = number_of_nodes_out

    def forward(self, x, adj, mask=None):
        # Compute embeddings and assignment matrix with GNNs.
        z = self.gnn_embed(x, adj)
        s = self.gnn_pool(x, adj)

        # Use the DiffPool to get pooled graph.
        x, adj, l1, e1 = dense_diff_pool(z, adj, s, mask)

        return x, adj



# The top-level abstraction for a composition of networks that takes in a batch
# of data and proudces a prediction
class GNN(torch.nn.Module):
    def __init__(self, num_vocab, max_seq_len, node_encoder, gnn_num_layers = 3, diff_pool_node_number_list=[(50, 3), (10, 3), (3, 3)], emb_dim = 300, drop_ratio = 0.5, graph_pooling = "mean", max_num_nodes=1000):
        super(GNN, self).__init__()
        self.drop_ratio = drop_ratio
        self.emb_dim = emb_dim
        self.num_vocab = num_vocab
        self.max_seq_len = max_seq_len
        self.graph_pooling = graph_pooling
        self.dropout_ratio = drop_ratio
        self.max_num_nodes = max_num_nodes

        if gnn_num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")


        self.node_encoder = node_encoder

        # Define the DiffPool layers
        self.diff_pool_layers = torch.nn.ModuleList()
        for number_of_nodes_out, gcn_layer_num in diff_pool_node_number_list:
            gnn_embed = GCN(input_dim=emb_dim, hidden_dim=emb_dim, output_dim=emb_dim, num_layers=gcn_layer_num, dropout=drop_ratio)
            gnn_pool = GCN(input_dim=emb_dim, hidden_dim=emb_dim, output_dim=number_of_nodes_out, num_layers=gcn_layer_num, dropout=drop_ratio)
            diff_pool_layer = DiffPool(gnn_embed=gnn_embed, gnn_pool=gnn_pool, number_of_nodes_out=number_of_nodes_out)
            self.diff_pool_layers.append(diff_pool_layer)

        # Define the final GNN that takes the output of the final
        self.final_gnn = GCN(input_dim=emb_dim, hidden_dim=emb_dim, output_dim=emb_dim, num_layers=gnn_num_layers, dropout=drop_ratio)

        # Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear_list = torch.nn.ModuleList()

        if graph_pooling == "set2set":
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab))

        else:
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab))

    def forward(self, batched_data):
        x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch
        # Embed the nodes of the AST
        x = self.node_encoder(x, node_depth.view(-1,))
        # DiffPool Layers
        z, mask = to_dense_batch(x, batch, max_num_nodes=self.max_num_nodes)
        s = to_dense_adj(edge_index, batch, max_num_nodes=self.max_num_nodes)
        for layer in range(len(self.diff_pool_layers)):
            if (mask is not None):
              z, s = self.diff_pool_layers[layer](z, s, mask)
              mask = None
            else:
              z, s = self.diff_pool_layers[layer](z, s)

        # DiffPool Layeres without node_num limit
        # z, mask = to_dense_batch(x, batch)
        # s = to_dense_adj(edge_index, batch)
        # for layer in range(len(self.diff_pool_layers)):
        #     z, s = self.diff_pool_layers[layer](z, s)

        # Final GNN and pooling
        x = self.final_gnn(z, s)
        del z, s, mask
        graph_emb = self.pool(x, batch=None)

        pred_list = []

        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](graph_emb))

        return pred_list




# Main Training Loop

In [None]:
import numpy as np
import argparse
from torchvision import transforms
import torch.optim as optim

multicls_criterion = torch.nn.CrossEntropyLoss()

def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    iter = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        iter += 1

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            optimizer.zero_grad()
            pred_list = model(batch)
            # if (iter % 100 == 1):
            #     print("after pred", iter, " memory use:", torch.cuda.memory_allocated() / 1e6, "MB")

            loss = 0
            for i in range(len(pred_list)):
                loss += multicls_criterion(pred_list[i].to(torch.float32), batch.y_arr[:,i])

            loss = loss / len(pred_list)

            loss.backward()
            optimizer.step()

            # loss_accum += loss.item()
            loss_accum += float(loss)
            del pred_list
            del loss

    print('Average training loss: {}'.format(loss_accum / (step + 1)))


def eval(model, device, loader, evaluator, arr_to_seq):
    model.eval()
    seq_ref_list = []
    seq_pred_list = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred_list = model(batch)

            mat = []
            for i in range(len(pred_list)):
                mat.append(torch.argmax(pred_list[i], dim = 1).view(-1,1))
            mat = torch.cat(mat, dim = 1)

            seq_pred = [arr_to_seq(arr) for arr in mat]

            # PyG = 1.4.3
            # seq_ref = [batch.y[i][0] for i in range(len(batch.y))]

            # PyG >= 1.5.0
            seq_ref = [batch.y[i] for i in range(len(batch.y))]

            seq_ref_list.extend(seq_ref)
            seq_pred_list.extend(seq_pred)

    input_dict = {"seq_ref": seq_ref_list, "seq_pred": seq_pred_list}

    return evaluator.eval(input_dict)

def main():
    print ("Starting Training")
    # Training settings
    print(args)

    device = torch.device("cuda:" + str(args['device'])) if torch.cuda.is_available() else torch.device("cpu")

    # ### automatic dataloading and splitting
    # dataset = PygGraphPropPredDataset(name = args.dataset)

    seq_len_list = np.array([len(seq) for seq in dataset.data.y])
    print('Target seqence less or equal to {} is {}%.'.format(args['max_seq_len'], np.sum(seq_len_list <= args['max_seq_len']) / len(seq_len_list)))

    split_idx = dataset.get_idx_split()

    if args['random_split']:
        print('Using random split')
        perm = torch.randperm(len(dataset))
        num_train, num_valid, num_test = len(split_idx['train']), len(split_idx['valid']), len(split_idx['test'])
        split_idx['train'] = perm[:num_train]
        split_idx['valid'] = perm[num_train:num_train+num_valid]
        split_idx['test'] = perm[num_train+num_valid:]

        assert(len(split_idx['train']) == num_train)
        assert(len(split_idx['valid']) == num_valid)
        assert(len(split_idx['test']) == num_test)


    print(split_idx['train'])
    print(split_idx['valid'])
    print(split_idx['test'])

    print(len(split_idx['train']))
    #shrink the train set
    # split_idx['train'] = split_idx['train'][:len(split_idx['train'])//10]
    # print(len(split_idx['train']))


    ### building vocabulary for sequence predition. Only use training data.

    vocab2idx, idx2vocab = get_vocab_mapping([dataset.data.y[i] for i in split_idx['train']], args['num_vocab'])

    ### set the transform function
    # augment_edge: add next-token edge as well as inverse edges. add edge attributes.
    # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence.
    dataset.transform = transforms.Compose([augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, args['max_seq_len'])])

    ### automatic evaluator. takes dataset name as input
    evaluator = Evaluator(args['dataset'])

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args['batch_size'], shuffle=True, num_workers = args['num_workers'])
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args['batch_size'], shuffle=False, num_workers = args['num_workers'])
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args['batch_size'], shuffle=False, num_workers = args['num_workers'])

    nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
    nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))

    print(nodeattributes_mapping)

    ### Encoding node features into emb_dim vectors.
    ### The following three node features are used.
    # 1. node type
    # 2. node attribute
    # 3. node depth
    node_encoder = ASTNodeEncoder(args['emb_dim'], num_nodetypes = len(nodetypes_mapping['type']), num_nodeattributes = len(nodeattributes_mapping['attr']), max_depth = 20)

    # if args.gnn == 'gcn':
    model = GNN(num_vocab = len(vocab2idx), max_seq_len = args['max_seq_len'],
                    node_encoder = node_encoder,
                    gnn_num_layers = args['num_layer'],
                    diff_pool_node_number_list = args['diff_pool_layers'],
                    emb_dim = args['emb_dim'],
                    drop_ratio = args['drop_ratio'],
                    graph_pooling = "sum",
                    max_num_nodes=args['max_num_nodes']).to(device)
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))


    # elif args.gnn == 'gin-virtual':
    #     model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gin', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    # elif args.gnn == 'gcn':
    #     model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gcn', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    # elif args.gnn == 'gcn-virtual':
    #     model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gcn', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    # else:
    #     raise ValueError('Invalid GNN type')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print(f'#Params: {sum(p.numel() for p in model.parameters())}')

    valid_curve = []
    test_curve = []
    train_curve = []

    best_val_f1 = -1

    for epoch in range(1, args['epochs'] + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_loader, optimizer)

        print('Evaluating...')
        # train_perf = eval(model, device, train_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab))
        valid_perf = eval(model, device, valid_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab))
        test_perf = eval(model, device, test_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab))

        # print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})
        print({'Validation': valid_perf, 'Test': test_perf})
        val_f1 = valid_perf[dataset.eval_metric]
        if (val_f1 > best_val_f1):
            best_val_f1 = val_f1
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, args['model_save_path'])

        # train_curve.append(train_perf[dataset.eval_metric])
        valid_curve.append(valid_perf[dataset.eval_metric])
        test_curve.append(test_perf[dataset.eval_metric])

    print('F1')
    best_val_epoch = np.argmax(np.array(valid_curve))
    # best_train = max(train_curve)
    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))
    if not args['eval_results_path'] == '':
        result_dict = {'Val': valid_curve, 'Test': test_curve}
        torch.save(result_dict, args['eval_results_path'])




In [None]:
dataset.eval_metric

'F1'

# Run

In [None]:
 main()

Starting Training
{'device': 0, 'num_layer': 3, 'emb_dim': 256, 'drop_ratio': 0.2, 'batch_size': 64, 'lr': 0.01, 'epochs': 50, 'num_vocab': 5000, 'max_seq_len': 5, 'diff_pool_layers': [(2, 5)], 'max_num_nodes': 750, 'random_split': False, 'dataset': 'ogbg-code2', 'num_workers': 0, 'model_save_path': 'best_model_params', 'eval_results_path': 'eval_results'}




Target seqence less or equal to 5 is 0.9874166466036873%.
tensor([     0,      1,      2,  ..., 407973, 407974, 407975])
tensor([407976, 407977, 407978,  ..., 430790, 430791, 430792])
tensor([430793, 430794, 430795,  ..., 452738, 452739, 452740])
407976




Coverage of top 5000 vocabulary:
0.9025832389087423
       attr idx      attr
0             0       NaN
1             1       NaN
2             2        \t
3             3        \n
4             4      \n\t
...         ...       ...
10025     10025         |
10026     10026         }
10027     10027         ~
10028     10028  __NONE__
10029     10029   __UNK__

[10030 rows x 2 columns]
model size: 37.483MB
#Params: 9820852
=====Epoch 1
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 3.0443002535128127
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.09762969131203343, 'recall': 0.048561101406567504, 'F1': 0.061913008711448475}, 'Test': {'precision': 0.09936364740902738, 'recall': 0.05007329734639739, 'F1': 0.0635355678685367}}
=====Epoch 2
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.81608464457942
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.11962571766665206, 'recall': 0.061190156184020406, 'F1': 0.07716128154617}, 'Test': {'precision': 0.12297248040823765, 'recall': 0.0640669793827257, 'F1': 0.08027835240109707}}
=====Epoch 3
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.7146203128216313
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1396473389724036, 'recall': 0.07305954346209419, 'F1': 0.09120726925653075}, 'Test': {'precision': 0.1478570560719276, 'recall': 0.07849805890401844, 'F1': 0.09765383731348738}}
=====Epoch 4
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.647097803396337
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.13319746972286742, 'recall': 0.07304712582027742, 'F1': 0.08953951703436736}, 'Test': {'precision': 0.1366943077577304, 'recall': 0.07616375357627408, 'F1': 0.0930203089238081}}
=====Epoch 5
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.594843675893896
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14702853135819782, 'recall': 0.07934328762110679, 'F1': 0.09788333437367051}, 'Test': {'precision': 0.15286890225381203, 'recall': 0.08400425536838876, 'F1': 0.10321344942039311}}
=====Epoch 6
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.554425692389993
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1501146805744255, 'recall': 0.08466242101579788, 'F1': 0.10263522786641513}, 'Test': {'precision': 0.15747068829354233, 'recall': 0.09048016813189037, 'F1': 0.10917136293024758}}
=====Epoch 7
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.5196167141596475
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14523527778995193, 'recall': 0.08112234959870494, 'F1': 0.09875263886746553}, 'Test': {'precision': 0.14999848125873277, 'recall': 0.08576588675686543, 'F1': 0.10378684656978807}}
=====Epoch 8
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4909312384362314
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15478225299849524, 'recall': 0.08763441140725622, 'F1': 0.10615404035660861}, 'Test': {'precision': 0.16021581313407446, 'recall': 0.0932143363389946, 'F1': 0.1118010272739633}}
=====Epoch 9
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.464895400907479
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15103285561934812, 'recall': 0.08793476529265634, 'F1': 0.10527813335807376}, 'Test': {'precision': 0.1563468197557864, 'recall': 0.09264469988226139, 'F1': 0.11042833829192496}}
=====Epoch 10
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.443738839112076
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.149260054637624, 'recall': 0.08714111449900555, 'F1': 0.10447724763585745}, 'Test': {'precision': 0.1580174351497479, 'recall': 0.09374654667164237, 'F1': 0.11170393631437915}}
=====Epoch 11
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.424007752493316
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15704664650625993, 'recall': 0.09155381221603776, 'F1': 0.10959536511756349}, 'Test': {'precision': 0.1667729785553733, 'recall': 0.10071236197534722, 'F1': 0.11902087013486685}}
=====Epoch 12
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4061021509357525
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15540021329125944, 'recall': 0.08942488231066922, 'F1': 0.10764819223862919}, 'Test': {'precision': 0.1624233035660045, 'recall': 0.09545285248538392, 'F1': 0.11413196967461592}}
=====Epoch 13
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.3871905597424976
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15685234693430336, 'recall': 0.09451775027113592, 'F1': 0.1119328351983828}, 'Test': {'precision': 0.16820439219974487, 'recall': 0.10374528828602092, 'F1': 0.12171070884848906}}
=====Epoch 14
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.3725471771651625
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16279893646550087, 'recall': 0.09141594856522321, 'F1': 0.11092651608275923}, 'Test': {'precision': 0.17136868963003463, 'recall': 0.10003229133218197, 'F1': 0.11985878138830572}}
=====Epoch 15
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.3558131156622193
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15993192210486332, 'recall': 0.09073036951701982, 'F1': 0.10982225032713704}, 'Test': {'precision': 0.16900932507138083, 'recall': 0.0978384695716571, 'F1': 0.11752697113550038}}
=====Epoch 16
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.3411832121680765
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15964996859067068, 'recall': 0.09459152567251791, 'F1': 0.11270808878294525}, 'Test': {'precision': 0.17034353927464915, 'recall': 0.10382941570407399, 'F1': 0.12249701577202943}}
=====Epoch 17
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.326752140344358
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16428101853880878, 'recall': 0.09892164781063409, 'F1': 0.11717537890566773}, 'Test': {'precision': 0.1711940343843023, 'recall': 0.10660952583451215, 'F1': 0.12448595539082144}}
=====Epoch 18
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.3145668806375244
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16809834772318885, 'recall': 0.09878028843712326, 'F1': 0.11788477819726448}, 'Test': {'precision': 0.1734098778932021, 'recall': 0.10494693807296322, 'F1': 0.1241353431211277}}
=====Epoch 19
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.30206296228895
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1668799579261077, 'recall': 0.09667083718558497, 'F1': 0.11614631466715437}, 'Test': {'precision': 0.1776016037907782, 'recall': 0.10701546367894099, 'F1': 0.12673831772382893}}
=====Epoch 20
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2900223913566737
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16456078070444552, 'recall': 0.0986795385785174, 'F1': 0.11693544046833261}, 'Test': {'precision': 0.16959707794180184, 'recall': 0.1056106280067461, 'F1': 0.12356386247856996}}
=====Epoch 21
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2790226786183374
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16270836072519027, 'recall': 0.10033442413794781, 'F1': 0.11769392533515775}, 'Test': {'precision': 0.1716709191422149, 'recall': 0.10869671026177312, 'F1': 0.1264044872164063}}
=====Epoch 22
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2679928804285385
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16622839695548638, 'recall': 0.10341949507155275, 'F1': 0.12082317423510393}, 'Test': {'precision': 0.17607147196403622, 'recall': 0.1126987019824636, 'F1': 0.13035401004236433}}
=====Epoch 23
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2575897645950316
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16657974317394925, 'recall': 0.09874385291946759, 'F1': 0.11762612950104732}, 'Test': {'precision': 0.17659923455440132, 'recall': 0.10755309808754313, 'F1': 0.12672345904303256}}
=====Epoch 24
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.248090444284327
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1695811602460154, 'recall': 0.10404893385675257, 'F1': 0.1220884550967384}, 'Test': {'precision': 0.17802305449243666, 'recall': 0.11301264749870546, 'F1': 0.131175757549459}}
=====Epoch 25
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2374792951883054
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.17095805758864005, 'recall': 0.10314415038633823, 'F1': 0.12180430074763247}, 'Test': {'precision': 0.17627270518194518, 'recall': 0.1095840890879163, 'F1': 0.12836919600861646}}
=====Epoch 26
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2279215227201874
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1679887802953938, 'recall': 0.10255090363562117, 'F1': 0.120665495164772}, 'Test': {'precision': 0.17660075329566852, 'recall': 0.11233029874363935, 'F1': 0.13034562080488818}}
=====Epoch 27
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.220476165790184
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16547968619888678, 'recall': 0.10235547360607622, 'F1': 0.12003421602623952}, 'Test': {'precision': 0.1758778324524634, 'recall': 0.11297457048550542, 'F1': 0.1308052931617721}}
=====Epoch 28
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2112158299427405
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16935837314283209, 'recall': 0.10317368141687729, 'F1': 0.12136808325176647}, 'Test': {'precision': 0.1783913492497418, 'recall': 0.11198456814230459, 'F1': 0.13066358013733628}}
=====Epoch 29
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.2029683200237797
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16966881418825144, 'recall': 0.10513368617523412, 'F1': 0.12307184568656278}, 'Test': {'precision': 0.1814948970293421, 'recall': 0.11659720233394372, 'F1': 0.13470296182160588}}
=====Epoch 30
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.1954701824562224
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16905085389548727, 'recall': 0.1046882163187988, 'F1': 0.12250730939647098}, 'Test': {'precision': 0.17719230301925762, 'recall': 0.11395920299351135, 'F1': 0.13152470643586006}}
=====Epoch 31
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.187588256181455
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16117368628654075, 'recall': 0.10377106738153327, 'F1': 0.11993654449060505}, 'Test': {'precision': 0.16971174290747826, 'recall': 0.11255279433928969, 'F1': 0.12859132999920314}}
=====Epoch 32
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.1800117934170893
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.17026924369256838, 'recall': 0.10466452888440879, 'F1': 0.1228660644600514}, 'Test': {'precision': 0.1809124597533564, 'recall': 0.11729728781738076, 'F1': 0.13520747320637969}}
=====Epoch 33
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.172794066111247
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1714072840425998, 'recall': 0.10740773205163791, 'F1': 0.1252987775690147}, 'Test': {'precision': 0.17768513456047627, 'recall': 0.11492821416276856, 'F1': 0.13267121145059907}}
=====Epoch 34
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.1648760828130387
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.16621378796511374, 'recall': 0.10438403626925342, 'F1': 0.12162068597621035}, 'Test': {'precision': 0.17576392685742057, 'recall': 0.11365165788689571, 'F1': 0.1312005998173303}}
=====Epoch 35
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.159146926692888
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1615973470073483, 'recall': 0.10576375106002137, 'F1': 0.1213914291772906}, 'Test': {'precision': 0.1722100722920843, 'recall': 0.11620232960446192, 'F1': 0.13200191663832997}}
=====Epoch 36
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.151690513666938
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1651707060525047, 'recall': 0.10581498687625698, 'F1': 0.12236626307275386}, 'Test': {'precision': 0.17728722434846, 'recall': 0.11639792178337831, 'F1': 0.13360183495503342}}
=====Epoch 37
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

# Some random exploration

In [None]:
import networkx as nx
import torch_geometric

In [None]:
print(len(train_loader))

In [None]:
print(len(valid_loader))

In [None]:
print(len(test_loader))

In [None]:
dataset[0]

In [None]:
G = torch_geometric.utils.to_networkx(dataset[0])
nx.draw(G, with_labels = True)



In [None]:
print(G.number_of_edges(), G.number_of_nodes())
G.number_of_edges() / G.number_of_nodes()

In [None]:
dataset[0]

In [None]:
for d in dataset[:1]:
  for pair in zip(d.x, d.node_is_attributed):
    print(pair)
