In [27]:
using DataStructures
using Dates
import Random: randstring

"""
A node in a stream computation graph that represents a computation step.
"""
mutable struct StreamNode{TGraph}
    graph::TGraph
    index::Int
    func::Function
    inputs::Vector{StreamNode}
    output_type::Type
    init_value::Any
    label::String
    field_name::Symbol
    function StreamNode(graph::TGraph, index, func, output_type, init_value, label) where {TGraph}
        field_name = Symbol("node$(index)_$(label)")
        new{TGraph}(graph, index, func, StreamNode[], output_type, init_value, label, field_name)
    end
end

@inline is_source(node::StreamNode) = isempty(node.inputs)

"""
A directed acyclic graph (DAG) that represents a stream computation graph.
"""
mutable struct StreamGraph
    nodes::Vector{StreamNode{StreamGraph}}
    deps::Vector{Vector{Int}}
    reverse_deps::Vector{Vector{Int}}
    topo_order::Vector{Int}
    function StreamGraph()
        new(
            StreamNode{StreamGraph}[],
            Vector{Int}[],
            Vector{Int}[],
            Int[]
        )
    end
end

function topological_sort!(graph::StreamGraph)
    in_degree = [length(deps) for deps in graph.deps]
    queue = Int[]

    for (index, degree) in enumerate(in_degree)
        if degree == 0
            push!(queue, index)
        end
    end

    while !isempty(queue)
        node_index = pop!(queue)
        push!(graph.topo_order, node_index)

        for dependent_index in graph.reverse_deps[node_index]
            in_degree[dependent_index] -= 1
            if in_degree[dependent_index] == 0
                push!(queue, dependent_index)
            end
        end
    end

    if length(graph.topo_order) != length(graph.nodes)
        error("Graph has a cycle")
    end

    nothing
end

function node!(::Type{TOutput}, graph::StreamGraph, func::Function; init=nothing, label::String="") where {TOutput}
    if isempty(label)
        # label = string(func)
        buffer = IOBuffer()
        Base.show_unquoted(buffer, func)
        label = String(take!(buffer))
    end
    index = length(graph.nodes) + 1
    node = StreamNode(graph, index, func, TOutput, init, label)
    push!(graph.nodes, node)
    push!(graph.deps, Int[])
    push!(graph.reverse_deps, Int[])
    node
end

function bind_inputs!(graph::StreamGraph, node::StreamNode, inputs)
    node.inputs = collect(inputs)
    input_indices = [input.index for input in node.inputs]
    graph.deps[node.index] = input_indices
    for input_index in input_indices
        push!(graph.reverse_deps[input_index], node.index)
    end
end

# syntactic sugar for binding inputs (bind_inputs!)
function Base.:|>(inputs, node::StreamNode{G}) where {G}
    for (i, input) in enumerate(inputs)
        input isa StreamNode || error("Input #$i is not a StreamNode")
    end
    bind_inputs!(node.graph, node, collect(StreamNode, inputs))
    node
end

"""
Compile the states struct for the given graph to store intermediate results
of computation steps.
"""
function compile_states(graph::StreamGraph)
    # Generate a unique name for the struct
    struct_name = Symbol("GraphStates_" * randstring(8))

    field_defs = []
    constructor_args = []
    for (i, node) in enumerate(graph.nodes)
        field_type = :($(Union{node.output_type,typeof(node.init_value)}))
        push!(field_defs, Expr(:(::), node.field_name, field_type))
        push!(constructor_args, node.init_value)
    end

    struct_def = Expr(:struct, true, struct_name, Expr(:block, field_defs...))
    Core.eval(@__MODULE__, struct_def)

    constructor_def = :($struct_name() = $struct_name($(constructor_args...)))
    Core.eval(@__MODULE__, constructor_def)

    # println("Generated struct definition:")
    # println(struct_def)
    # println("Generated constructor:")
    # println(constructor_def)

    getfield(@__MODULE__, struct_name)
end

"""
An executor that runs a stream computation graph.
"""
mutable struct Executor{TStates,TTime}
    graph::StreamGraph
    states::TStates
    current_time::TTime
    function Executor(graph::StreamGraph, states::TStates, start_time::TTime) where {TStates,TTime}
        new{TStates,TTime}(graph, states, start_time)
    end
end

@inline function time(executor::Executor{TStates,TTime})::TTime where {TStates,TTime}
    executor.current_time
end

@inline function time!(executor::Executor{TStates,TTime}, new_time::TTime) where {TStates,TTime}
    executor.current_time = new_time
    nothing
end

function compile_source!(executor::Executor{TStates,TTime}, source_node::StreamNode{StreamGraph}; debug=false) where {TStates,TTime}
    graph = executor.graph
    nodes = graph.nodes

    # Find the subgraph starting from the source node
    subgraph_indices = Int[]
    queue = [source_node.index]
    visited = falses(length(nodes))

    while !isempty(queue)
        node_index = popfirst!(queue)
        if !visited[node_index]
            push!(subgraph_indices, node_index)
            visited[node_index] = true
            append!(queue, graph.reverse_deps[node_index])
        end
    end

    # Sort the subgraph indices according to the topological order
    sort!(subgraph_indices, by=i -> graph.topo_order[i])

    # Generate code for each node in the subgraph
    node_expressions = Expr[]
    for node_index in subgraph_indices
        node = nodes[node_index]
        field_name = node.field_name
        res_name = Symbol("res_$(field_name)")
        if is_source(node)
            push!(node_expressions, :(states.$field_name = input_value))
        else
            input_exprs = [:(states.$(input.field_name)) for input in node.inputs]
            if debug
                println("Node $(node.label) inputs: ", input_exprs)
                result_expr = quote
                    try
                        result = $(node.func)(executor, $(input_exprs...))
                        println("Node $($(node.label)) output: ", result)
                        result::$(node.output_type)  # Type assertion
                    catch e
                        error("Error in node $($(node.label)): $e")
                    end
                end
            else
                result_expr = :($(node.func)(executor, $(input_exprs...)))
            end
            update_expr = quote
                # for input in $input_exprs
                #     # print value of input
                #     println("Input value: ", input)
                # end
                # abort if any input nothing
                # println("Inputs $($(node.label)): ", ($([:(states.$(input.field_name)) for input in node.inputs]...),))
                if any(isnothing, ($([:(states.$(input.field_name)) for input in node.inputs]...),))
                    return
                end
                $res_name = $result_expr
                if !isnothing($res_name)
                    states.$field_name = $res_name
                end
            end
            push!(node_expressions, update_expr)
        end
    end

    # println("Generated code for source node $(source_node.label):")
    # for expr in node_expressions
    #     println(expr)
    # end

    # Create the compiled function
    compiled_func = @eval begin
        (executor::Executor{$TStates,$TTime}, time::$TTime, input_value::$(source_node.output_type)) -> begin
            executor.current_time = time
            states = executor.states
            $(node_expressions...)
            nothing
        end
    end

    compiled_func
end

# Example usage
g = StreamGraph()

# Create source nodes
source1 = node!(Float64, g, (exe, x) -> x; init=0.0, label="source1")
source2 = node!(Float64, g, (exe, x) -> x; init=0.0, label="source2")
source3 = node!(Float64, g, (exe, x) -> x; init=0.0, label="source3")

# Create compute nodes
square = node!(Float64, g, (exe, x) -> x^2; label="square")
divide_by_2 = node!(Float64, g, (exe, x) -> x / 2; label="divide_by_2")
negate = node!(Float64, g, (exe, x) -> -x; label="negate")
combine = node!(Tuple{Float64,Float64}, g, (exe, x, y) -> (x, y); label="combine")
final_multiply = node!(Tuple{Float64,Float64}, g, (exe, tuple, src2, src3) -> tuple .* src2 .+ src3; label="final_multiply")
output = node!(Nothing, g, (exe, x) -> println("Final Output at time $(time(exe)): $x"); label="output")

# Create edges between nodes (define the computation graph)
[source1] |> square
[square] |> divide_by_2
[source2] |> negate
(divide_by_2, negate) |> combine
(combine, source2, source3) |> final_multiply
[final_multiply] |> output

# Perform topological sort
topological_sort!(g)

# Create executor
states = compile_states(g)();
executor = Executor(g, states, DateTime(2000, 1, 1))

source1_func = compile_source!(executor, source1; debug=!true)
source2_func = compile_source!(executor, source2; debug=!true)
source3_func = compile_source!(executor, source3; debug=!true)

# call the compiled source functions with test data
source1_func(executor, DateTime(2000, 1, 1, 0, 0, 1), 2.0)
source2_func(executor, DateTime(2000, 1, 1, 0, 0, 2), 10.0)
source3_func(executor, DateTime(2000, 1, 1, 0, 0, 2), 10.0)
source1_func(executor, DateTime(2000, 1, 1, 0, 0, 3), 4.0)
source3_func(executor, DateTime(2000, 1, 1, 0, 0, 4), 20.0)
source2_func(executor, DateTime(2000, 1, 1, 0, 0, 4), 20.0)
source1_func(executor, DateTime(2000, 1, 1, 0, 0, 5), 6.0)
source2_func(executor, DateTime(2000, 1, 1, 0, 0, 6), 30.0)
source3_func(executor, DateTime(2000, 1, 1, 0, 0, 6), 30.0)

Final Output at time 2000-01-01T00:00:02: (0.0, -100.0)
Final Output at time 2000-01-01T00:00:02: (10.0, -90.0)
Final Output at time 2000-01-01T00:00:03: (30.0, -90.0)
Final Output at time 2000-01-01T00:00:04: (40.0, -80.0)
Final Output at time 2000-01-01T00:00:04: (60.0, -380.0)
Final Output at time 2000-01-01T00:00:05: (180.0, -380.0)
Final Output at time 2000-01-01T00:00:06: (260.0, -880.0)
Final Output at time 2000-01-01T00:00:06: (270.0, -870.0)


In [None]:
@code_warntype source1_func(executor, DateTime(2000, 1, 1, 0, 0, 1), 2.0)

In [None]:
using BenchmarkTools
@benchmark compile_source!(executor, source1) samples=500 evals=2

In [None]:
using BenchmarkTools
@benchmark compile_source2!(executor, source1) samples=500 evals=2