In [None]:
# This magic makes plots appear in the browser
%matplotlib inline
import matplotlib.pyplot as plt

# Load Firedrake on Colab
try:
    import firedrake
except ImportError:
    !wget "https://github.com/thwaitesproject/tutorials/releases/latest/download/firedrake-install-real.sh" -O "/tmp/firedrake-install.sh" && bash "/tmp/firedrake-install.sh"
    import firedrake

try: 
    import thwaites
except:
    !pip install git+https://github.com/thwaitesproject/thwaites
    import thwaites



In [None]:
from thwaites import *
from thwaites.utility import get_top_boundary, cavity_thickness
from firedrake.petsc import PETSc
from firedrake import FacetNormal
import numpy as np

In [None]:
#  Generate mesh
L = 10E3
H1 = 2.
H2 = 102.
dy = 50.0
ny = round(L/dy)
#nz = 50
dz = 1.0


try: 
    # create mesh
    mesh = Mesh("./coarse.msh")
except:
    # load mesh
    !wget https://raw.githubusercontent.com/thwaitesproject/tutorials/main/coarse.msh
    mesh = Mesh("./coarse.msh")
    
PETSc.Sys.Print("Mesh dimension ", mesh.geometric_dimension())


# shift z = 0 to surface of ocean. N.b z = 0 is outside domain.
PETSc.Sys.Print("Length of lhs", assemble(Constant(1.0)*ds(1, domain=mesh)))

PETSc.Sys.Print("Length of rhs", assemble(Constant(1.0)*ds(2, domain=mesh)))

PETSc.Sys.Print("Length of bottom", assemble(Constant(1.0)*ds(3, domain=mesh)))

PETSc.Sys.Print("Length of top", assemble(Constant(1.0)*ds(4, domain=mesh)))


water_depth = 600.0
mesh.coordinates.dat.data[:, 1] -= water_depth


print("You have Comm WORLD size = ", mesh.comm.size)
print("You have Comm WORLD rank = ", mesh.comm.rank)

y, z = SpatialCoordinate(mesh)


In [None]:
# Set up function spaces
V = VectorFunctionSpace(mesh, "DG", 1)  # velocity space
W = FunctionSpace(mesh, "CG", 2)  # pressure space
M = MixedFunctionSpace([V, W])

# u velocity function space.
U = FunctionSpace(mesh, "DG", 1)

Q = FunctionSpace(mesh, "DG", 1)  # melt function space
K = FunctionSpace(mesh, "DG", 1)    # temperature space
S = FunctionSpace(mesh, "DG", 1)    # salinity space

##########

# Set up functions
m = Function(M)
v_, p_ = m.split()  # function: y component of velocity, pressure
v, p = split(m)  # expression: y component of velocity, pressure
v_._name = "v_velocity"
p_._name = "perturbation pressure"
#u = Function(U, name="x velocity")  # x component of velocity

rho = Function(K, name="density")
temp = Function(K, name="temperature")
sal = Function(S, name="salinity")
melt = Function(Q, name="melt rate")
Q_mixed = Function(Q, name="ocean heat flux")
Q_ice = Function(Q, name="ice heat flux")
Q_latent = Function(Q, name="latent heat")
Q_s = Function(Q, name="ocean salt flux")
Tb = Function(Q, name="boundary freezing temperature")
Sb = Function(Q, name="boundary salinity")
full_pressure = Function(M.sub(1), name="full pressure")

##########

In [None]:
dump_file = "/data/2d_mitgcm_comparison/14.04.20_3_eq_param_ufricHJ99_dt30.0_dtOutput30.0_T900.0_ip50.0_tres86400.0_Kh0.001_Kv0.0001_dy50_dz1_closed_iterative/dump_step_30.h5"

DUMP = False
if DUMP:
    with DumbCheckpoint(dump_file, mode=FILE_UPDATE) as chk:
        # Checkpoint file open for reading and writing
        chk.load(v_, name="v_velocity")
        chk.load(p_, name="perturbation_pressure")
        #chk.load(u, name="u_velocity")
        chk.load(sal, name="salinity")
        chk.load(temp, name="temperature")

        # from holland et al 2008b. constant T below 200m depth. varying sal.
        T_200m_depth = 1.0

        S_200m_depth = 34.4
        #S_bottom = 34.8
        #salinity_gradient = (S_bottom - S_200m_depth) / -H2
        #S_surface = S_200m_depth - (salinity_gradient * (H2 - water_depth))  # projected linear slope to surface.

        T_restore = Constant(T_200m_depth)
        S_restore = Constant(S_200m_depth) #S_surface + (S_bottom - S_surface) * (z / -water_depth)


else:
    # Assign Initial conditions
    v_init = zero(mesh.geometric_dimension())
    v_.assign(v_init)

    #u_init = Constant(0.0)
    #u.interpolate(u_init)

    # from holland et al 2008b. constant T below 200m depth. varying sal.
    T_200m_depth = 1.0


    #S_bottom = 34.8
    #salinity_gradient = (S_bottom - S_200m_depth) / -H2
    S_surface = 34.4 #S_200m_depth - (salinity_gradient * (H2 - water_depth))  # projected linear slope to surface.

    T_restore = Constant(T_200m_depth)
    S_restore = Constant(S_surface) #S_surface + (S_bottom - S_surface) * (z / -water_depth)

    temp_init = T_restore
    temp.interpolate(temp_init)

    sal_init = Constant(34.4)
    #sal_init = S_restore
    sal.interpolate(sal_init)



In [None]:
# Set up equations
mom_eq = MomentumEquation(M.sub(0), M.sub(0))
cty_eq = ContinuityEquation(M.sub(1), M.sub(1))
#u_eq = ScalarVelocity2halfDEquation(U, U)
temp_eq = ScalarAdvectionDiffusionEquation(K, K)
sal_eq = ScalarAdvectionDiffusionEquation(S, S)


In [None]:
# Terms for equation fields

# momentum source: the buoyancy term Boussinesq approx. From mitgcm default
T_ref = Constant(0.0)
S_ref = Constant(35)
beta_temp = Constant(2.0E-4)
beta_sal = Constant(7.4E-4)
g = Constant(9.81)
mom_source = as_vector((0., -g))*(-beta_temp*(temp - T_ref) + beta_sal * (sal - S_ref)) 

rho0 = 1030.
rho.interpolate(rho0*(1.0-beta_temp * (temp - T_ref) + beta_sal * (sal - S_ref)))
# coriolis frequency f-plane assumption at 75deg S. f = 2 omega sin (lat) = 2 * 7.2921E-5 * sin (-75 *2pi/360)
#f = Constant(-1.409E-4)

kappa_h = Constant(0.25)
kappa_v = Constant(kappa_h/250.)

kappa = as_tensor([[kappa_h, 0], [0, kappa_v]])

kappa_temp = kappa
kappa_sal = kappa
mu = kappa

ip_factor = Constant(50.)
#dt = 1.0
restoring_time = 86400.

In [None]:
# Scalar source/sink terms at open boundary.
absorption_factor = Constant(1.0/restoring_time)
sponge_fraction = 0.06  # fraction of domain where sponge
# Temperature source term
source_temp = conditional(y > (1.0-sponge_fraction) * L,
                          ((y - (1.0-sponge_fraction) * L)/(L * sponge_fraction)) * absorption_factor * T_restore,
                          0.0)

# Salinity source term
source_sal = conditional(y > (1.0-sponge_fraction) * L,
                         ((y - (1.0-sponge_fraction) * L)/(L * sponge_fraction)) * absorption_factor * S_restore,
                         0.0)

# Temperature absorption term
absorp_temp = conditional(y > (1.0-sponge_fraction) * L,
                          ((y - (1.0-sponge_fraction) * L)/(L * sponge_fraction)) * absorption_factor,
                          0.0)

# Salinity absorption term
absorp_sal = conditional(y > (1.0-sponge_fraction) * L,
                         ((y - (1.0-sponge_fraction) * L)/(L * sponge_fraction)) * absorption_factor,
                         0.0)



In [None]:
ip_alpha = Constant(3*dy/dz*2*ip_factor)
# Equation fields
vp_coupling = [{'pressure': 1}, {'velocity': 0}]
vp_fields = {'viscosity': mu, 'source': mom_source, 'interior_penalty': ip_alpha}
#u_fields = {'diffusivity': mu, 'velocity': v, 'interior_penalty': ip_alpha, 'coriolis_frequency': f}
temp_fields = {'diffusivity': kappa_temp, 'velocity': v, 'interior_penalty': ip_alpha, 'source': source_temp,
               'absorption coefficient': absorp_temp}
sal_fields = {'diffusivity': kappa_sal, 'velocity': v, 'interior_penalty': ip_alpha, 'source': source_sal,
              'absorption coefficient': absorp_sal}


In [None]:
# Get expressions used in melt rate parameterisation
mp = ThreeEqMeltRateParam(sal, temp, p, z, velocity=pow(dot(v, v), 0.5), HJ99Gamma=True)


In [None]:
# Boundary conditions
# top boundary: no normal flow, drag flowing over ice
# bottom boundary: no normal flow, drag flowing over bedrock
# grounding line wall (LHS): no normal flow
# open ocean (RHS): pressure to account for density differences

# WEAKLY Enforced BCs
n = FacetNormal(mesh)
Temperature_term = -beta_temp * ((T_restore-T_ref) * z)
Salinity_term = beta_sal * ((S_restore - S_ref) * z) # ((S_bottom - S_surface) * (pow(z, 2) / (-2.0*water_depth)) + (S_surface-S_ref) * z)
stress_open_boundary = -n*-g*(Temperature_term + Salinity_term)
no_normal_flow = 0.
ice_drag = 0.0097


In [None]:
vp_bcs = {4: {'un': no_normal_flow, 'drag': ice_drag}, 2: {'un': no_normal_flow}, 
          3: {'un': no_normal_flow, 'drag': 0.0025}, 1: {'un': no_normal_flow}}
#u_bcs = {2: {'q': Constant(0.0)}}

temp_bcs = {4: {'flux': -mp.T_flux_bc}}

sal_bcs = {4: {'flux': -mp.S_flux_bc}}


In [None]:
# STRONGLY Enforced BCs
# open ocean (RHS): no tangential flow because viscosity of outside ocean resists vertical flow.
strong_bcs = []#DirichletBC(M.sub(0).sub(1), 0, 2)]

##########

# Solver parameters
mumps_solver_parameters = {
    'snes_monitor': None,
    'snes_type': 'ksponly',
    'ksp_type': 'preonly',
    'pc_type': 'lu',
    'pc_factor_mat_solver_type': 'mumps',
    'mat_type': 'aij',
    'snes_max_it': 100,
    'snes_rtol': 1e-8,
    'snes_atol': 1e-6,
}

pressure_projection_solver_parameters = {
        'snes_type': 'ksponly',
        'ksp_type': 'preonly',  # we solve the full schur complement exactly, so no need for outer krylov
        'mat_type': 'matfree',
        'pc_type': 'fieldsplit',
        'pc_fieldsplit_type': 'schur',
        'pc_fieldsplit_schur_fact_type': 'full',
        # velocity mass block:
        'fieldsplit_0': {
            'ksp_type': 'gmres',
            'pc_type': 'python',
            'pc_python_type': 'firedrake.AssembledPC',
            'ksp_converged_reason': None,
            'assembled_ksp_type': 'preonly',
            'assembled_pc_type': 'bjacobi',
            'assembled_sub_pc_type': 'ilu',
            },
        # schur system: explicitly assemble the schur system
        # this only works with pressureprojectionicard if the velocity block is just the mass matrix
        # and if the velocity is DG so that this mass matrix can be inverted explicitly
        'fieldsplit_1': {
            'ksp_type': 'preonly',
            'pc_type': 'python',
            'pc_python_type': 'thwaites.AssembledSchurPC',
            'schur_ksp_type': 'cg',
            'schur_ksp_max_it': 1000,
            'schur_ksp_rtol': 1e-7,
            'schur_ksp_atol': 1e-9,
            'schur_ksp_converged_reason': None,
            'schur_pc_type': 'gamg',
            'schur_pc_gamg_threshold': 0.01
            },
        }

vp_solver_parameters = pressure_projection_solver_parameters
u_solver_parameters = mumps_solver_parameters
temp_solver_parameters = mumps_solver_parameters
sal_solver_parameters = mumps_solver_parameters


In [None]:
# define time steps
dt = 300
T = 86400*1  # run for 1 day, steady state by ~40 days
output_dt = 3600*6  # output every 6 hours
output_step = output_dt/dt


In [None]:
# Set up time stepping routines

vp_timestepper = PressureProjectionTimeIntegrator([mom_eq, cty_eq], m, vp_fields, vp_coupling, dt, vp_bcs,
                                                          solver_parameters=vp_solver_parameters,
                                                          predictor_solver_parameters=u_solver_parameters,
                                                          picard_iterations=1,
                                                          pressure_nullspace=VectorSpaceBasis(constant=True))

# performs pseudo timestep to get good initial pressure
# this is to avoid inconsistencies in terms (viscosity and advection) that
# are meant to decouple from pressure projection, but won't if pressure is not initialised
# do this here, so we can see the initial pressure in pressure_0.pvtu
if not DUMP:
    # should not be done when picking up
    vp_timestepper.initialize_pressure()

#u_timestepper = DIRK33(u_eq, u, u_fields, dt, u_bcs, solver_parameters=u_solver_parameters)
temp_timestepper = DIRK33(temp_eq, temp, temp_fields, dt, temp_bcs, solver_parameters=temp_solver_parameters)
sal_timestepper = DIRK33(sal_eq, sal, sal_fields, dt, sal_bcs, solver_parameters=sal_solver_parameters)


In [None]:
# Set up folder
folder = "./ice_pump/"
# Output files for velocity, pressure, temperature and salinity
v_file = File(folder+"vw_velocity.pvd")
v_file.write(v_)

p_file = File(folder+"pressure.pvd")
p_file.write(p_)

#u_file = File(folder+"u_velocity.pvd")
#u_file.write(u)

t_file = File(folder+"temperature.pvd")
t_file.write(temp)

s_file = File(folder+"salinity.pvd")
s_file.write(sal)

rho_file = File(folder+"density.pvd")
rho_file.write(rho)

m_file = File(folder+"melt.pvd")
m_file.write(melt)

In [None]:
# Begin time stepping
t = 0.0
step = 0

while t < T - 0.5*dt:
    vp_timestepper.advance(t)
    temp_timestepper.advance(t)
    sal_timestepper.advance(t)
        #u_timestepper.advance(t)
  
    step += 1
    t += dt

    
    if step % output_step == 0:
        # dumb checkpoint for starting from last timestep reached
        with DumbCheckpoint(folder+"dump.h5", mode=FILE_UPDATE) as chk:
            # Checkpoint file open for reading and writing
            chk.store(v_, name="v_velocity")
            chk.store(p_, name="perturbation_pressure")
            #chk.store(u, name="u_velocity")
            chk.store(temp, name="temperature")
            chk.store(sal, name="salinity")

        # Update melt rate functions
        Q_ice.interpolate(mp.Q_ice)
        Q_mixed.interpolate(mp.Q_mixed)
        Q_latent.interpolate(mp.Q_latent)
        Q_s.interpolate(mp.S_flux_bc)
        melt.interpolate(mp.wb)
        Tb.interpolate(mp.Tb)
        Sb.interpolate(mp.Sb)
        full_pressure.interpolate(mp.P_full)

        # Update density for plotting
        rho.interpolate(rho0*(1.0-beta_temp * (temp - T_ref) + beta_sal * (sal - S_ref)))
        
        
        
         # Write out files
        v_file.write(v_)
        p_file.write(p_)
        #u_file.write(u)
        t_file.write(temp)
        s_file.write(sal)
        rho_file.write(rho)
           
        # Write melt rate functions
        m_file.write(melt)
        time_str = str(step)
        PETSc.Sys.Print("t=", t)
        PETSc.Sys.Print("integrated melt =", assemble(melt * ds(4)))
    if step % (output_step * 24) == 0:
        with DumbCheckpoint(folder+"dump_step_{}.h5".format(step), mode=FILE_CREATE) as chk:
            # Checkpoint file open for reading and writing at regular interval
            chk.store(v_, name="v_velocity")
            chk.store(p_, name="perturbation_pressure")
            #chk.store(u, name="u_velocity")
            chk.store(temp, name="temperature")
            chk.store(sal, name="salinity")


In [None]:
import pyvista as pv

temp_data = pv.read("ice_pump/temperature_4.vtu")

boring_cmap = plt.cm.get_cmap("viridis", 25)
plotter = pv.Plotter(notebook=True)
plotter.add_mesh(temp_data, cmap=boring_cmap)
plotter.camera_position = "xy"
plotter.set_scale(yscale=20)
plotter.show(jupyter_backend="static", interactive=False)