In [12]:
import json
import copy
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout
import ipywidgets as widgets
from IPython.display import display

# Load the JSON data for each snapshot into graph_snapshots (same as previous code)
graph_snapshots = []
for i in range(1, 51): 
    with open(f'DAGs/search_thread{i}.json', 'r') as f:
        graph_snapshots.append(json.load(f))
        
# Function to plot a specific snapshot with edge labels closer to "from" node
def plot_dag_snapshot(snapshot_index):
    graph_data = graph_snapshots[snapshot_index]

    # Create a directed graph for the current snapshot
    G = nx.DiGraph()

    # Add nodes with their attributes
    for node in graph_data['nodes']:
        G.add_node(
            node['index'], 
            N=node['N'], 
            Q=node['Q'], 
            state=node['state']
        )

    # Add edges with their attributes
    for edge in graph_data['edges']:
        G.add_edge(
            edge['from'], 
            edge['to'], 
            E=edge['E'], 
            action=edge['action'], 
            index=edge['index']
        )

    # Use graphviz_layout for hierarchical placement
    pos = graphviz_layout(G, prog="dot")

    # Define node labels to display only the 'state' attribute inside each node
    node_labels = {node: f"{data['state']}\nN: {data['N']}\nQ: {[f'{float(q):.2f}' for q in data['Q']]}" for node, data in G.nodes(data=True)}

    # Define edge labels (showing E and action)
    edge_labels = {(u, v): f"E: {data['E']}, A: {data['action']}" for u, v, data in G.edges(data=True)}

    # Draw the graph for this snapshot
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=False, node_size=3000, node_color="lightblue", font_size=10)
    nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_color="black")

    # Adjust edge label positions to be closer to the "from" node
    adjusted_edge_pos = copy.deepcopy(pos)
    for (u, v), label in edge_labels.items():
        # Calculate the adjusted position closer to the "from" node
        x1, y1 = pos[u]
        x2, y2 = pos[v]
        midpoint_x = x1 * 0.9 + x2 * 0.1
        midpoint_y = y1 * 0.9 + y2 * 0.1
        adjusted_edge_pos[(u, v)] = (midpoint_x, midpoint_y)

    # Draw edge labels with adjusted positions
    nx.draw_networkx_edge_labels(G, adjusted_edge_pos, edge_labels=edge_labels, label_pos=0.7, font_size=8)

    # Set the plot title to indicate which snapshot is being displayed
    plt.title(f"DAG Snapshot {snapshot_index + 1}")
    plt.show()

# Create an interactive slider to select the snapshot index
snapshot_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(graph_snapshots) - 1,
    step=1,
    description='Snapshot:',
    continuous_update=False
)

# Display the plot for the selected snapshot whenever the slider value changes
widgets.interactive(lambda snapshot_index: plot_dag_snapshot(snapshot_index), snapshot_index=snapshot_slider)

interactive(children=(IntSlider(value=0, continuous_update=False, description='Snapshot:', max=49), Output()),…