# Args

In [None]:
args = {
        'device': 0,
        'num_layer': 3,
        'emb_dim': 256,
        'drop_ratio': 0.5,
        'batch_size': 64,
        'lr': 0.01,
        'epochs': 50,
        'num_vocab': 5000,
        'max_seq_len': 5,
        'diff_pool_layers': [(128,5), (32,3)],
        'max_num_nodes': 512,
        'random_split': False,
        'dataset': "ogbg-code2",
        'num_workers': 0,
        'model_save_path': "best_model_params",
        'eval_results_path': 'eval_results'
    }

# Setting up the dependencies and download the dataset

In [None]:
!pip install ogb
!pip install torch_geometric
!python -c "import ogb; print(ogb.__version__)"

import os
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.loader import DataLoader
import torch
import pandas as pd
import torch.nn.functional as F
from tqdm.notebook import tqdm
print(torch.__version__)

# The PyG built-in GCNConv
from torch_geometric.nn import GCNConv

import torch_geometric.transforms as T
from torch_geometric.nn import global_add_pool, global_mean_pool

dataset = PygGraphPropPredDataset(name = "ogbg-code2")

Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7026 sha256=d3cc7728008ec198d0cff9c78fd3f615018ba1daf92e9ea8a30200a640adfb35
  Stored in directory: /root/.cache/pip/wheels/3d/fe/b0/27a9892da57472e538c7452a721a9cf463cc03cf7379889266
Successfully built littleutils
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.2 ogb-1.3.6 outdated-0.2.2
Collecting torch_geometric
  

Downloaded 0.91 GB: 100%|██████████| 934/934 [01:01<00:00, 15.14it/s]


Extracting dataset/code2.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 452741/452741 [00:01<00:00, 341403.81it/s]


Converting graphs into PyG objects...


100%|██████████| 452741/452741 [00:21<00:00, 21136.65it/s]


Saving...


Done!


# Utils

In [None]:
from collections import Counter
import numpy as np
import torch

class ASTNodeEncoder(torch.nn.Module):
    '''
        Input:
            x: default node feature. the first and second column represents node type and node attributes.
            depth: The depth of the node in the AST.

        Output:
            emb_dim-dimensional vector

    '''
    def __init__(self, emb_dim, num_nodetypes, num_nodeattributes, max_depth):
        super(ASTNodeEncoder, self).__init__()

        self.max_depth = max_depth

        self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim)
        self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim)
        self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim)


    def forward(self, x, depth):
        depth[depth > self.max_depth] = self.max_depth
        return self.type_encoder(x[:,0]) + self.attribute_encoder(x[:,1]) + self.depth_encoder(depth)



def get_vocab_mapping(seq_list, num_vocab):
    '''
        Input:
            seq_list: a list of sequences
            num_vocab: vocabulary size
        Output:
            vocab2idx:
                A dictionary that maps vocabulary into integer index.
                Additioanlly, we also index '__UNK__' and '__EOS__'
                '__UNK__' : out-of-vocabulary term
                '__EOS__' : end-of-sentence

            idx2vocab:
                A list that maps idx to actual vocabulary.

    '''

    vocab_cnt = {}
    vocab_list = []
    for seq in seq_list:
        for w in seq:
            if w in vocab_cnt:
                vocab_cnt[w] += 1
            else:
                vocab_cnt[w] = 1
                vocab_list.append(w)

    cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
    topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab]

    print('Coverage of top {} vocabulary:'.format(num_vocab))
    print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list))

    vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)}
    idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]

    # print(topvocab)
    # print([vocab_list[v] for v in topvocab[:10]])
    # print([vocab_list[v] for v in topvocab[-10:]])

    vocab2idx['__UNK__'] = num_vocab
    idx2vocab.append('__UNK__')

    vocab2idx['__EOS__'] = num_vocab + 1
    idx2vocab.append('__EOS__')

    # test the correspondence between vocab2idx and idx2vocab
    for idx, vocab in enumerate(idx2vocab):
        assert(idx == vocab2idx[vocab])

    # test that the idx of '__EOS__' is len(idx2vocab) - 1.
    # This fact will be used in decode_arr_to_seq, when finding __EOS__
    assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1)

    return vocab2idx, idx2vocab

def augment_edge(data):
    '''
        Input:
            data: PyG data object
        Output:
            data (edges are augmented in the following ways):
                data.edge_index: Added next-token edge. The inverse edges were also added.
                data.edge_attr (torch.Long):
                    data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
                    data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
    '''

    ##### AST edge
    edge_index_ast = data.edge_index
    edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))

    ##### Inverse AST edge
    edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0)
    edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1)


    ##### Next-token edge

    ## Obtain attributed nodes and get their indices in dfs order
    # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
    # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]

    ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
    attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1,) == 1)[0]

    ## build next token edge
    # Given: attributed_node_idx_in_dfs_order
    #        [1, 3, 4, 5, 8, 9, 12]
    # Output:
    #    [[1, 3, 4, 5, 8, 9]
    #     [3, 4, 5, 8, 9, 12]
    edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0)
    edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1)


    ##### Inverse next-token edge
    edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0)
    edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))


    data.edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim = 1)
    data.edge_attr = torch.cat([edge_attr_ast,   edge_attr_ast_inverse, edge_attr_nextoken,  edge_attr_nextoken_inverse], dim = 0)

    return data

def encode_y_to_arr(data, vocab2idx, max_seq_len):
    '''
    Input:
        data: PyG graph object
        output: add y_arr to data
    '''

    # PyG >= 1.5.0
    seq = data.y

    # PyG = 1.4.3
    # seq = data.y[0]

    data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len)

    return data

def encode_seq_to_arr(seq, vocab2idx, max_seq_len):
    '''
    Input:
        seq: A list of words
        output: add y_arr (torch.Tensor)
    '''

    augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq))
    return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long)


def decode_arr_to_seq(arr, idx2vocab):
    '''
        Input: torch 1d array: y_arr
        Output: a sequence of words.
    '''


    eos_idx_list = torch.nonzero(arr == len(idx2vocab) - 1, as_tuple=False) # find the position of __EOS__ (the last vocab in idx2vocab)
    if len(eos_idx_list) > 0:
        clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__
    else:
        clippted_arr = arr

    return list(map(lambda x: idx2vocab[x], clippted_arr.cpu()))


# Define GNN and DiffPool

In [None]:
from torch_geometric.nn import dense_diff_pool, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set, DenseGCNConv
from torch_geometric.utils import to_dense_batch, to_dense_adj
import torch.nn.functional as F

# The generic GCN used by different layers of DiffPool
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super(GCN, self).__init__()

        # GCN layers
        self.convs = torch.nn.ModuleList()
        self.convs.extend([DenseGCNConv(hidden_dim, hidden_dim) for i in range(num_layers-1)])
        self.convs.append(DenseGCNConv(hidden_dim, output_dim))

        # Batch norm
        self.bns = torch.nn.ModuleList()
        self.bns.extend([torch.nn.BatchNorm1d(hidden_dim) for i in range(num_layers-1)])

        # Log Softmax
        self.softmax = torch.nn.LogSoftmax(dim=1)

        # Probability of an element getting zeroed
        self.dropout = dropout

    def forward(self, x, adj):
        for i in range(len(self.convs)-1):
          x = self.convs[i](x, adj)
          x = self.bns[i](x.transpose(1,2))
          x = x.transpose(1,2)
          x = F.relu(x)
          x = F.dropout(x, p=self.dropout, training=self.training)
        out = self.convs[-1](x, adj)

        return out

# A DiffPool layer that takes in a graph and pool the nodes to produce a
# smaller graph
class DiffPool(torch.nn.Module):
    def __init__(self, gnn_embed, gnn_pool, number_of_nodes_out):
        super(DiffPool, self).__init__()
        self.gnn_embed = gnn_embed
        self.gnn_pool = gnn_pool
        self.number_of_nodes_out = number_of_nodes_out

    def forward(self, x, adj, mask=None):
        # Compute embeddings and assignment matrix with GNNs.
        z = self.gnn_embed(x, adj)
        s = self.gnn_pool(x, adj)

        # Use the DiffPool to get pooled graph.
        x, adj, l1, e1 = dense_diff_pool(z, adj, s, mask)

        return x, adj



# The top-level abstraction for a composition of networks that takes in a batch
# of data and proudces a prediction
class GNN(torch.nn.Module):
    def __init__(self, num_vocab, max_seq_len, node_encoder, gnn_num_layers = 3, diff_pool_node_number_list=[(50, 3), (10, 3), (3, 3)], emb_dim = 300, drop_ratio = 0.5, graph_pooling = "mean", max_num_nodes=1000):
        super(GNN, self).__init__()
        self.drop_ratio = drop_ratio
        self.emb_dim = emb_dim
        self.num_vocab = num_vocab
        self.max_seq_len = max_seq_len
        self.graph_pooling = graph_pooling
        self.dropout_ratio = drop_ratio
        self.max_num_nodes = max_num_nodes

        if gnn_num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")


        self.node_encoder = node_encoder

        # Define the DiffPool layers
        self.diff_pool_layers = torch.nn.ModuleList()
        for number_of_nodes_out, gcn_layer_num in diff_pool_node_number_list:
            gnn_embed = GCN(input_dim=emb_dim, hidden_dim=emb_dim, output_dim=emb_dim, num_layers=gcn_layer_num, dropout=drop_ratio)
            gnn_pool = GCN(input_dim=emb_dim, hidden_dim=emb_dim, output_dim=number_of_nodes_out, num_layers=gcn_layer_num, dropout=drop_ratio)
            diff_pool_layer = DiffPool(gnn_embed=gnn_embed, gnn_pool=gnn_pool, number_of_nodes_out=number_of_nodes_out)
            self.diff_pool_layers.append(diff_pool_layer)

        # Define the final GNN that takes the output of the final
        self.final_gnn = GCN(input_dim=emb_dim, hidden_dim=emb_dim, output_dim=emb_dim, num_layers=gnn_num_layers, dropout=drop_ratio)

        # Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_linear_list = torch.nn.ModuleList()

        if graph_pooling == "set2set":
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_vocab))

        else:
            for i in range(max_seq_len):
                 self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_vocab))

    def forward(self, batched_data):
        x, edge_index, edge_attr, node_depth, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.node_depth, batched_data.batch
        # Embed the nodes of the AST
        x = self.node_encoder(x, node_depth.view(-1,))
        # DiffPool Layers
        z, mask = to_dense_batch(x, batch, max_num_nodes=self.max_num_nodes)
        s = to_dense_adj(edge_index, batch, max_num_nodes=self.max_num_nodes)
        for layer in range(len(self.diff_pool_layers)):
            if (mask is not None):
              z, s = self.diff_pool_layers[layer](z, s, mask)
              mask = None
            else:
              z, s = self.diff_pool_layers[layer](z, s)

        # DiffPool Layeres without node_num limit
        # z, mask = to_dense_batch(x, batch)
        # s = to_dense_adj(edge_index, batch)
        # for layer in range(len(self.diff_pool_layers)):
        #     z, s = self.diff_pool_layers[layer](z, s)

        # Final GNN and pooling
        x = self.final_gnn(z, s)
        del z, s, mask
        graph_emb = self.pool(x, batch=None)

        pred_list = []

        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](graph_emb))

        return pred_list




# Main Training Loop

In [None]:
import numpy as np
import argparse
from torchvision import transforms
import torch.optim as optim

multicls_criterion = torch.nn.CrossEntropyLoss()

def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    iter = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        iter += 1

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            optimizer.zero_grad()
            pred_list = model(batch)
            # if (iter % 100 == 1):
            #     print("after pred", iter, " memory use:", torch.cuda.memory_allocated() / 1e6, "MB")

            loss = 0
            for i in range(len(pred_list)):
                loss += multicls_criterion(pred_list[i].to(torch.float32), batch.y_arr[:,i])

            loss = loss / len(pred_list)

            loss.backward()
            optimizer.step()

            # loss_accum += loss.item()
            loss_accum += float(loss)
            del pred_list
            del loss

    print('Average training loss: {}'.format(loss_accum / (step + 1)))


def eval(model, device, loader, evaluator, arr_to_seq):
    model.eval()
    seq_ref_list = []
    seq_pred_list = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred_list = model(batch)

            mat = []
            for i in range(len(pred_list)):
                mat.append(torch.argmax(pred_list[i], dim = 1).view(-1,1))
            mat = torch.cat(mat, dim = 1)

            seq_pred = [arr_to_seq(arr) for arr in mat]

            # PyG = 1.4.3
            # seq_ref = [batch.y[i][0] for i in range(len(batch.y))]

            # PyG >= 1.5.0
            seq_ref = [batch.y[i] for i in range(len(batch.y))]

            seq_ref_list.extend(seq_ref)
            seq_pred_list.extend(seq_pred)

    input_dict = {"seq_ref": seq_ref_list, "seq_pred": seq_pred_list}

    return evaluator.eval(input_dict)

def main():
    print ("Starting Training")
    # Training settings
    print(args)

    device = torch.device("cuda:" + str(args['device'])) if torch.cuda.is_available() else torch.device("cpu")

    # ### automatic dataloading and splitting
    # dataset = PygGraphPropPredDataset(name = args.dataset)

    seq_len_list = np.array([len(seq) for seq in dataset.data.y])
    print('Target seqence less or equal to {} is {}%.'.format(args['max_seq_len'], np.sum(seq_len_list <= args['max_seq_len']) / len(seq_len_list)))

    split_idx = dataset.get_idx_split()

    if args['random_split']:
        print('Using random split')
        perm = torch.randperm(len(dataset))
        num_train, num_valid, num_test = len(split_idx['train']), len(split_idx['valid']), len(split_idx['test'])
        split_idx['train'] = perm[:num_train]
        split_idx['valid'] = perm[num_train:num_train+num_valid]
        split_idx['test'] = perm[num_train+num_valid:]

        assert(len(split_idx['train']) == num_train)
        assert(len(split_idx['valid']) == num_valid)
        assert(len(split_idx['test']) == num_test)


    print(split_idx['train'])
    print(split_idx['valid'])
    print(split_idx['test'])

    print(len(split_idx['train']))
    #shrink the train set
    # split_idx['train'] = split_idx['train'][:len(split_idx['train'])//10]
    # print(len(split_idx['train']))


    ### building vocabulary for sequence predition. Only use training data.

    vocab2idx, idx2vocab = get_vocab_mapping([dataset.data.y[i] for i in split_idx['train']], args['num_vocab'])

    ### set the transform function
    # augment_edge: add next-token edge as well as inverse edges. add edge attributes.
    # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence.
    dataset.transform = transforms.Compose([augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, args['max_seq_len'])])

    ### automatic evaluator. takes dataset name as input
    evaluator = Evaluator(args['dataset'])

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args['batch_size'], shuffle=True, num_workers = args['num_workers'])
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args['batch_size'], shuffle=False, num_workers = args['num_workers'])
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args['batch_size'], shuffle=False, num_workers = args['num_workers'])

    nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
    nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))

    print(nodeattributes_mapping)

    ### Encoding node features into emb_dim vectors.
    ### The following three node features are used.
    # 1. node type
    # 2. node attribute
    # 3. node depth
    node_encoder = ASTNodeEncoder(args['emb_dim'], num_nodetypes = len(nodetypes_mapping['type']), num_nodeattributes = len(nodeattributes_mapping['attr']), max_depth = 20)

    # if args.gnn == 'gcn':
    model = GNN(num_vocab = len(vocab2idx), max_seq_len = args['max_seq_len'],
                    node_encoder = node_encoder,
                    gnn_num_layers = args['num_layer'],
                    diff_pool_node_number_list = args['diff_pool_layers'],
                    emb_dim = args['emb_dim'],
                    drop_ratio = args['drop_ratio'],
                    graph_pooling = "sum",
                    max_num_nodes=args['max_num_nodes']).to(device)
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))


    # elif args.gnn == 'gin-virtual':
    #     model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gin', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    # elif args.gnn == 'gcn':
    #     model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gcn', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device)
    # elif args.gnn == 'gcn-virtual':
    #     model = GNN(num_vocab = len(vocab2idx), max_seq_len = args.max_seq_len, node_encoder = node_encoder, num_layer = args.num_layer, gnn_type = 'gcn', emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device)
    # else:
    #     raise ValueError('Invalid GNN type')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print(f'#Params: {sum(p.numel() for p in model.parameters())}')

    valid_curve = []
    test_curve = []
    train_curve = []

    best_val_f1 = -1

    for epoch in range(1, args['epochs'] + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_loader, optimizer)

        print('Evaluating...')
        # train_perf = eval(model, device, train_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab))
        valid_perf = eval(model, device, valid_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab))
        test_perf = eval(model, device, test_loader, evaluator, arr_to_seq = lambda arr: decode_arr_to_seq(arr, idx2vocab))

        # print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})
        print({'Validation': valid_perf, 'Test': test_perf})
        val_f1 = valid_perf[dataset.eval_metric]
        if (val_f1 > best_val_f1):
            best_val_f1 = val_f1
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, args['model_save_path'])

        # train_curve.append(train_perf[dataset.eval_metric])
        valid_curve.append(valid_perf[dataset.eval_metric])
        test_curve.append(test_perf[dataset.eval_metric])

    print('F1')
    best_val_epoch = np.argmax(np.array(valid_curve))
    # best_train = max(train_curve)
    print('Finished training!')
    print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    print('Test score: {}'.format(test_curve[best_val_epoch]))
    if not args['eval_results_path'] == '':
        result_dict = {'Val': valid_curve, 'Test': test_curve}
        torch.save(result_dict, args['eval_results_path'])




In [None]:
dataset.eval_metric

'F1'

# Run

In [None]:
 main()

Starting Training
{'device': 0, 'num_layer': 3, 'emb_dim': 256, 'drop_ratio': 0.5, 'batch_size': 64, 'lr': 0.01, 'epochs': 50, 'num_vocab': 5000, 'max_seq_len': 5, 'diff_pool_layers': [(128, 5), (32, 3)], 'max_num_nodes': 512, 'random_split': False, 'dataset': 'ogbg-code2', 'num_workers': 0, 'model_save_path': 'best_model_params', 'eval_results_path': 'eval_results'}
Target seqence less or equal to 5 is 0.9874166466036873%.
tensor([     0,      1,      2,  ..., 407973, 407974, 407975])
tensor([407976, 407977, 407978,  ..., 430790, 430791, 430792])
tensor([430793, 430794, 430795,  ..., 452738, 452739, 452740])
407976




Coverage of top 5000 vocabulary:
0.9025832389087423
       attr idx      attr
0             0       NaN
1             1       NaN
2             2        \t
3             3        \n
4             4      \n\t
...         ...       ...
10025     10025         |
10026     10026         }
10027     10027         ~
10028     10028  __NONE__
10029     10029   __UNK__

[10030 rows x 2 columns]
model size: 38.909MB
#Params: 10192466
=====Epoch 1
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 3.184906141991709
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.08590086339133103, 'recall': 0.04292326593023441, 'F1': 0.05477109071797243}, 'Test': {'precision': 0.08449517040277019, 'recall': 0.041386476983114484, 'F1': 0.05332192765216275}}
=====Epoch 2
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.9284559424531227
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.12361397203839243, 'recall': 0.05710840427389491, 'F1': 0.0748175441452384}, 'Test': {'precision': 0.12381538181155459, 'recall': 0.05872860382838513, 'F1': 0.07632004257209286}}
=====Epoch 3
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.833876733742508
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.12400111028326831, 'recall': 0.06352947294241065, 'F1': 0.0799531639563195}, 'Test': {'precision': 0.12525059230909422, 'recall': 0.06449266086362859, 'F1': 0.08104596763016665}}
=====Epoch 4
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.772654317706239
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.12614132737286524, 'recall': 0.06376535596196374, 'F1': 0.08062229049869844}, 'Test': {'precision': 0.12507973391653, 'recall': 0.06394851756388115, 'F1': 0.08068183133109322}}
=====Epoch 5
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.726030705321069
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.12131305605469606, 'recall': 0.06374161635260817, 'F1': 0.07933843534930445}, 'Test': {'precision': 0.12402041188263167, 'recall': 0.0660424278453256, 'F1': 0.08198758721585402}}
=====Epoch 6
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.6932570677364573
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1267549049685176, 'recall': 0.0665761778846321, 'F1': 0.08296428557199034}, 'Test': {'precision': 0.13056618674442622, 'recall': 0.07079500319658877, 'F1': 0.08713575521940749}}
=====Epoch 7
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.666482454823513
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.11847525967480388, 'recall': 0.06458966824374196, 'F1': 0.07944215918095043}, 'Test': {'precision': 0.12405838041431261, 'recall': 0.0685653825347647, 'F1': 0.0840627367691337}}
=====Epoch 8
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.6441842813678815
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.13609004981665718, 'recall': 0.07045128214760506, 'F1': 0.08801144510045766}, 'Test': {'precision': 0.1391470749043193, 'recall': 0.0736179092863073, 'F1': 0.0916643885547658}}
=====Epoch 9
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.62604159525329
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.13202509824546027, 'recall': 0.07224884884634195, 'F1': 0.08869733719845478}, 'Test': {'precision': 0.13385426158799588, 'recall': 0.07397145055701643, 'F1': 0.09064926452159913}}
=====Epoch 10
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.6090960788352815
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1367255408978685, 'recall': 0.07644660309668856, 'F1': 0.09303088549264646}, 'Test': {'precision': 0.1421997448514671, 'recall': 0.08087868584041356, 'F1': 0.09801490654415752}}
=====Epoch 11
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.5917793906529742
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1444427400622343, 'recall': 0.07287710499898781, 'F1': 0.09195698556701318}, 'Test': {'precision': 0.14493347913249496, 'recall': 0.07523276517945737, 'F1': 0.09414181716669415}}
=====Epoch 12
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.5797641641018436
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.13886575798746548, 'recall': 0.07606002138477923, 'F1': 0.09300893722377704}, 'Test': {'precision': 0.14469427738290505, 'recall': 0.07977597119887989, 'F1': 0.09770504716513466}}
=====Epoch 13
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.563658599479526
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14288323033995118, 'recall': 0.08008685044776555, 'F1': 0.09725931862277569}, 'Test': {'precision': 0.15053004070226594, 'recall': 0.08648631252513155, 'F1': 0.10435860663444042}}
=====Epoch 14
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.553486994014067
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.13682780383047727, 'recall': 0.07358111920170911, 'F1': 0.09045229434193802}, 'Test': {'precision': 0.14106068890103882, 'recall': 0.07916055554109136, 'F1': 0.09603142058718876}}
=====Epoch 15
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.5397970588160494
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1452754525134768, 'recall': 0.07790398901960456, 'F1': 0.09619731458930302}, 'Test': {'precision': 0.1499832938460604, 'recall': 0.08289448942811435, 'F1': 0.10139565090357872}}
=====Epoch 16
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.5302603231691845
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.13917254678529167, 'recall': 0.07753302501406985, 'F1': 0.09431699839509806}, 'Test': {'precision': 0.14946692181519955, 'recall': 0.08414435925029144, 'F1': 0.10222622157422596}}
=====Epoch 17
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.518110799434138
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14615564418343047, 'recall': 0.08240178410555761, 'F1': 0.09981338753964428}, 'Test': {'precision': 0.15136534839924667, 'recall': 0.0866935122265904, 'F1': 0.10467379146985488}}
=====Epoch 18
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.5102627125908348
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14376707425749805, 'recall': 0.08400689474778969, 'F1': 0.10049611435639398}, 'Test': {'precision': 0.15201081343782272, 'recall': 0.09241656324761846, 'F1': 0.10907210234028165}}
=====Epoch 19
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.500724945629344
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.147291493184906, 'recall': 0.0840935051907134, 'F1': 0.10138183973482216}, 'Test': {'precision': 0.15311949456290624, 'recall': 0.09084699838868783, 'F1': 0.10808893433278342}}
=====Epoch 20
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.494167191505432
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14823742531153672, 'recall': 0.08382949986469292, 'F1': 0.1015854928551602}, 'Test': {'precision': 0.1553520442257457, 'recall': 0.09089294031202177, 'F1': 0.10881021809124597}}
=====Epoch 21
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.483883600571576
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.14627251610641187, 'recall': 0.08483539841847243, 'F1': 0.10178198693779182}, 'Test': {'precision': 0.15421298827531743, 'recall': 0.0924456905354937, 'F1': 0.10965197222031557}}
=====Epoch 22
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4755208687501797
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1556660969160421, 'recall': 0.0873857107378166, 'F1': 0.10592655750651982}, 'Test': {'precision': 0.15954377012332177, 'recall': 0.09234247037008109, 'F1': 0.11095848747571001}}
=====Epoch 23
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.468990633945839
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15020233451666154, 'recall': 0.08575712136105701, 'F1': 0.10330643508318049}, 'Test': {'precision': 0.15906916347731, 'recall': 0.0926842413959691, 'F1': 0.11111226824731472}}
=====Epoch 24
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.462267612475975
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15034111992520197, 'recall': 0.0864520223364506, 'F1': 0.10372219683040562}, 'Test': {'precision': 0.15881477431504767, 'recall': 0.09389918016899974, 'F1': 0.11178264787477472}}
=====Epoch 25
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.456142645106596
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15533739463265692, 'recall': 0.08727283889553249, 'F1': 0.10571218480233688}, 'Test': {'precision': 0.16178391349249743, 'recall': 0.09469163766594056, 'F1': 0.11328649101339096}}
=====Epoch 26
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4505090857674094
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15500869234927173, 'recall': 0.09114870839133452, 'F1': 0.10873165757462554}, 'Test': {'precision': 0.16088405929165905, 'recall': 0.09749859697235312, 'F1': 0.11506700804814528}}
=====Epoch 27
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4430982690885954
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15512556427225316, 'recall': 0.0878459287178663, 'F1': 0.10614631846169734}, 'Test': {'precision': 0.16396330721098354, 'recall': 0.09460810689624248, 'F1': 0.1139203583913808}}
=====Epoch 28
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4375856330161003
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.1436940293056347, 'recall': 0.08608495406168193, 'F1': 0.10189259470278009}, 'Test': {'precision': 0.14904167426037299, 'recall': 0.09280449315987861, 'F1': 0.10848149278985801}}
=====Epoch 29
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4341017691668343
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15012928956479818, 'recall': 0.09080724063302842, 'F1': 0.1072342708447367}, 'Test': {'precision': 0.15759598444808942, 'recall': 0.09757757151824947, 'F1': 0.114298922732493}}
=====Epoch 30
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

Average training loss: 2.4275282937779146
Evaluating...


Iteration:   0%|          | 0/357 [00:00<?, ?it/s]

Iteration:   0%|          | 0/343 [00:00<?, ?it/s]

{'Validation': {'precision': 0.15714890943886867, 'recall': 0.09389248548318538, 'F1': 0.11149807543943496}, 'Test': {'precision': 0.16388357329445352, 'recall': 0.1018130154680182, 'F1': 0.11908899652885437}}
=====Epoch 31
Training...


Iteration:   0%|          | 0/6375 [00:00<?, ?it/s]

# Some random exploration

In [None]:
import networkx as nx
import torch_geometric

In [None]:
print(len(train_loader))

In [None]:
print(len(valid_loader))

In [None]:
print(len(test_loader))

In [None]:
dataset[0]

In [None]:
G = torch_geometric.utils.to_networkx(dataset[0])
nx.draw(G, with_labels = True)



In [None]:
print(G.number_of_edges(), G.number_of_nodes())
G.number_of_edges() / G.number_of_nodes()

In [None]:
dataset[0]

In [None]:
for d in dataset[:1]:
  for pair in zip(d.x, d.node_is_attributed):
    print(pair)
