In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# Load the CSV file
csv_path = "../code/data/N600t500r20lendingrate3.csv"
df = pd.read_csv(csv_path)

# Pre-filter data by node ID
node_data = {}
for node_id in df["id"].unique():
    node_df = df[df["id"] == node_id].copy()
    node_df.sort_values("t", inplace=True)
    node_data[node_id] = node_df

# Pre-calculate normalized values for each node
normalized_data = {}
candidate_cols = ["wage", "savings", "loan", "a", "beta", "links", "component", "rate", "U_self", "e_self", "e_star"]
binary_cols = ["startup", "thwart", "borrow", "move", "go"]

log_area = widgets.Output()

for node_id, node_df in node_data.items():
    norm_node_data = {}
    for col in candidate_cols:
        if col in node_df.columns:
            col_data = node_df[col].astype(float)
            cmin, cmax = col_data.min(), col_data.max()
            if cmax - cmin == 0:
                with log_area:
                    print(f"  - {col}: constant value, skipping normalization")
                norm_node_data[col] = None
            else:
                norm_vals = (col_data - cmin) / (cmax - cmin)
                norm_node_data[col] = norm_vals
    for col in binary_cols:
        if col in node_df.columns:
            norm_node_data[col] = node_df[col]
    normalized_data[node_id] = norm_node_data

# Create a dedicated output widget for the plot
plot_area = widgets.Output()

previous_figure = None

def update_plot(node_id):
    global previous_figure
    with log_area:
        clear_output(wait=True)
        print(f"Updating plot for node ID={node_id}")

    # Clear the dedicated plot_output region
    plot_area.clear_output(wait=True)
    
    with plot_area:
        if previous_figure:
            plt.close(previous_figure)

        fig, ax = plt.subplots(figsize=(10, 6))
        lines = {}
        for col in candidate_cols:
            line, = ax.plot([], [], label=col)
            lines[col] = line

        ax.set_xlabel("Time (t)")
        ax.set_ylabel("Normalized Value")
        fig.tight_layout()

        if node_id in normalized_data:
            norm_node_data = normalized_data[node_id]
            node_df = node_data[node_id]
            plotted_lines = []
            for col, line in lines.items():
                if col in norm_node_data:
                    if norm_node_data[col] is None:
                        print(f"  - {col}: constant value, skipping plot")
                        line.set_data([], [])
                    elif np.all(norm_node_data[col] == 0):
                        print(f"  - {col}: all zeros, skipping")
                        line.set_data([], [])
                    else:
                        line.set_data(node_df["t"], norm_node_data[col])
                        print(f"  - {col}: x={len(node_df['t'])}, y={len(norm_node_data[col])}")
                        plotted_lines.append(line)
                else:
                    line.set_data([], [])
            
            for col in binary_cols:
                if col in norm_node_data:
                    event_times = node_df["t"][norm_node_data[col] == 1]
                    ax.plot(event_times, [0] * len(event_times), "*", label=col, markersize=10)
                    print(f"  - {col}: {len(event_times)} events")
            
            # Add timeline below the axis
            ax2 = ax.twiny()
            ax2.set_xlim(ax.get_xlim())
            ax2.set_xticks(node_df["t"].unique())
            ax2.set_xticklabels(node_df["t"].unique(), rotation=90, fontsize=8)
            ax2.tick_params(axis='x', which='both', length=0)
            ax2.spines['bottom'].set_position(('outward', 30))
            ax2.set_xlabel("Time (t)", labelpad=20)
            
            ax.relim()
            ax.autoscale_view()
            
            # Add the legend only for plotted lines
            if plotted_lines:
                ax.legend(handles=plotted_lines, bbox_to_anchor=(1.05, 1), loc="upper left")

            fig.canvas.draw()
            fig.canvas.flush_events()
            plt.show()
            previous_figure = fig
        else:
            print(f"No data for node ID {node_id}")

# Create the dropdown
unique_ids = sorted(df["id"].unique())
dropdown = widgets.Dropdown(
    options=unique_ids,
    description="Node ID:",
    value=unique_ids[0],
    style={'description_width': 'initial'}
)

def on_value_change(change):
    update_plot(change["new"])

dropdown.observe(on_value_change, names="value")

# Display the widgets
display(dropdown)
display(log_area)
display(plot_area)

# Initial plot
update_plot(dropdown.value)