In [15]:
using Revise
using Plots
using Hungarian
using Gen
using Statistics

includet("./distribution_utils.jl")
includet("./model.jl")
includet("./visualizations.jl")
includet("./inference.jl")

In [None]:
function visualize_importance_samples(traces, constraint_img, log_weights, scene_size)
    fig = plot(layout = grid(1,2),  size=(400, 400), legend=false)
    frame = channelview(RGB.(constraint_img))
    heatmap!(fig[2], frame, aspect_ratio=:equal, xlims=[0, scene_size + 1], ylims=[0, scene_size + 1])
    for t in 1:length(traces)
        trace = traces[t]
        choices = get_choices(trace)
        n_fireflies = choices[:init => :n_fireflies]
        xs = [choices[:init => :init_x => n] for n in 1:n_fireflies]
        ys = [choices[:init => :init_y => n] for n in 1:n_fireflies]
        color_opts = ["red", "green", "blue"]
        colors = [color_opts[choices[:init => :color => n]] for n in 1:n_fireflies]
        scatter!(fig[1], xs, ys, color=colors, markersize=5, alpha=log_weights[t] + 0.001, aspect_ratio=:equal,  xlims=[0, scene_size + 1], ylims=[0, scene_size + 1], yflip=true,
                label="Importance Samples")
    end

    return fig
end

function visualize_importance_resampling(trace, constraint_img)
    fig = plot(layout = grid(1,2), size=(400, 400), legend=false)
    frame = colorview(RGB, constraint_img)
    Plots.plot!(fig[2], frame, title="Frame", aspect_ratio=:equal, xlims=[0, scene_size + 1], ylims=[0, scene_size + 1])
    choices = get_choices(trace)
    n_fireflies = choices[:init => :n_fireflies]
    xs = [choices[:init => :init_x => n] for n in 1:n_fireflies]
    ys = [choices[:init => :init_y => n] for n in 1:n_fireflies]
    color_opts = ["red", "green", "blue"]
    colors = [color_opts[choices[:init => :color => n]] for n in 1:n_fireflies]
    scatter!(fig[1], xs, ys, color=colors, markersize=5, aspect_ratio=:equal,  xlims=[0, scene_size + 1], ylims=[0, scene_size + 1], yflip=true,
    label="Importance Samples")
    return fig
end

In [None]:
function find_color_patches(image::Array{Float64, 3}, threshold::Float64, size_prior::Int, color_threshold::Float64)
    height, width = size(image, 2), size(image, 3)
    visited = zeros(Int, height, width)
    patches = []

    # Function to flood fill and track patch
    function flood_fill(y, x)
        stack = [(y, x)]
        patch = []
        initial_color = image[:, y, x]
        too_large = false

        while !isempty(stack)
            cy, cx = pop!(stack)
            if cy < 1 || cy > height || cx < 1 || cx > width || visited[cy, cx] == 1
                continue
            end

            pixel_color = image[:, cy, cx]
            
            # Luminance filter and color filter
            if maximum(pixel_color) <= threshold || norm(pixel_color - initial_color) > color_threshold
                continue
            end

            visited[cy, cx] = 1
            push!(patch, (cx, cy, argmax(pixel_color)))            

            for dy in -1:1, dx in -1:1
                if dy == 0 && dx == 0
                    continue
                end
                push!(stack, (cy + dy, cx + dx))
            end
        end

        if length(patch) > size_prior  
            too_large = true
        end
        
        return patch, too_large
    end

    # Function to split large patches using simple median split
    function split_patch(patch)
        xs, ys = [p[1] for p in patch], [p[2] for p in patch]
        mean_x, mean_y = mean(xs), mean(ys)
        cluster1, cluster2 = [], []
        
        for p in patch
            if p[1] < mean_x || p[2] < mean_y
                push!(cluster1, p)
            else
                push!(cluster2, p)
            end
        end

        return cluster1, cluster2
    end

    # Main loop to find patches
    for y in 1:height, x in 1:width
        if maximum(image[:, y, x]) > threshold && visited[y, x] == 0
            patch, too_large = flood_fill(y, x)
            if too_large
                cluster1, cluster2 = split_patch(patch)
                push!(patches, cluster1, cluster2)
            else
                push!(patches, patch)
            end
        end
    end

    return patches, length(patches)
end


@gen function feature_based_proposal(max_fireflies, img_array)
    _, _, scene_size = size(img_array)
    
    # Sum across color channels and reshape to 2D
    # intensity = dropdims(sum(img_array, dims=1), dims=1)
    
    threshold = 0.2 # Luminance threshold
    size_prior = 30 # size prior for splitting patches
    color_threshold = 0.7 # color matching threshold
    patch_info, num_clusters = find_color_patches(img_array, threshold, size_prior, color_threshold)
    n_fireflies = {:init => :n_fireflies} ~ uniform_discrete(num_clusters, num_clusters)
    for n in 1:num_clusters
        patch = patch_info[n]
        patch_index = uniform_discrete(1, length(patch))
        x_opts = zeros(scene_size)
        y_opts = zeros(scene_size)
        color_opts = zeros(3)

        # Constrain initialization to area of patch
        x_opts[patch[patch_index][1]] = 1
        y_opts[patch[patch_index][2]] = 1
        color_opts[patch[patch_index][3]] = 1

        x = {:init => :init_x => n} ~ categorical(x_opts)
        y = {:init => :init_y => n} ~ categorical(y_opts)
        color = {:init => :color => n} ~ categorical(color_opts)
    end
end

In [None]:
scene_size = 64
max_fireflies = 4
steps = 1

constraints = choicemap()

# Test case 1 - overlap different color
gt_constraints = choicemap()
gt_constraints[:init => :n_fireflies] = 2
gt_constraints[:states => 1 => :x => 1] = 10
gt_constraints[:states => 1 => :y => 1] = 10
gt_constraints[:init => :color => 1] = 1

gt_constraints[:states => 1 => :x => 2] = 12
gt_constraints[:states => 1 => :y => 2] = 12
gt_constraints[:init => :color => 2] = 2

gt_trace, w = generate(model, (scene_size, max_fireflies, steps,), gt_constraints)
choices = get_choices(gt_trace)

constraints = choicemap()
constraints[:observations => 1] = choices[:observations => 1]
true_x = [choices[:states => 1 => :x => n] for n in 1:choices[:init => :n_fireflies]]
true_y = [choices[:states => 1 => :y => n] for n in 1:choices[:init => :n_fireflies]]

tr, lml_est = importance_resampling(model, get_args(gt_trace), constraints, 
                                    feature_based_proposal, 
                                    (max_fireflies, choices[:observations => 1]), 10)

inferred_choices = get_choices(tr)
inf_xs = [inferred_choices[:states => 1 => :x => n] for n in 1:inferred_choices[:init => :n_fireflies]]
inf_ys = [inferred_choices[:states => 1 => :y => n] for n in 1:inferred_choices[:init => :n_fireflies]]
inf_color = [inferred_choices[:init => :color => n] for n in 1:inferred_choices[:init => :n_fireflies]]

println("Inferred n_fireflies: ", inferred_choices[:init => :n_fireflies])
println("GT n_fireflies: ", choices[:init => :n_fireflies])
fig = visualize_importance_resampling(tr, choices[:observations => 1])
scatter!(fig[1], true_x, true_y, color="black", markershape=:x, markersize=5, label="True Locations")
display(fig)


# Test case 2 - overlap same color
gt_constraints = choicemap()
gt_constraints[:init => :n_fireflies] = 2
gt_constraints[:states => 1 => :x => 1] = 10
gt_constraints[:states => 1 => :y => 1] = 10
gt_constraints[:init => :color => 1] = 1
gt_constraints[:states => 1 => :x => 2] = 12
gt_constraints[:states => 1 => :y => 2] = 12
gt_constraints[:init => :color => 2] = 1

gt_trace, w = generate(model, (scene_size, max_fireflies, steps,), gt_constraints)
choices = get_choices(gt_trace)

constraints = choicemap()
constraints[:observations => 1] = choices[:observations => 1]
true_x = [choices[:states => 1 => :x => n] for n in 1:choices[:init => :n_fireflies]]
true_y = [choices[:states => 1 => :y => n] for n in 1:choices[:init => :n_fireflies]]

tr, lml_est = importance_resampling(model, get_args(gt_trace), constraints, 
                                    feature_based_proposal, 
                                    (max_fireflies, choices[:observations => 1]), 10)

inferred_choices = get_choices(tr)
println("Inferred n_fireflies: ", inferred_choices[:init => :n_fireflies])
println("GT n_fireflies: ", choices[:init => :n_fireflies])
inf_xs = [inferred_choices[:states => 1 => :x => n] for n in 1:inferred_choices[:init => :n_fireflies]]
inf_ys = [inferred_choices[:states => 1 => :y => n] for n in 1:inferred_choices[:init => :n_fireflies]]
inf_color = [inferred_choices[:init => :color => n] for n in 1:inferred_choices[:init => :n_fireflies]]

fig = visualize_importance_resampling(tr, choices[:observations => 1])
scatter!(fig[1], true_x, true_y, color="black", markershape=:x, markersize=5, label="True Locations")
display(fig)

In [None]:
num_particles = 1000
num_samples = 20

inferred_traces = smc(gt_trace, model, num_particles, num_samples; record_json=false, experiment_tag="");
anim = visualize_particles(inferred_traces, gt_trace)
gif(anim, "firefly_inference.mp4", fps = 10)

In [164]:
# DOESN'T WORK WHEN THERE ARE MISMATCHED OBJECTS / PREDICTIONS
function position_accuracy(gt_trace, inferred_traces)
    gt_states, obs = get_retval(gt_trace)
    gt_xs = gt_states[:xs]
    gt_ys = gt_states[:ys]
    steps = length(gt_xs)
    scores = []
    for (i, inferred_trace) in enumerate(inferred_traces)
        states, _ = get_retval(inferred_trace)
        inferred_xs = states[:xs]
        inferred_ys = states[:ys]
        l2_dist = (sum((gt_xs .- inferred_xs).^2) + sum((gt_ys .- inferred_ys).^2)) / steps
        push!(scores, l2_dist)
    end
    scores
end

if max_fireflies == 1
    scores = position_accuracy(gt_trace, inferred_traces)
    println(mean(scores))
end

In [None]:
function optimal_position_error(gt_trace, inferred_traces; max_cost=100.)
    gt_states, obs = get_retval(gt_trace)
    gt_xs = gt_states[:xs]
    gt_ys = gt_states[:ys]
    num_gt_objects, steps = size(gt_xs)
    
    all_scores = Float64[]
    all_misses = Int[]
    assignments = []
    
    for (i, inferred_trace) in enumerate(inferred_traces)
        states, _ = get_retval(inferred_trace)
        inferred_xs = states[:xs]
        inferred_ys = states[:ys]
        num_inferred_objects = size(inferred_xs, 1)
        
        # Determine the size of the cost matrix
        max_objects = max(num_gt_objects, num_inferred_objects)
        
        # Compute cost matrix for this particle
        cost_matrix = fill(max_cost, (max_objects, max_objects))
        for inf_obj in 1:num_inferred_objects
            for gt_obj in 1:num_gt_objects
                squared_distances = (gt_xs[gt_obj, :] .- inferred_xs[inf_obj, :]).^2 .+ 
                                    (gt_ys[gt_obj, :] .- inferred_ys[inf_obj, :]).^2
                average_distance = sqrt(sum(squared_distances) / steps)
                cost_matrix[inf_obj, gt_obj] = min(average_distance, max_cost)
            end
        end
        
        # Apply Hungarian algorithm
        assignment, cost = hungarian(cost_matrix)
        
        # Compute scores based on optimal assignment
        particle_scores = zeros(num_inferred_objects)
        for inf_obj in 1:num_inferred_objects
            assigned_gt = assignment[inf_obj]
            if assigned_gt <= num_gt_objects  # Real assignment
                particle_scores[inf_obj] = cost_matrix[inf_obj, assigned_gt]
            else  # Assigned to a dummy object (missed)
                particle_scores[inf_obj] = max_cost
            end
        end
        
        # Count missed ground truth objects
        missed_gt = count(i -> i > num_inferred_objects, assignment[1:num_gt_objects])
        
        append!(all_scores, sum(particle_scores))
        push!(all_misses, missed_gt)
        push!(assignments, assignment)
    end
    
    return all_scores, all_misses, assignments
end

scores, misses, assignments = optimal_position_error(gt_trace, inferred_traces; max_cost=64.)

println(mean(scores), " ", std(scores), " ", mean(misses), " ", std(misses))

fig = plot(layout=(1, 2), figsize=(10, 5))
scatter!(fig[1], scores, xlabel="Particle", ylabel="Error", title="Tracking Error per particle")
bar!(fig[2], misses, xlabel="Particle", ylabel="Missed objects", title="Missed objects per particle")

In [None]:
function pixel_reconstruction_error(gt_trace, particles)
    scene_size, max_fireflies, steps = get_args(gt_trace)
    gt_choices = get_choices(gt_trace)
    gt_obs = [gt_choices[:observations => t] for t in 1:steps]

    l2_dists = []
    errormaps = []
    inferred_renders = []
    for (i, particle) in enumerate(particles)
        particle_errormap = []
        states, _  = get_retval(particle)
        particle_obs = []
        for t in 1:steps
            obs = render!(states, t, scene_size)
            p_errormap = logpdfmap(image_likelihood, gt_obs[t], obs, 0.01)
            push!(particle_errormap, p_errormap)
            push!(particle_obs, obs)
        end
        push!(errormaps, particle_errormap)
        l2_dist = sum([sum((gt_obs[t] .- particle_obs[t]).^2) for t in 1:steps]) / (scene_size^2 * steps)
        push!(l2_dists, l2_dist)
        push!(inferred_renders, particle_obs)
    end
    l2_dists, errormaps, inferred_renders
end

l2_dists, errormaps, obses = pixel_reconstruction_error(gt_trace, inferred_traces)

In [None]:
anim = @animate for t in 1:steps
    # Create a 4x5 grid of subplots
    fig = plot(layout = (4, 5), size=(1000, 800), 
        suptitle="Reconstruction Error [Frame $t]")
    # Add figure title
    for p in 1:num_samples
        rendered = dropdims(sum(obses[p][t]; dims=1), dims=1)
        # rendered = (rendered .- minimum(rendered)) / (maximum(rendered) - minimum(rendered))
        error = errormaps[p][t] 
        img = error

        heatmap!(img, color=:grays, 
            axis=nothing, legend=false, 
            aspect_ratio=:equal, subplot=p,
            colorbar=false, 
            xticks=false, yticks=false, 
            title="Particle $p")
    end
end

gif(anim, "reconstruction_error.mp4", fps = 10)

In [None]:
anim = @animate for t in 1:steps
    # Create a 4x5 grid of subplots
    fig = plot(layout = (4, 5), size=(1000, 800), 
        suptitle="Reconstruction Error [Frame $t]")
    # Add figure title
    for p in 1:num_samples
        # rendered = mat_to_img(obses[p][t])
        img = mat_to_img(obses[p][t]) # Access the t-th image of element i
        heatmap!(img, axis=nothing, legend=false, 
            aspect_ratio=:equal, subplot=p,
            colorbar=false, 
            xticks=false, yticks=false, 
            title="Particle $p")
    end
end

gif(anim, "reconstruction_error.mp4", fps = 10)