In the "Cell" menu, select "Run All"

## Activate Package
This takes care of all the dependencies (other pieces of software) that we rely on.

In [None]:
import Pkg; Pkg.activate("."); Pkg.instantiate()

In [None]:
using Revise  # Development tool - must come first, otherwise don't worry about it!

## Custom Mesh Package
This package provides a simple wrapper for some of the functionality of [p4est](https://p4est.org), a powerful, low-level library for massively-parallel adaptive mesh refinement (AMR).

In [None]:
using SimpleTreeMeshes

## Helpers
Other packages to help with plotting, printing, etc.

In [None]:
using Plots
using Printf
using Images, FileIO

## An image used to drive the simulation

In [None]:
#image_name = "images/prism_square.png"
image_name = "images/scribble.png"

img = Gray.(load(image_name))
sz = size(img)
mask_size = min(sz[1], sz[2])
img2 = img[1:mask_size, 1:mask_size]
mask_r = convert(Array{Float16}, img2)
mask = zeros(size(mask_r))
for i = 1:mask_size
    for j = 1:mask_size
        mask[i,j] = mask_r[mask_size - j + 1,i]
    end
end
display(img2)

## Defining a Physical Problem

In [None]:
struct Problem
  name::String
  boundary_conditions::StokesBoundaryConditions
  eta::Function
  eta_min::Float64
  eta_max::Float64
  eta_ref::Float64
  f_x::Function
  f_y::Function
  f_p::Function
  vx_ref::Function
  vy_ref::Function
  p_ref::Function
  coordinate_scale::Float64
  coordinate_offset_x::Float64
  coordinate_offset_y::Float64
end

In [None]:
function image_value(x, y)
  @assert 0 <= x <= 1  
  @assert 0 <= y <= 1  
  return mask[Int(floor(x * mask_size)), Int(floor(y * mask_size))]
end;

function problem_Image()
    eta(x, y) = 1.0
    eta_min = 1.0
    eta_max= 1.0
    eta_ref = sqrt(eta_min * eta_max)

    f_vy(x,y) = -10.0 * (1.0 - image_value(x,y))
    f_vx(x, y) = 0.0
    f_p(x, y) = 0.0 
    
    vx_ref(x, y) = 0.0
    vy_ref(x, y) = 0.0
    p_ref(x, y) = 0.0
 
    return Problem("Image", stokes_boundary_free_slip, eta, eta_min, eta_max, eta_ref, f_vx, f_vy, f_p, 
                   vx_ref, vy_ref, p_ref, 1.0, 0.0, 0.0)
    
end;

## Refinement function

In [None]:
min_level = 1;
max_level = 6

function refinement_function(x, y, d, level)::Bool
    
    if level <= min_level;  return true; end
    if level >= max_level;  return false;  end

    # attempt to refine if all local pixels aren't in a constant band of values 
    mx_min = max(1, Int(ceil(x * (mask_size - 1))))
    my_min = max(1, Int(ceil(y * (mask_size - 1))))

    mx_max = max(1, Int(ceil((x + d) * (mask_size - 1))))
    my_max = max(1, Int(ceil((y + d) * (mask_size - 1))))

    max_local = 0
    min_local = 1.0

    for i = mx_min:mx_max
        for j = my_min:my_max
            val = mask[i,j]
            if val > max_local
                max_local = val
            end
            if val < min_local
                min_local = val
            end
        end
    end

    return max_local - min_local > 0.2
end;

## Create our tree object

In [None]:
tree = CreateSimpleTreeMesh(refinement_function);

## Plot to see the grid

In [None]:
start = time()
plot_size = 600
p_grid = plot(axis=nothing, xaxis=false, yaxis=false, size=(plot_size, plot_size), aspect_ratio=1)
plot_grid!(p_grid, tree, RGB(0.5, 0.5, 0.5))
plot_face_numbers!(p_grid, tree)
plot_element_numbers!(p_grid, tree, offset=tree.nf) # Number after faces, as in our Stokes system
plot_corner_numbers!(p_grid, tree) # A separate numbering system
display(p_grid)
@printf "⏲️ %2.2f s\n" time()- start

## Solve the Stokes system

In [None]:
start = time()
tree = CreateSimpleTreeMesh(refinement_function)
@printf "⚙️ %d elements, %d faces, %d corners " tree.ne tree.nf tree.nc
@printf "⏲️ %2.2f s\n" time() - start

In [None]:
start = time()
h_min = 2.0^(-(max_level))
problem = problem_Image()
A, b, Kcont = assemble_system(tree, problem, h_min)
x = A\b
v, p = sol2vp(tree, x, Kcont)
@printf "⏲️ %2.2f s\n" time() - start

In [None]:
start = time()
plot_size = 800
p_stokes = plot(axis=nothing, xaxis=false, yaxis=false, size=(plot_size, plot_size), aspect_ratio=1)
plot_element_field!(p_stokes, tree, p)
c = RGB(0.25, 0.25, 0.25)
plot_grid!(p_stokes, tree, c)
plot_averaged_velocity_field!(p_stokes, tree, v, color=c)
display(p_stokes)
@printf "⏲️ %2.2f s\n" time() - start

## Push a particle

In [None]:
function test_element(tree, e_id, xp, yp)
    @assert 1 <= e_id <= tree.ne
    x, y = get_element_corner_coordinates(tree, e_id)
    d = get_element_size(tree, e_id)
    return (x <= xp <= x + d) && (y <= yp <= y + d)
end

function locate_point(tree, x, y, e_id_guess)
    if test_element(tree, e_id_guess, x, y)
       return e_id_guess
    end
    
    if x < tree.coordinate_offset_x || x > tree.coordinate_offset_x + tree.coordinate_scale
        error("x coordinate out of bounds")
    end
    
    if y < tree.coordinate_offset_y || y > tree.coordinate_offset_y + tree.coordinate_scale
        error("y coordinate out of bounds")
    end
    
    e_id_left = e_id_guess - 1
    e_id_right = e_id_guess + 1
    while e_id_left >= 1 || e_id_right <= tree.ne
        if e_id_left >= 1 
            if test_element(tree, e_id_left, x, y)
                return e_id_left
            end
            e_id_left -= 1
        end
        if e_id_right <= tree.ne 
            if test_element(tree, e_id_right, x, y)
                return e_id_right
            end
        end
        e_id_right += 1
    end 
    @assert false
end;

In [None]:
function get_point_velocity(tree::SimpleTreeMesh, xp, yp, e_id, v)
    @assert 1 <= e_id <= tree.ne
    @assert length(v) == tree.nf
    x, y = get_element_corner_coordinates(tree, e_id)
    d = get_element_size(tree, e_id)
    @assert (x <= xp <= x + d) && (y <= yp <= y + d)
    eta = (xp - x)/d
    f_id_left = tree.e2f[LEFT, e_id]
    f_id_right = tree.e2f[RIGHT, e_id]
    vx = (1.0 - eta) * v[f_id_left] + eta * v[f_id_right]

    xi = (yp - y)/d
    f_id_down = tree.e2f[DOWN, e_id]
    f_id_up = tree.e2f[UP, e_id]
    vy = (1.0 - xi) * v[f_id_down] + xi * v[f_id_up]
    
    return vx, vy
end;

In [None]:
function create_trace(tree, x0, y0, v, dt=50.0, nsteps=70)
    particle_x = x0
    particle_xs = []
    particle_y = y0
    particle_ys = []
    e_id = tree.ne ÷ 2
    dt = 1.0
    for step in 1:nsteps
       append!(particle_xs, particle_x)
       append!(particle_ys, particle_y)
       e_id = locate_point(tree, particle_x, particle_y, e_id)
       vx, vy = get_point_velocity(tree, particle_x, particle_y, e_id, v)
       particle_x += dt * vx
       particle_y += dt * vy
       if particle_x < tree.coordinate_offset_x; particle_x = tree.coordinate_offset_x; end
       if particle_x > tree.coordinate_offset_x + tree.coordinate_scale; particle_x = tree.coordinate_offset_x + tree.coordinate_scale; end
       if particle_y < tree.coordinate_offset_y; particle_y = tree.coordinate_offset_y; end
       if particle_y > tree.coordinate_offset_y + tree.coordinate_scale; particle_y = tree.coordinate_offset_y + tree.coordinate_scale; end
    end
    return particle_xs, particle_ys
end;

In [None]:
start = time()
plot_size = 800
p_stokes = plot(axis=nothing, xaxis=false, yaxis=false, size=(plot_size, plot_size), aspect_ratio=1)
c = RGB(0.25, 0.25, 0.25)
plot_grid!(p_stokes, tree, c)
plot_averaged_velocity_field!(p_stokes, tree, v, color=c)
plot!(p_stokes, create_trace(tree, 0.5, 0.5, v)..., markershape=:circle, label=nothing)
plot!(p_stokes, create_trace(tree, 0.5, 0.6, v)..., markershape=:circle, label=nothing)
plot!(p_stokes, create_trace(tree, 0.5, 0.7, v)..., markershape=:circle, label=nothing)
plot!(p_stokes, create_trace(tree, 0.5, 0.8, v)..., markershape=:circle, label=nothing)
plot!(p_stokes, create_trace(tree, 0.5, 0.9, v)..., markershape=:circle, label=nothing)
display(p_stokes)
@printf "⏲️ %2.2f s\n" time() - start