In [None]:
using DataStructures
using Dates

"""
A node in a stream computation graph that represents a computation step.
"""
mutable struct StreamNode{TGraph}
    graph::TGraph
    index::Int
    func::Function
    inputs::Vector{StreamNode}
    label::String
    function StreamNode(graph::TGraph, index, func, label) where {TGraph}
        new{TGraph}(graph, index, func, StreamNode[], label)
    end
end

@inline is_source(node::StreamNode) = isempty(node.inputs)

struct Event{TTime}
    timestamp::TTime
    trigger_index::Int
    Event(timestamp::TTime, trigger_index::Int) where {TTime} = new{TTime}(timestamp, trigger_index)
end

@inline Base.isless(a::Event, b::Event) = a.timestamp < b.timestamp

"""
A directed acyclic graph (DAG) that represents a stream computation graph.
"""
mutable struct StreamGraph
    nodes::Vector{StreamNode{StreamGraph}}
    deps::Vector{Vector{Int}}
    reverse_deps::Vector{Vector{Int}}
    topo_order::Vector{Int}
    function StreamGraph()
        new(
            StreamNode{StreamGraph}[],
            Vector{Int}[],
            Vector{Int}[],
            Int[]
        )
    end
end

function topological_sort!(graph::StreamGraph)
    in_degree = [length(deps) for deps in graph.deps]
    queue = Int[]

    for (index, degree) in enumerate(in_degree)
        if degree == 0
            push!(queue, index)
        end
    end

    while !isempty(queue)
        node_index = pop!(queue)
        push!(graph.topo_order, node_index)

        for dependent_index in graph.reverse_deps[node_index]
            in_degree[dependent_index] -= 1
            if in_degree[dependent_index] == 0
                push!(queue, dependent_index)
            end
        end
    end

    if length(graph.topo_order) != length(graph.nodes)
        error("Graph has a cycle")
    end

    nothing
end

function node!(graph::StreamGraph, func::Function; label::String="")
    if isempty(label)
        # label = string(func)
        buffer = IOBuffer()
        Base.show_unquoted(buffer, func)
        label = String(take!(buffer))
    end
    index = length(graph.nodes) + 1
    node = StreamNode(graph, index, func, label)
    push!(graph.nodes, node)
    push!(graph.deps, Int[])
    push!(graph.reverse_deps, Int[])
    node
end

function bind_inputs!(graph::StreamGraph, node::StreamNode, inputs)
    node.inputs = collect(inputs)
    input_indices = [input.index for input in node.inputs]
    graph.deps[node.index] = input_indices
    for input_index in input_indices
        push!(graph.reverse_deps[input_index], node.index)
    end
end

# syntactic sugar for binding inputs (bind_inputs!)
function Base.:|>(inputs, node::StreamNode{G}) where {G}
    for (i, input) in enumerate(inputs)
        input isa StreamNode || error("Input #$i is not a StreamNode")
    end
    bind_inputs!(node.graph, node, collect(StreamNode, inputs))
    node
end

"""
An executor that runs a stream computation graph.
"""
mutable struct Executor{TTime}
    graph::StreamGraph
    event_queue::BinaryMinHeap{Event{TTime}}
    states::Vector{Any}
    current_time::TTime
    function Executor(graph::StreamGraph, start_time::TTime) where {TTime}
        states = Vector{Any}(nothing, length(graph.nodes))
        new{TTime}(graph, BinaryMinHeap{Event{TTime}}(), states, start_time)
    end
end

@inline function time(executor::Executor{TTime})::TTime where {TTime}
    executor.current_time
end

function execute_node!(executor::Executor{TTime}, node::StreamNode{StreamGraph}, is_event_trigger::Bool) where {TTime}
    if is_source(node)
        if is_event_trigger
            # get the actual value from the source function
            executor.states[node.index] = node.func(time(executor), executor, node)
        end
    else
        if any(isnothing(executor.states[x.index]) for x in node.inputs)
            executor.states[node.index] = nothing
        else
            input_values = (executor.states[x.index] for x in node.inputs)
            executor.states[node.index] = node.func(time(executor), input_values...)
        end
    end
    nothing
end

function process_subgraph!(executor::Executor{TTime}, start_node_index::Int) where {TTime}
    graph = executor.graph
    nodes = graph.nodes
    queue = PriorityQueue{Int,Int}()
    enqueue!(queue, start_node_index => 0)
    depths = Dict(start_node_index => 0)

    while !isempty(queue)
        node_index = dequeue!(queue)
        is_event_trigger = (node_index == start_node_index)
        execute_node!(executor, nodes[node_index], is_event_trigger)

        current_depth = depths[node_index]
        for dependent_index in graph.reverse_deps[node_index]
            new_depth = current_depth + 1
            if get(depths, dependent_index, -1) < new_depth
                depths[dependent_index] = new_depth
                if haskey(queue, dependent_index)
                    queue[dependent_index] = -new_depth
                else
                    enqueue!(queue, dependent_index => -new_depth)
                end
            end
        end
    end

    nothing
end

function process_event!(executor::Executor{TTime}, event::Event{TTime}) where {TTime}
    executor.current_time = event.timestamp
    process_subgraph!(executor, event.trigger_index)
    nothing
end

function enqueue_event!(executor::Executor{TTime}, timestamp::TTime, trigger_index::Int) where {TTime}
    push!(executor.event_queue, Event(timestamp, trigger_index))
    nothing
end

function run_simulation!(executor::Executor)
    # initialize source nodes
    for node in executor.graph.nodes
        if is_source(node)
            execute_node!(executor, node, true)
        end
    end

    # process events in order of their timestamps
    while !isempty(executor.event_queue)
        event = pop!(executor.event_queue)
        process_event!(executor, event)
    end

    nothing
end

function create_list_source(data::D) where {D<:AbstractVector}
    current_index = Ref(1)

    function source_func(time::TTime, executor::Executor{TTime}, node::StreamNode{StreamGraph}) where {TTime}
        if current_index[] <= length(data)
            timestamp, output = data[current_index[]]
            current_index[] += 1
            if current_index[] <= length(data)
                next_timestamp, _ = data[current_index[]]
                enqueue_event!(executor, next_timestamp, node.index)
            end
            return output
        end
        nothing
    end

    source_func
end

# Example usage
g = StreamGraph()

# Create source nodes
source1 = node!(g, create_list_source([
        (DateTime(2000, 1, 1, 0, 0, 1), 2),
        (DateTime(2000, 1, 1, 0, 0, 3), 4),
        (DateTime(2000, 1, 1, 0, 0, 5), 6)
    ]), label="source1")
source2 = node!(g, create_list_source([
        (DateTime(2000, 1, 1, 0, 0, 2), 10),
        (DateTime(2000, 1, 1, 0, 0, 4), 20),
        (DateTime(2000, 1, 1, 0, 0, 6), 30)
    ]), label="source2")
source3 = node!(g, create_list_source([
        (DateTime(2000, 1, 1, 0, 0, 2), 10),
        (DateTime(2000, 1, 1, 0, 0, 4), 20),
        (DateTime(2000, 1, 1, 0, 0, 6), 30)
    ]), label="source3")

# Create compute nodes
square = node!(g, (time, x) -> x^2; label="square")
divide_by_2 = node!(g, (time, x) -> x / 2, label="divide_by_2")
negate = node!(g, (time, x) -> -x, label="negate")
combine = node!(g, (time, x, y) -> (x, y), label="combine")
final_multiply = node!(g, (time, tuple, src2_val, src3_val) -> tuple .* src2_val .+ src3_val, label="final_multiply")
output = node!(g, (time, x) -> println("Final Output at time $time: $x"), label="output")

# Bind inputs
[source1] |> square
[square] |> divide_by_2
[source2] |> negate
(divide_by_2, negate) |> combine
(combine, source2, source3) |> final_multiply
[final_multiply] |> output

# Perform topological sort
topological_sort!(g)

# Create executor and run simulation
executor = Executor(g, DateTime(2000, 1, 1))
run_simulation!(executor)

In [None]:
using Colors
using CairoMakie
using GraphMakie
using Graphs
using LayeredLayouts

function visualize_graph(graph::StreamGraph)
    nodes = graph.nodes
    g = SimpleDiGraph(length(nodes))
    nlabels = String[]
    node_colors = []
    
    # Create a mapping from node IDs to graph indices
    zero_levels = Vector{Pair{Int, Int}}()
    nlabels_align = []
    nlabels_color = []

    for (i, node) in enumerate(nodes)
        node_id = node.index
        push!(nlabels, node.label)
        if is_source(node)
            push!(node_colors, colorant"#ffdd33")  # yellow for source nodes
            push!(zero_levels, i => 1) # source layers always at level 1
            push!(nlabels_align, (:center, :bottom))
            push!(nlabels_color, colorant"#bf489d")
        else
            push!(node_colors, colorant"#b4dee8")  # light blue for compute nodes
            push!(nlabels_align, (:center, :top))
            push!(nlabels_color, colorant"#000000")
        end
        
        for dep_id in graph.deps[node_id]
            add_edge!(g, nodes[dep_id].index, nodes[node_id].index)
        end
    end

    xs, ys, paths = solve_positions(Zarate(), g) #; force_layer=zero_levels)
    xs, ys = ys, -xs # rotate coordinates by 90°
    ys .*= 0.5 # scale the y coordinates

    f, ax, p = graphplot(g;
        layout=Point.(zip(xs, ys)),
        arrow_size=15,
        arrow_shift=:end,
        arrow_marker='>',
        edge_width=0.75,
        edge_color=colorant"#444",
        # node_color=node_colors,
        # node_size=48,
        nlabels=nlabels,
        nlabels_fontsize=14,
        nlabels_align=nlabels_align,
        nlabels_color=nlabels_color,
        # nlabels_distance=12,
        node_size=48,
        node_color=:white,
        nlabels_distance=-10,
    )
    hidedecorations!(ax)
    hidespines!(ax)

    # add some padding
    x_range, y_range = extrema(xs), extrema(ys)
    x_range, y_range = x_range[2] - x_range[1], y_range[2] - y_range[1]
    xlims!(ax, minimum(xs) - 0.2x_range, maximum(xs) + 0.2x_range)
    ylims!(ax, minimum(ys) - 0.1y_range, maximum(ys) + 0.1y_range)

    # adjust width to match aspect ratio
    height = size(f.scene)[2]
    width = floor(Int, height*(x_range / y_range))
    resize!(f, width, height)
    
    f, ax, p
end

f, ax, p = visualize_graph(g)

# Save the plot to a file
# save("computation_graph.png", f)

display(f);