## Activate Package
This takes care of all the dependencies (other pieces of software) that we rely on.

In [None]:
import Pkg; Pkg.activate(".")

In [None]:
using Revise  # Development tool - must come first, otherwise don't worry about it!

## Custom Mesh Package
This package provides a simple wrapper for some of the functionality of p4est, a very powerful, low-level library for adaptive mesh refinement (AMR).

In [None]:
using SimpleTreeMeshes

## Helpers
Other packages to help with plotting, printing, etc.

In [None]:
using Plots
using Printf
using Images
using FileIO

## An image used to drive the simulation

In [None]:
#image_name = "images/prism_square.png"
image_name = "images/scribble2.png"
img = Gray.(load(image_name))
sz = size(img)
mask_size = min(sz[1], sz[2])
img2 = img[1:mask_size, 1:mask_size]
mask_r = convert(Array{Float16}, img2)
mask = zeros(size(mask_r))
for i = 1:mask_size
    for j = 1:mask_size
        mask[i,j] = mask_r[mask_size - j + 1,i]
    end
end
display(img2)

## Defining a Physical Problem

In [None]:
struct Problem
  name::String
  boundary_conditions::StokesBoundaryConditions
  eta::Function
  eta_min::Float64
  eta_max::Float64
  eta_ref::Float64
  f_x::Function
  f_y::Function
  f_p::Function
  vx_ref::Function
  vy_ref::Function
  p_ref::Function
  coordinate_scale::Float64
  coordinate_offset_x::Float64
  coordinate_offset_y::Float64
end

In [None]:
include("solKz.jl");

In [None]:
function problem_SolKz(;B = 1.0)
    sigma = 1.0
    k_m = 1.6 * pi
    n = Int32(3)

    eta(x, y) = exp(2.0 * B * y)

    eta_min = eta(0,0)
    eta_max= eta(1,1)
    eta_ref = sqrt(eta_min * eta_max)

    f_vy(x, y) = -sigma * sin(k_m * y) * cos(n * pi * x)
    f_vx(x, y) = 0.0
    f_p(x, y) = 0.0  # physically, zero, but could be non-zero for MMS
    
    vx_ref(x, y) = evaluate_solKz(x, y, sigma, k_m, n, B)[2];  # wasteful to throw away two values
    vy_ref(x, y) = evaluate_solKz(x, y, sigma, k_m, n, B)[3];  # wasteful
    p_ref(x, y) = evaluate_solKz(x, y, sigma, k_m, n, B)[1]; # wasteful
 
    return Problem("SolKz", stokes_boundary_free_slip, eta, eta_min, eta_max, eta_ref, f_vx, f_vy, f_p, 
                   vx_ref, vy_ref, p_ref, 1.0, 0.0, 0.0)
end;

In [None]:
function image_value(x, y)
  @assert 0 <= x <= 1  
  @assert 0 <= y <= 1  
  return mask[Int(floor(x * mask_size)), Int(floor(y * mask_size))]
end;

function problem_Image()
    eta(x, y) = 1.0
    eta_min = 1.0
    eta_max= 1.0
    eta_ref = sqrt(eta_min * eta_max)

    f_vy(x,y) = -10.0 * abs(image_value(x,y))
    f_vx(x, y) = 0.0
    f_p(x, y) = 0.0 
    
    vx_ref(x, y) = 0.0
    vy_ref(x, y) = 0.0
    p_ref(x, y) = 0.0
 
    return Problem("Image", stokes_boundary_free_slip, eta, eta_min, eta_max, eta_ref, f_vx, f_vy, f_p, 
                   vx_ref, vy_ref, p_ref, 1.0, 0.0, 0.0)
    
end;

## Refinement functions

In [None]:
max_level = 6

function rf_image(x, y, d, level)::Bool
    min_level = 1;
    if level <= min_level;  return true; end
    if level >= max_level;  return false;  end

    # attempt to refine if all local pixels aren't in a constant band of values 
    mx_min = max(1, Int(ceil(x * (mask_size - 1))))
    my_min = max(1, Int(ceil(y * (mask_size - 1))))

    mx_max = max(1, Int(ceil((x + d) * (mask_size - 1))))
    my_max = max(1, Int(ceil((y + d) * (mask_size - 1))))

    max_local = 0
    min_local = 1.0

    for i = mx_min:mx_max
        for j = my_min:my_max
            val = mask[i,j]
            if val > max_local
                max_local = val
            end
            if val < min_local
                min_local = val
            end
        end
    end

    return max_local - min_local > 0.2 # this threshold could be a function of max level?
end

function rf_demo(x, y, d, level)::Bool
    if level >= max_level; return false; end

    # Refine if the 4 corners and center of an element are not all on one side of a circle
    r = 0.25;  cx = 0.7; cy = 0.4
    inside0 = ((x + d/2 - cx)^2 + (y + d/2 - cy)^2 < r^2)
    for (px, py) in ((x, y), (x + d, y), (x, y + d), (x + d, y + d))
        if ((px - cx)^2 + (py - cy)^2 < r^2) != inside0
            return true
        end
    end
    return false
end;

function rf_uniform(x, y, d, level)::Bool
    return level >= max_level
end;


In [None]:
#refinement_function = rf_demo;
refinement_function = rf_image;

## Create our tree object

In [None]:
tree = CreateSimpleTreeMesh(refinement_function);

## Plot to see the grid

In [None]:
start = time()
plot_size = 600
p_grid = plot(axis=nothing, xaxis=false, yaxis=false, size=(plot_size, plot_size), aspect_ratio=1)
plot_grid!(p_grid, tree, RGB(0.5, 0.5, 0.5))
plot_element_numbers!(p_grid, tree) 
plot_face_numbers!(p_grid, tree, offset=tree.ne) # Number faces after elements, as in our Stokes system
plot_corner_numbers!(p_grid, tree) # A separate numbering system
display(p_grid)
@printf "⏲️ %2.2f s\n" time()- start

## Solve the Stokes system

In [None]:
start = time()
tree = CreateSimpleTreeMesh(refinement_function)
@printf "⚙️ %d elements, %d faces, %d corners " tree.ne tree.nf tree.nc
@printf "⏲️ %2.2f s\n" time() - start

In [None]:
start = time()
h_min = 2.0^(-(max_level))
problem = problem_Image()
A, b, Kcont = assemble_system(tree, problem, h_min)
x = A\b
v, p = sol2vp(tree, x, Kcont)
@printf "⏲️ %2.2f s\n" time() - start

In [None]:
start = time()
plot_size = 800
p_stokes = plot(axis=nothing, xaxis=false, yaxis=false, size=(plot_size, plot_size), aspect_ratio=1)
plot_element_field!(p_stokes, tree, p)
c = RGB(0.25, 0.25, 0.25)
plot_grid!(p_stokes, tree, c)
plot_averaged_velocity_field!(p_stokes, tree, v, color=c)
display(p_stokes)
@printf "⏲️ %2.2f s\n" time() - start