In [24]:
import json
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout
import ipywidgets as widgets
from IPython.display import display

# Load the combined JSON file with all snapshots
with open('graph_viz_test.json', 'r') as f:
    graph_data = json.load(f)

# Extract the list of snapshots from the JSON
graph_snapshots = graph_data['graphs']

# Function to plot a specific snapshot based on the selected index
def plot_dag_snapshot(snapshot_index):
    snapshot = graph_snapshots[snapshot_index]

    G = nx.DiGraph()

    # Add nodes
    for node in snapshot['nodes']:
        G.add_node(
            node['index'], 
            N=node['N'], 
            Q=node['Q'], 
            state=node['state']
        )

    # Add edges
    for edge in snapshot['edges']:
        G.add_edge(
            edge['from'], 
            edge['to'], 
            E=edge['E'], 
            action=edge['action'], 
            index=edge['index']
        )

    # Use graphviz_layout for hierarchical placement
    pos = graphviz_layout(G, prog="dot")

    # Define node labels
    node_labels = {node: f"{data['state']}\nN: {data['N']}\nQ: {[f'{float(q):.2f}' for q in data['Q']]}" for node, data in G.nodes(data=True)}

    # Define edge labels (showing E and action)
    edge_labels = {(u, v): f"A: {data['action']}, E: {data['E']}" for u, v, data in G.edges(data=True)}

    # Clear the current plot
    plt.figure(figsize=(12, 8))
    
    # Draw the graph for this snapshot
    nx.draw(G, pos, with_labels=False, node_size=3000, node_color="lightblue", font_size=10)
    nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_color="black")

    # Draw edge labels 
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, label_pos=0.7, font_size=8)

    # Set the plot title to indicate which snapshot is being displayed
    plt.title(f"DAG Snapshot {snapshot_index + 1}")
    plt.show()

# Create an interactive slider to select the snapshot index
snapshot_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(graph_snapshots) - 1,
    step=1,
    description='Snapshot:',
    continuous_update=False
)

# Display the plot for the selected snapshot whenever the slider value changes
widgets.interactive(lambda snapshot_index: plot_dag_snapshot(snapshot_index), snapshot_index=snapshot_slider)



interactive(children=(IntSlider(value=0, continuous_update=False, description='Snapshot:', max=19), Output()),…