In [None]:
import os
import time

import tensorflow as tf
from graph_nets import utils_tf
from lib.gns import obses_to_cgraphs
from lib.data_utils import extract_target_windows
import numpy as np

from experience import load_experience
from lib.action_space import is_do_nothing_action
from lib.constants import Constants as Const
from lib.data_utils import make_dir, env_pf
from lib.gns import (
    tf_batched_graph_dataset,
    get_graph_feature_dimensions,
    GraphNetworkSwitching,
)
from lib.run_utils import create_logger
from lib.visualizer import Visualizer, pprint

Visualizer()
np.random.seed(0)

# experience_dir = make_dir(os.path.join(Const.RESULTS_DIR, "performance-aug"))
experience_dir = make_dir(os.path.join(Const.EXPERIENCE_DIR, "data-aug"))
results_dir = make_dir(os.path.join(Const.RESULTS_DIR, "binary-linear"))

agent_name = "agent-mip"
case_name = "l2rpn_2019_art"
env_dc = True
verbose = False

case_results_dir = make_dir(os.path.join(results_dir, f"{case_name}-{env_pf(env_dc)}"))
create_logger(logger_name=f"{case_name}-{env_pf(env_dc)}", save_dir=case_results_dir)

case, collector = load_experience(case_name, agent_name, experience_dir, env_dc=env_dc)
obses, actions, rewards, dones = collector.aggregate_data()

pprint("    - Number of chronics:", dones.sum())
pprint("    - Observations:", len(obses))

In [None]:
"""
    Parameters
"""
n_window_targets = 0
n_window_history = 1
n_batch = 32
n_epochs = 20

"""
    Datasets
"""

labels = is_do_nothing_action(actions, case.env).astype(float)
pprint("    - Labels:", f"{int(labels.sum())}/{labels.shape[0]}", "{:.2f} %".format(100 * labels.mean()))

mask_positive = extract_target_windows(labels, mask=~dones, n_window=n_window_targets)
mask_negative = np.logical_and(np.random.binomial(1, 0.05, len(labels)), ~mask_positive)
mask_targets = np.logical_or(mask_positive, mask_negative)

pprint("    - Mask (0):", mask_negative.sum(), "{:.2f} %".format(100 * mask_negative.sum() / mask_targets.sum()))
pprint("    - Mask (1):", mask_positive.sum(), "{:.2f} %".format(100 * mask_positive.sum() / mask_targets.sum()))
pprint("    - Mask:", mask_targets.sum())

In [1]:
cgraphs = obses_to_cgraphs(obses, dones, case, mask=mask_targets, n_window=n_window_history)
graph_dims = get_graph_feature_dimensions(cgraphs=cgraphs)
cgraph_dims = {**graph_dims, "n_nodes": case.env.n_sub, "n_edges": 2 * case.env.n_line}

pprint("    - Cgraphs:", len(cgraphs["globals"]))
for field in cgraphs:
    pprint(f"        - {field}:", cgraphs[field][0].shape)
    
for n_feat in cgraph_dims:
    pprint(f"        - {n_feat}:", cgraph_dims[n_feat])

NameError: name 'obses_to_cgraphs' is not defined

In [None]:
graph_dataset = tf_batched_graph_dataset(cgraphs, n_batch=n_batch, **graph_dims)
label_dataset = tf.data.Dataset.from_tensor_slices(labels[mask_targets]).batch(n_batch)
dataset = tf.data.Dataset.zip((graph_dataset, label_dataset))

"""
    Signatures
"""

graphs_sig = utils_tf.specs_from_graphs_tuple(
    next(iter(graph_dataset)), dynamic_num_graphs=True
)
labels_sig = tf.TensorSpec(shape=[None], dtype=tf.dtypes.float64)

In [None]:
"""
    Model
"""
tf.random.set_seed(0)
model = GraphNetworkSwitching(
    pos_class_weight=1 / labels.mean(),
    n_hidden=(512, 512, 512, 512, 512),
    graphs_signature=graphs_sig,
    labels_signature=labels_sig,
    **cgraph_dims,
)

model_dir = os.path.join(case_results_dir, "model-10")
checkpoint_path = os.path.join(model_dir, "ckpts")

ckpt = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    pprint(f"Restoring checkpoint from:", ckpt_manager.latest_checkpoint)

In [None]:
"""
    Training
"""
import time
n_epochs = 10

# Epoch Metrics
metric_recall_e = tf.keras.metrics.Recall()
metric_precision_e = tf.keras.metrics.Precision()
metric_accuracy_e = tf.keras.metrics.Accuracy()
metric_bce_e = tf.keras.metrics.BinaryCrossentropy(from_logits=False)
recall_e = []
precision_e = []
accuracy_e = []
bce_e = []

# Batch Metrics
metric_accuracy_b = tf.metrics.binary_accuracy
metric_bce_b = tf.metrics.binary_crossentropy
losses = []
accuracy = []

for epoch in range(n_epochs):
    start = time.time()

    # Reset epoch metrics
    metric_recall_e.reset_states()
    metric_precision_e.reset_states()
    metric_accuracy_e.reset_states()
    metric_bce_e.reset_states()

    for batch, (graph_batch, label_batch) in enumerate(dataset):
        (
            output_graphs,
            loss,
            probabilities,
            predicted_labels,
            gradients,
        ) = model.train_step(graph_batch, label_batch)

        # Batch Metric
        bce = metric_bce_b(label_batch, probabilities)  # Control
        acc = metric_accuracy_b(
            label_batch, tf.cast(predicted_labels, dtype=tf.float64)
        )

        # Epoch Metrics
        metric_recall_e(label_batch, predicted_labels)
        metric_precision_e(label_batch, predicted_labels)
        metric_accuracy_e(label_batch, predicted_labels)
        metric_bce_e(label_batch, probabilities)

        losses.append(loss.numpy())
        accuracy.append(acc.numpy())

        if batch % 5 == 0:
            pprint(
                "        - Batch/Epoch:",
                f"{batch}/{epoch}",
                "loss = {:.4f}".format(loss.numpy()),
                "bce = {:.4f}".format(bce.numpy()),
                "acc = {} %".format(int(100 * acc.numpy())),
            )

    recall_e.append(metric_recall_e.result().numpy())
    precision_e.append(metric_precision_e.result().numpy())
    accuracy_e.append(metric_accuracy_e.result().numpy())
    bce_e.append(metric_bce_e.result().numpy())

    if (epoch + 1) % 10 == 0:
        ckpt_save_path = ckpt_manager.save()
        pprint(f"            - Saving checkpoint to:", ckpt_save_path)
        pprint(f"            - Time taken for epoch:", f"{time.time() - start} secs")

ckpt_save_path = ckpt_manager.save()
pprint(f"    - Saving checkpoint to:", ckpt_save_path)

In [None]:
print_variables(model.trainable_variables)

In [None]:
import matplotlib.pyplot as plt
from lib.tf_utils import print_variables
import numpy as np

losses = np.array(losses)
accuracy = np.array(accuracy)
accuracy_e = np.array(accuracy_e)
bce_e = np.array(bce_e)
recall_e = np.array(recall_e)
precision_e = np.array(precision_e)

make_dir(model_dir)


"""
    Batch Metrics
"""
fig, ax = plt.subplots(figsize=(Const.FIG_SIZE))
ax.plot(losses, lw=1.0)
ax.set_title("Batch training loss")
ax.set_xlabel("Batch")
ax.set_ylabel("Loss")
fig.savefig(os.path.join(model_dir, "training-loss"))
# plt.close(fig)

fig, ax = plt.subplots(figsize=(Const.FIG_SIZE))
ax.plot(accuracy * 100, lw=1.0)
ax.set_title("Batch accuracy")
ax.set_xlabel("Batch")
ax.set_ylabel("Accuracy [\%]")
# ax.set_ylim([0, 100])
fig.savefig(os.path.join(model_dir, "training-accuracy"))
# plt.close(fig)

"""
    Epoch Metrics
"""
fig, ax = plt.subplots(figsize=(Const.FIG_SIZE))
ax.plot(accuracy_e * 100, lw=1.0)
ax.set_title("Epoch accuracy")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy [\%]")
# ax.set_ylim([0, 100])
fig.savefig(os.path.join(model_dir, "epoch-accuracy"))
# plt.close(fig)

fig, ax = plt.subplots(figsize=(Const.FIG_SIZE))
ax.plot(bce_e, lw=1.0)
ax.set_title("Epoch cross-entropy")
ax.set_xlabel("Epoch")
ax.set_ylabel("Cross-entropy")
fig.savefig(os.path.join(model_dir, "epoch-bce"))
# plt.close(fig)

fig, ax = plt.subplots(figsize=(Const.FIG_SIZE))
ax.plot(recall_e * 100, lw=1.0)
ax.set_title("Epoch recall")
ax.set_xlabel("Epoch")
ax.set_ylabel("Recall [\%]")
# ax.set_ylim([0, 100])
fig.savefig(os.path.join(model_dir, "epoch-recall"))
# plt.close(fig)

fig, ax = plt.subplots(figsize=(Const.FIG_SIZE))
ax.plot(precision_e * 100, lw=1.0)
ax.set_title("Epoch precision")
ax.set_xlabel("Epoch")
ax.set_ylabel("Precision [\%]")
# ax.set_ylim([0, 100])
fig.savefig(os.path.join(model_dir, "epoch-precision"))
# plt.close(fig)

In [None]:
import sonnet as snt
from lib.tf_utils import print_variables

n_nodes = 14

class OutputModel(snt.Module):
    def __init__(self, n_edge_features, n_node_features, n_edges, n_nodes):
        super(OutputModel, self).__init__(name="output_model")
        
        self.n_edge_features = n_edge_features
        self.n_node_features = n_node_features
        self.n_edges = n_edges
        self.n_nodes = n_nodes
        
        self.layer_concat = tf.keras.layers.Concatenate(
            axis=-1, name="layer_concat", dtype=tf.float64
        )
        
        self.layer_linear = snt.Linear(output_size=1, with_bias=True, name="layer_linear")
        
    def __call__(self, input_graphs):
        edges = tf.reshape(input_graphs.edges, shape=[-1, 2 * self.n_edges, self.n_edge_features])
        edges = tf.math.reduce_max(edges, axis=1)
        print(edges.shape) 
        
        nodes = tf.reshape(input_graphs.nodes, shape=[-1, self.n_nodes, self.n_node_features])
        nodes = tf.math.reduce_max(nodes, axis=1)
        print(nodes.shape)
        
        x = self.layer_concat([nodes, edges])
        print(x.shape)
        
        x = self.layer_linear(x)
        print(x.shape)
        
        x = tf.reshape(x, shape=[-1])
        print(x.shape)
        return x
        
output_model = OutputModel(n_edge_features=graph_dims["n_edge_features"],
                           n_node_features=graph_dims["n_node_features"],
                           n_edges=n_edges,
                          n_nodes=n_nodes)

print(tf.reshape(model.graph_network(g_stacked).edges, shape=[-1, 2 * n_edges, graph_dims["n_edge_features"]]).shape)
print(tf.reshape(model.graph_network(g_stacked).nodes, shape=[-1, n_nodes, graph_dims["n_node_features"]]).shape)

logits = output_model(model.graph_network(g_stacked))

print(output_model.submodules)
print_variables(output_model.trainable_variables)


In [None]:
"""
    Chronic datasets.
"""

dataset_by_chronic = dict()
for chronic_idx, chronic_len in zip(collector.chronic_ids, collector.chronic_lengths):
    dataset_by_chronic[chronic_idx] = dataset.take(chronic_len)
    print(tf.data.experimental.cardinality(dataset_by_chronic[chronic_idx]))
    dataset = dataset.skip(chronic_len)