In [1]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# create a random graph
G = nx.gnm_random_graph(100, 1000)

In [7]:
np.random.choice(G.nodes())

72

In [None]:
# initialize a transformer network
tr = nn.Transformer()

In [None]:
def train_transformer_epoch(epoch, args, rnn, output, data_loader,
                    optimizer_rnn, optimizer_output,
                    scheduler_rnn, scheduler_output):
    rnn.train()
    output.train()
    loss_sum = 0
    for batch_idx, data in enumerate(data_loader):
        data = reformat_data(data)
        rnn.zero_grad()
        output.zero_grad()
        x = data['x']
        y = data['y']
        y_len = data['y_len']
        output_x = data['output_x']
        output_y = data['output_y']
        output_y_len = data['output_y_len']

        x = x.to(device)
        y = y.to(device)
        output_x = output_x.to(device)
        output_y = output_y.to(device)
        # print(output_y_len)
        # print('len',len(output_y_len))
        # print('y',y.size())
        # print('output_y',output_y.size())
        rnn.hidden = rnn.init_hidden(batch_size=x.size(0))

        # if using ground truth to train
        h = rnn(x, pack=True, input_len=y_len)

        # get packed hidden vector
        h = pack_padded_sequence(h, y_len, batch_first=True).data
        # reverse h
        idx = [i for i in range(h.size(0) - 1, -1, -1)]
        idx = Variable(torch.LongTensor(idx)).to(device)
        h = h.index_select(0, idx)
        hidden_null = Variable(torch.zeros(
            args.num_layers-1, h.size(0), h.size(1))).to(device)
        # num_layers, batch_size, hidden_size
        output.hidden = torch.cat(
            (h.view(1, h.size(0), h.size(1)), hidden_null), dim=0)

        y_pred = output(output_x, pack=True, input_len=output_y_len)

        y_pred = torch.sigmoid(y_pred)
        # clean
        y_pred = pack_padded_sequence(y_pred, output_y_len, batch_first=True)
        y_pred = pad_packed_sequence(y_pred, batch_first=True)[0]
        output_y = pack_padded_sequence(
            output_y, output_y_len, batch_first=True)
        output_y = pad_packed_sequence(output_y, batch_first=True)[0]
        # use cross entropy loss
        loss = binary_cross_entropy_weight(y_pred, output_y)
        loss.backward()
        # update deterministic and lstm
        optimizer_output.step()
        optimizer_rnn.step()
        scheduler_output.step()
        scheduler_rnn.step()

        if epoch % args.epochs_log == 0 and batch_idx == 0:  # only output first batch's statistics
            print('Epoch: {}/{}, train loss: {:.6f}, graph type: {}, num_layer: {}, hidden: {}'.format(
                epoch, args.epochs, loss.data, args.graph_type, args.num_layers, args.hidden_size_rnn))

        feature_dim = y.size(1)*y.size(2)
        loss_sum += loss.data*feature_dim
    return loss_sum/(batch_idx+1)

def train_transformer_Dev():
    # set seed and basic args
    random.seed(123)
    args = Args()
    create_save_path(args)
    args.max_prev_node = 3
    
    # load dataset, shuffle
    graphs = load_graph_dataset(min_num_nodes=10, name='ENZYMES')
    shuffle(graphs)

    # get graph statistics
    args.max_num_node = max([graph.number_of_nodes() for graph in graphs])
    edge_counts = [graph.number_of_edges() for graph in graphs]
    max_num_edge = max(edge_counts)
    min_num_edge = min(edge_counts)

    # 60-20-20 split
    train, valid, test = split_dataset(graphs, len(graphs), 0.6, 0.2, 0.2)                    

    # print dataset split and graph statistics
    print('total graph num: {}, training set: {}'.format(
        len(graphs), len(train)))
    print('max number node: {}'.format(args.max_num_node))
    print('max/min number edge: {}; {}'.format(max_num_edge, min_num_edge))
    print('max previous node: {}'.format(args.max_prev_node))

    # sample permutations of bfs-order graph adjacency matrices
    train_sampled = Graph_sequence_sampler_pytorch(train, max_prev_node=args.max_prev_node,
                                             max_num_node=args.max_num_node)

    # samples elements from train_sampled uniformly at random, with replacement
    sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(train_sampled)] * len(train_sampled), num_samples=args.batch_size * args.batch_ratio,replacement=True)

    # create data loader
    dataset_loader = torch.utils.data.DataLoader(train_sampled, batch_size=args.batch_size, num_workers=args.num_workers, sampler=sample_strategy)

    # initialize transformer with defaults
    tr_first = Transformer(batch_first=True, device=device)
    tr_out = Transformer(batch_first=True, device=device)

    # initialize optimizers
    optimizer_tr_first = optim.Adam(list(tr_first.parameters()), lr=args.lr)
    optimizer_tr_out = optim.Adam(list(tr_out.parameters()), lr=args.lr)

    # initialize schedulers
    scheduler_tr_first = MultiStepLR(
        optimizer_tr_first, milestones=args.milestones, gamma=args.lr_rate)
    scheduler_tr_out = MultiStepLR(
        optimizer_tr_out, milestones=args.milestones, gamma=args.lr_rate)

    # start main loop
    time_all = np.zeros(args.epochs)

    for epoch in range(1, args.epochs + 1):
        train_rnn_epoch(epoch, args, rnn, output, dataset_loader,
                        optimizer_rnn, optimizer_output,
                        scheduler_rnn, scheduler_output)

        time_end = tm.time()
        time_all[epoch - 1] = time_end - time_start

        # test
        if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
            for sample_time in range(1, 4):
                G_pred = []
                while len(G_pred) < args.test_total_size:
                    G_pred_step = test_rnn_epoch(epoch, args, rnn, output, test_batch_size=args.test_batch_size)
                    G_pred.extend(G_pred_step)
                # save graphs
                fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
                save_graph_list(G_pred, fname)
                if 'GraphRNN_RNN' in args.note:
                    break

            print('test done, graphs saved')
        epoch += 1