In [118]:
using Revise
using Plots
using Hungarian
using Gen
using Statistics

includet("./distribution_utils.jl")
includet("./model.jl")
includet("./visualizations.jl")
includet("./inference.jl")

In [None]:
function visualize_importance_samples(traces, constraint_img, log_weights, scene_size)
    fig = plot(layout = grid(1,2),  size=(400, 400), legend=false)
    frame = mat_to_img(constraint_img)
    Plots.heatmap!(fig[2], frame, title="Frame", aspect_ratio=:equal, xlims=[0, scene_size + 1], ylims=[0, sce])
    for t in 1:length(traces)
        trace = traces[t]
        choices = get_choices(trace)
        n_fireflies = choices[:init => :n_fireflies]
        xs = [choices[:init => :init_x => n] for n in 1:n_fireflies]
        ys = [choices[:init => :init_y => n] for n in 1:n_fireflies]
        color_opts = ["red", "green", "blue"]
        colors = [color_opts[choices[:init => :color => n]] for n in 1:n_fireflies]
        scatter!(fig[1], xs, ys, color=colors, markersize=5, alpha=log_weights[t] + 0.001, aspect_ratio=:equal,  xlims=[0, scene_size + 1], ylims=[0, scene_size + 1], yflip=true,
                label="Importance Samples")
    end

    return fig
end

function visualize_importance_resampling(trace, constraint_img)
    fig = plot(layout = grid(1,2), size=(400, 400), legend=false)
    frame = mat_to_img(constraint_img)
    Plots.heatmap!(fig[2], frame, title="Frame", aspect_ratio=:equal, xlims=[0, scene_size + 1], ylims=[0, scene_size + 1])
    choices = get_choices(trace)
    n_fireflies = choices[:init => :n_fireflies]
    println("n_fireflies: ", n_fireflies)
    xs = [choices[:init => :init_x => n] for n in 1:n_fireflies]
    ys = [choices[:init => :init_y => n] for n in 1:n_fireflies]
    color_opts = ["red", "green", "blue"]
    colors = [color_opts[choices[:init => :color => n]] for n in 1:n_fireflies]
    scatter!(fig[1], xs, ys, color=colors, markersize=5, aspect_ratio=:equal,  xlims=[0, scene_size + 1], ylims=[0, scene_size + 1], yflip=true,
    label="Importance Samples")
    return fig
end

In [133]:
# test = "importance_resampling"
# if test == "importance_resampling"
#     println("Running importance resampling")
#     out, lml_est = importance_resampling(model, get_args(gt_trace), obs, 1000)
#     println("Visualizing....")
#     fig = visualize_importance_resampling(out, choices[:observations => 1])
#     scatter!(fig[1], true_x, true_y, color="black", markershape=:x, markersize=5, label="True Location")
#     display(fig)
# elseif test == "importance_sampling"
#     println("Running importance sampling")
#     out, log_norm_weight, lml_est = importance_sampling(model, get_args(gt_trace), obs, 100)
#     println("Visualizing....")
#     fig = visualize_importance_samples(out, choices[:observations => 1], exp.(log_norm_weight))
#     scatter!(fig[1], true_x, true_y, color="black", markershape=:x, markersize=5, label="True Location")
#     display(fig)
# end

In [None]:
function identify_color_patches(image::Matrix{Int64}, threshold::Float64)
    height, width = size(image)
    labels = zeros(Int, height, width)
    next_label = 1
    
    # Helper function for depth-first search
    function dfs(i, j, label)
        if i < 1 || i > height || j < 1 || j > width
            return
        end
        if labels[i, j] != 0 || image[i, j] <= threshold
            return
        end
        
        labels[i, j] = label
        
        # Check 8-connected neighbors
        for di in -1:1, dj in -1:1
            dfs(i + di, j + dj, label)
        end
    end
    
    # Main loop to identify patches
    for i in 1:height, j in 1:width
        if image[i, j] > threshold && labels[i, j] == 0
            dfs(i, j, next_label)
            next_label += 1
        end
    end
    
    return labels, next_label - 1
end


function get_patch_info(image::Array{Float64,3}, labels::Matrix{Int}, num_clusters::Int)
    patch_info = Dict{Int, Vector{NamedTuple{(:x, :y, :color), Tuple{Int,Int,Int}}}}()
    
    for label in 1:num_clusters
        patch_pixels = findall(==(label), labels)
        patch_info[label] = []
        for pixel in patch_pixels
            y, x = pixel.I
            color = argmax(image[:, y, x])
            push!(patch_info[label], (x=x, y=y, color=color))
        end
    end
    
    return patch_info
end


@gen function feature_based_proposal(max_fireflies, observation)
    _, _, scene_size = size(observation)

    # Convert image to Float64 array
    img_array = Float64.(channelview(observation))
    
    # Sum across color channels and reshape to 2D
    intensity = dropdims(sum(img_array, dims=1), dims=1)
    
    threshold = 0.2
    labels, num_clusters = identify_color_patches(intensity, threshold)
    patch_info = get_patch_info(img_array, labels, num_clusters)

    n_fireflies = {:init => :n_fireflies} ~ uniform_discrete(num_clusters, num_clusters)
    for n in 1:num_clusters
        patch = patch_info[n]
        patch_index = uniform_discrete(1, length(patch))
        x_opts = zeros(scene_size)
        y_opts = zeros(scene_size)
        color_opts = zeros(3)
        x_opts[patch[patch_index].x] = 1
        y_opts[patch[patch_index].y] = 1
        color_opts[patch[patch_index].color] = 1

        x = {:init => :init_x => n} ~ categorical(x_opts)
        y = {:init => :init_y => n} ~ categorical(y_opts)
        color = {:init => :color => n} ~ categorical(color_opts)
    end
end


In [None]:
function animate_obs(choices)
    observations = [mat_to_img(choices[:observations => t]) for t in 1:steps]
    fig = plot(layout = grid(1,1), legend=false, size=(400, 400))
    anim = Plots.@animate for t in 1:steps
        frame = observations[t]
        fig = plot(layout = grid(1,1), legend=false, size=(400, 400))
        Plots.heatmap!(fig, frame, title="Frame $t")
    end
    return fig
end

In [None]:
scene_size = 64
max_fireflies = 4
steps = 1

constraints = choicemap()
for i in range(1, 20)
    gt_trace = simulate(model, (scene_size, max_fireflies, steps,))
    choices = get_choices(gt_trace)

    constraints = choicemap()
    constraints[:observations => 1] = choices[:observations => 1]
    true_x = [choices[:init => :init_x => n] for n in 1:choices[:init => :n_fireflies]]
    true_y = [choices[:init => :init_y => n] for n in 1:choices[:init => :n_fireflies]]

    tr, lml_est = importance_resampling(model, get_args(gt_trace), obs, 
                                        feature_based_proposal, 
                                        (max_fireflies, choices[:observations => 1]), 10)
    
    inferred_choices = get_choices(tr)
    inf_xs = [inferred_choices[:init => :init_x => n] for n in 1:inferred_choices[:init => :n_fireflies]]
    inf_ys = [inferred_choices[:init => :init_y => n] for n in 1:inferred_choices[:init => :n_fireflies]]
    inf_color = [inferred_choices[:init => :color => n] for n in 1:inferred_choices[:init => :n_fireflies]]
    
    # println("Actual n_fireflies: ", choices[:init => :n_fireflies])
    # println("inferred n_fireflies: ", inferred_choices[:init => :n_fireflies])
    fig = visualize_importance_resampling(tr, choices[:observations => 1])
    scatter!(fig[1], true_x, true_y, color="black", markershape=:x, markersize=5, label="True Locations")
    display(fig)
end

In [None]:
num_particles = 1000
num_samples = 20

inferred_traces = smc(gt_trace, model, num_particles, num_samples; record_json=false, experiment_tag="");
anim = visualize_particles(inferred_traces, gt_trace)
gif(anim, "firefly_inference.mp4", fps = 10)

In [164]:
# DOESN'T WORK WHEN THERE ARE MISMATCHED OBJECTS / PREDICTIONS
function position_accuracy(gt_trace, inferred_traces)
    gt_states, obs = get_retval(gt_trace)
    gt_xs = gt_states[:xs]
    gt_ys = gt_states[:ys]
    steps = length(gt_xs)
    scores = []
    for (i, inferred_trace) in enumerate(inferred_traces)
        states, _ = get_retval(inferred_trace)
        inferred_xs = states[:xs]
        inferred_ys = states[:ys]
        l2_dist = (sum((gt_xs .- inferred_xs).^2) + sum((gt_ys .- inferred_ys).^2)) / steps
        push!(scores, l2_dist)
    end
    scores
end

if max_fireflies == 1
    scores = position_accuracy(gt_trace, inferred_traces)
    println(mean(scores))
end

In [None]:
function optimal_position_error(gt_trace, inferred_traces; max_cost=100.)
    gt_states, obs = get_retval(gt_trace)
    gt_xs = gt_states[:xs]
    gt_ys = gt_states[:ys]
    num_gt_objects, steps = size(gt_xs)
    
    all_scores = Float64[]
    all_misses = Int[]
    assignments = []
    
    for (i, inferred_trace) in enumerate(inferred_traces)
        states, _ = get_retval(inferred_trace)
        inferred_xs = states[:xs]
        inferred_ys = states[:ys]
        num_inferred_objects = size(inferred_xs, 1)
        
        # Determine the size of the cost matrix
        max_objects = max(num_gt_objects, num_inferred_objects)
        
        # Compute cost matrix for this particle
        cost_matrix = fill(max_cost, (max_objects, max_objects))
        for inf_obj in 1:num_inferred_objects
            for gt_obj in 1:num_gt_objects
                squared_distances = (gt_xs[gt_obj, :] .- inferred_xs[inf_obj, :]).^2 .+ 
                                    (gt_ys[gt_obj, :] .- inferred_ys[inf_obj, :]).^2
                average_distance = sqrt(sum(squared_distances) / steps)
                cost_matrix[inf_obj, gt_obj] = min(average_distance, max_cost)
            end
        end
        
        # Apply Hungarian algorithm
        assignment, cost = hungarian(cost_matrix)
        
        # Compute scores based on optimal assignment
        particle_scores = zeros(num_inferred_objects)
        for inf_obj in 1:num_inferred_objects
            assigned_gt = assignment[inf_obj]
            if assigned_gt <= num_gt_objects  # Real assignment
                particle_scores[inf_obj] = cost_matrix[inf_obj, assigned_gt]
            else  # Assigned to a dummy object (missed)
                particle_scores[inf_obj] = max_cost
            end
        end
        
        # Count missed ground truth objects
        missed_gt = count(i -> i > num_inferred_objects, assignment[1:num_gt_objects])
        
        append!(all_scores, sum(particle_scores))
        push!(all_misses, missed_gt)
        push!(assignments, assignment)
    end
    
    return all_scores, all_misses, assignments
end

scores, misses, assignments = optimal_position_error(gt_trace, inferred_traces; max_cost=64.)

println(mean(scores), " ", std(scores), " ", mean(misses), " ", std(misses))

fig = plot(layout=(1, 2), figsize=(10, 5))
scatter!(fig[1], scores, xlabel="Particle", ylabel="Error", title="Tracking Error per particle")
bar!(fig[2], misses, xlabel="Particle", ylabel="Missed objects", title="Missed objects per particle")

In [None]:
function pixel_reconstruction_error(gt_trace, particles)
    scene_size, max_fireflies, steps = get_args(gt_trace)
    gt_choices = get_choices(gt_trace)
    gt_obs = [gt_choices[:observations => t] for t in 1:steps]

    l2_dists = []
    errormaps = []
    inferred_renders = []
    for (i, particle) in enumerate(particles)
        particle_errormap = []
        states, _  = get_retval(particle)
        particle_obs = []
        for t in 1:steps
            obs = render!(states, t, scene_size)
            p_errormap = logpdfmap(image_likelihood, gt_obs[t], obs, 0.01)
            push!(particle_errormap, p_errormap)
            push!(particle_obs, obs)
        end
        push!(errormaps, particle_errormap)
        l2_dist = sum([sum((gt_obs[t] .- particle_obs[t]).^2) for t in 1:steps]) / (scene_size^2 * steps)
        push!(l2_dists, l2_dist)
        push!(inferred_renders, particle_obs)
    end
    l2_dists, errormaps, inferred_renders
end

l2_dists, errormaps, obses = pixel_reconstruction_error(gt_trace, inferred_traces)

In [None]:
anim = @animate for t in 1:steps
    # Create a 4x5 grid of subplots
    fig = plot(layout = (4, 5), size=(1000, 800), 
        suptitle="Reconstruction Error [Frame $t]")
    # Add figure title
    for p in 1:num_samples
        rendered = dropdims(sum(obses[p][t]; dims=1), dims=1)
        # rendered = (rendered .- minimum(rendered)) / (maximum(rendered) - minimum(rendered))
        error = errormaps[p][t] 
        img = error

        heatmap!(img, color=:grays, 
            axis=nothing, legend=false, 
            aspect_ratio=:equal, subplot=p,
            colorbar=false, 
            xticks=false, yticks=false, 
            title="Particle $p")
    end
end

gif(anim, "reconstruction_error.mp4", fps = 10)

In [None]:
anim = @animate for t in 1:steps
    # Create a 4x5 grid of subplots
    fig = plot(layout = (4, 5), size=(1000, 800), 
        suptitle="Reconstruction Error [Frame $t]")
    # Add figure title
    for p in 1:num_samples
        # rendered = mat_to_img(obses[p][t])
        img = mat_to_img(obses[p][t]) # Access the t-th image of element i
        heatmap!(img, axis=nothing, legend=false, 
            aspect_ratio=:equal, subplot=p,
            colorbar=false, 
            xticks=false, yticks=false, 
            title="Particle $p")
    end
end

gif(anim, "reconstruction_error.mp4", fps = 10)