In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from src.utils import *
import tensorflow as tf
import numpy as np
import time

from graph_nets.demos import models

  from ._conv import register_converters as _register_converters


In [2]:
def create_loss_ops(target_op, output_ops):
    loss_ops = [
        #tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +
        tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)
        for output_op in output_ops
    ]
    return loss_ops

def compute_accuracy(target, output, use_nodes=True, use_edges=False):
    if not use_nodes and not use_edges:
        raise ValueError("Nodes or edges (or both) must be used")
    tdds = utils_np.graphs_tuple_to_data_dicts(target)
    odds = utils_np.graphs_tuple_to_data_dicts(output)
    cs = []
    ss = []
    for td, od in zip(tdds, odds):
        xn = np.argmax(td["nodes"], axis=-1)
        yn = np.argmax(od["nodes"], axis=-1)
        xe = np.argmax(td["edges"], axis=-1)
        ye = np.argmax(od["edges"], axis=-1)
        c = []
        if use_nodes:
            c.append(xn == yn)
        if use_edges:
            c.append(xe == ye)
        c = np.concatenate(c, axis=0)
        s = np.all(c)
        cs.append(c)
        ss.append(s)
    correct = np.mean(np.concatenate(cs, axis=0))
    solved = np.mean(np.stack(ss))
    return correct, solved

In [3]:
# Model Setup
tf.reset_default_graph()

seed = 2
rand = np.random.RandomState(seed=seed)

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10

# Data / training parameters.
num_training_iterations = 150000 # previously 10000
#theta = 20  # Large values (1000+) make trees. Try 20-60 for good non-trees.
batch_size_tr = 400 # previously 32
batch_size_ge = 100 # previously 100
# Number of nodes per graph sampled uniformly from this range.
#num_nodes_min_max_tr = (8, 17) # NOT USED?
#num_nodes_min_max_ge = (16, 33)

# Data.
# Input and target placeholders.
input_ph, target_ph = create_placeholders(batch_size_tr)

# Connect the data to the model.
# Instantiate the model.
model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
# A list of outputs, one per processing step.
output_ops_tr = model(input_ph, num_processing_steps_tr)
output_ops_ge = model(input_ph, num_processing_steps_ge)

# Training loss.
loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)
# Loss across processing steps.e
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr
# Test/generalization loss.
loss_ops_ge = create_loss_ops(target_ph, output_ops_ge)
loss_op_ge = loss_ops_ge[-1]  # Loss from final processing step.

# Optimizer.
learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

# Lets an iterable of TF graphs be output from a session as NP graphs.
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


In [4]:
# Reset Session
#@title Reset session  { form-width: "30%" }

# This cell resets the Tensorflow session, but keeps the same computational
# graph.

try:
    sess.close()
except NameError:
    pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())

last_iteration = 0
logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []

In [None]:
#@title Run training  { form-width: "30%" }

# You can interrupt this cell's training loop at any time, and visualize the
# intermediate results by running the next cell (below). You can then resume
# training by simply executing this cell again.

# How much time between logging and printing the current results.
log_every_seconds = 20

print("# (iteration number), T (elapsed seconds), "
      "Ltr (training loss), Lge (test/generalization loss), "
      "Ctr (training fraction nodes/edges labeled correctly), "
      "Str (training fraction examples solved correctly), "
      "Cge (test/generalization fraction nodes/edges labeled correctly), "
      "Sge (test/generalization fraction examples solved correctly)")

start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
    last_iteration = iteration
    feed_dict, _ = create_feed_dict(batch_size_tr, input_ph, target_ph)
    train_values = sess.run({
        "step": step_op,
        "target": target_ph,
        "loss": loss_op_tr,
        "outputs": output_ops_tr
    }, feed_dict=feed_dict)
    the_time = time.time()
    elapsed_since_last_log = the_time - last_log_time
    if elapsed_since_last_log > log_every_seconds:
        last_log_time = the_time
        feed_dict, raw_graphs = create_feed_dict(batch_size_ge, input_ph, target_ph)
        test_values = sess.run({
            "target": target_ph,
            "loss": loss_op_ge,
            "outputs": output_ops_ge
        },
                           feed_dict=feed_dict)
        # added use_nodes=False in compute_accuracy
        correct_tr, solved_tr = compute_accuracy(
            train_values["target"], train_values["outputs"][-1], use_nodes=False, use_edges=True)
        correct_ge, solved_ge = compute_accuracy(
            test_values["target"], test_values["outputs"][-1], use_nodes=False, use_edges=True)
        elapsed = time.time() - start_time
        losses_tr.append(train_values["loss"])
        corrects_tr.append(correct_tr)
        solveds_tr.append(solved_tr)
        losses_ge.append(test_values["loss"])
        corrects_ge.append(correct_ge)
        solveds_ge.append(solved_ge)
        logged_iterations.append(iteration)
        print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, Str"
              " {:.4f}, Cge {:.4f}, Sge {:.4f}".format(
                  iteration, elapsed, train_values["loss"], test_values["loss"],
                  correct_tr, solved_tr, correct_ge, solved_ge))

# (iteration number), T (elapsed seconds), Ltr (training loss), Lge (test/generalization loss), Ctr (training fraction nodes/edges labeled correctly), Str (training fraction examples solved correctly), Cge (test/generalization fraction nodes/edges labeled correctly), Sge (test/generalization fraction examples solved correctly)
# 00009, T 22.0, Ltr 0.6274, Lge 0.6205, Ctr 0.6819, Str 0.0000, Cge 0.6936, Sge 0.0000
# 00023, T 42.3, Ltr 0.6243, Lge 0.6270, Ctr 0.6815, Str 0.0000, Cge 0.6781, Sge 0.0000
# 00038, T 63.5, Ltr 0.6223, Lge 0.6251, Ctr 0.6801, Str 0.0000, Cge 0.6750, Sge 0.0000
# 00051, T 83.8, Ltr 0.6189, Lge 0.6187, Ctr 0.6782, Str 0.0000, Cge 0.6759, Sge 0.0000
# 00065, T 104.0, Ltr 0.6019, Lge 0.5947, Ctr 0.6841, Str 0.0000, Cge 0.6850, Sge 0.0000
# 00080, T 124.9, Ltr 0.5455, Lge 0.5712, Ctr 0.7538, Str 0.0000, Cge 0.7380, Sge 0.0000
# 00094, T 145.0, Ltr 0.4658, Lge 0.4483, Ctr 0.8108, Str 0.0025, Cge 0.7991, Sge 0.0000
# 00109, T 166.3, Ltr 0.4375, Lge 0.4320, Ctr 0.8138

# 01265, T 1839.2, Ltr 0.3735, Lge 0.3491, Ctr 0.8455, Str 0.0200, Cge 0.8542, Sge 0.0100
# 01279, T 1859.3, Ltr 0.3898, Lge 0.3747, Ctr 0.8378, Str 0.0125, Cge 0.8333, Sge 0.0100
# 01294, T 1880.7, Ltr 0.3843, Lge 0.3543, Ctr 0.8386, Str 0.0125, Cge 0.8474, Sge 0.0500
# 01308, T 1900.7, Ltr 0.3877, Lge 0.3499, Ctr 0.8365, Str 0.0200, Cge 0.8549, Sge 0.0400
# 01322, T 1921.8, Ltr 0.3740, Lge 0.3464, Ctr 0.8447, Str 0.0175, Cge 0.8509, Sge 0.0100
# 01336, T 1942.0, Ltr 0.3782, Lge 0.3557, Ctr 0.8452, Str 0.0300, Cge 0.8488, Sge 0.0400
# 01351, T 1962.9, Ltr 0.3769, Lge 0.3630, Ctr 0.8399, Str 0.0150, Cge 0.8358, Sge 0.0200
# 01365, T 1983.0, Ltr 0.3785, Lge 0.3718, Ctr 0.8417, Str 0.0125, Cge 0.8391, Sge 0.0000
# 01379, T 2003.4, Ltr 0.3717, Lge 0.3455, Ctr 0.8471, Str 0.0175, Cge 0.8475, Sge 0.0000
# 01393, T 2023.5, Ltr 0.3803, Lge 0.3605, Ctr 0.8396, Str 0.0150, Cge 0.8418, Sge 0.0100
# 01408, T 2044.7, Ltr 0.3787, Lge 0.3639, Ctr 0.8381, Str 0.0150, Cge 0.8356, Sge 0.0200
# 01422, T

# 02583, T 3731.7, Ltr 0.3505, Lge 0.3522, Ctr 0.8667, Str 0.0800, Cge 0.8474, Sge 0.0300
# 02597, T 3751.9, Ltr 0.3643, Lge 0.3285, Ctr 0.8530, Str 0.0650, Cge 0.8562, Sge 0.0600
# 02611, T 3771.9, Ltr 0.3629, Lge 0.3368, Ctr 0.8502, Str 0.0575, Cge 0.8549, Sge 0.0400
# 02626, T 3793.3, Ltr 0.3561, Lge 0.2953, Ctr 0.8574, Str 0.0750, Cge 0.8786, Sge 0.0800
# 02641, T 3814.5, Ltr 0.3477, Lge 0.3301, Ctr 0.8614, Str 0.0700, Cge 0.8573, Sge 0.0700
# 02655, T 3834.6, Ltr 0.3544, Lge 0.3246, Ctr 0.8596, Str 0.0750, Cge 0.8594, Sge 0.0800
# 02669, T 3854.7, Ltr 0.3622, Lge 0.3412, Ctr 0.8530, Str 0.0875, Cge 0.8611, Sge 0.1000
# 02683, T 3874.9, Ltr 0.3468, Lge 0.3494, Ctr 0.8641, Str 0.0825, Cge 0.8412, Sge 0.0400
# 02698, T 3895.9, Ltr 0.3505, Lge 0.3259, Ctr 0.8572, Str 0.0800, Cge 0.8526, Sge 0.0700
# 02712, T 3916.2, Ltr 0.3570, Lge 0.3122, Ctr 0.8551, Str 0.0625, Cge 0.8692, Sge 0.0600
# 02726, T 3936.4, Ltr 0.3629, Lge 0.3248, Ctr 0.8517, Str 0.0750, Cge 0.8591, Sge 0.0300
# 02740, T