In [None]:
import sys
sys.path.append('..')

In [None]:
import planetengine
from planetengine import quickShow

In [None]:
import underworld as uw
from underworld import function as fn
import math
import numpy as np

from planetengine.utilities import Grouper

In [None]:
def build(
        res = 64,
        f = 0.9,
        aspect = 4.,
        periodic = False,
        refVisc = 1.,
        strength = 2.,
        slipperiness = 10.,
        Ra = 1e7,
        depth_density = 1.
        ):

    inputs = locals().copy()
    # script = __file__

    ### MESH & MESH VARIABLES ###

    f = max(0.00001, min(0.99999, f))
    inputs['f'] = f

    length = 1.
    outerRad = 1. / (1. - f)
    radii = (outerRad - length, outerRad)

    maxAspect = math.pi * sum(radii) / length
    aspect = min(aspect, maxAspect)
    inputs['aspect'] = aspect
    if aspect == maxAspect:
        periodic = True
        inputs['periodic'] = periodic

    width = length**2 * aspect * 2. / (radii[1]**2 - radii[0]**2)
    midpoint = math.pi / 2.
    angExtentRaw = (midpoint - 0.5 * width, midpoint + 0.5 * width)
    angExtentDeg = [item * 180. / math.pi for item in angExtentRaw]
    angularExtent = [
        max(0., angExtentDeg[0]),
        min(360., angExtentDeg[1] + abs(min(0., angExtentDeg[0])))
        ]
    angLen = angExtentRaw[1] - angExtentRaw[0]

    radRes = max(4, int(res / 4) * 4)
    inputs['res'] = radRes
    angRes = 4 * int(angLen * (int(radRes * radii[1] / length)) / 4)
    elementRes = (radRes, angRes)

    mesh = uw.mesh.FeMesh_Annulus(
        elementRes = elementRes,
        radialLengths = radii,
        angularExtent = angularExtent,
        periodic = [False, periodic]
        )

    ### VARIABLES ###

    pressureField = uw.mesh.MeshVariable(mesh.subMesh, 1)
    velocityField = uw.mesh.MeshVariable(mesh, 2)
    lithoField = uw.mesh.MeshVariable(mesh, 1)
    lithoDotField = uw.mesh.MeshVariable(mesh, 1)
    lithoVelField = uw.mesh.MeshVariable(mesh, 2)

    ### BOUNDARIES ###

    inner = mesh.specialSets["inner"]
    outer = mesh.specialSets["outer"]
    sides = mesh.specialSets["MaxJ_VertexSet"] + mesh.specialSets["MinJ_VertexSet"]

    if periodic:
        velBC = uw.conditions.RotatedDirichletCondition(
            variable = velocityField,
            indexSetsPerDof = (inner + outer, None),
            basis_vectors = (mesh.bnd_vec_normal, mesh.bnd_vec_tangent)
            )
    else:
        velBC = uw.conditions.RotatedDirichletCondition(
            variable = velocityField,
            indexSetsPerDof = (inner + outer, sides),
            basis_vectors = (mesh.bnd_vec_normal, mesh.bnd_vec_tangent)
            )

    lithoBC = uw.conditions.DirichletCondition(
        variable = lithoField,
        indexSetsPerDof = (inner + outer,)
        )

    ### SPECIAL ###

    # Needed for annulus to work properly
    vc = uw.mesh.MeshVariable(mesh = mesh, nodeDofCount = 2)
    vc_eqNum = uw.systems.sle.EqNumber(vc, False )
    vcVec = uw.systems.sle.SolutionVector(vc, vc_eqNum)

    ### FUNCTIONS ###

    depthFn = mesh.radialLengths[1] - mesh.radiusFn

    angExtentsRadians = np.array(mesh.angularExtent) * np.pi / 180.
    radWidth = angExtentsRadians[1] - angExtentsRadians[0]
    xFn, yFn = fn.input()[0], fn.input()[1]
    magnitudeFn = fn.math.sqrt(fn.math.pow(xFn, 2) + fn.math.pow(yFn, 2))
    rawAngFn = angFn = - 2. * fn.math.atan(yFn / ((magnitudeFn) + xFn))
    angFn = (rawAngFn + angExtentsRadians[1]) / radWidth

    angMag = fn.math.dot(mesh.unitvec_theta_Fn, vc)
    radMag = fn.math.dot(mesh.unitvec_r_Fn, vc)

#     coolingFn = fn.branching.conditional([
#         (radMag / angMag > 1e3 / Ra, 0.),
#         (True, 1.)
#         ])
    coolingFn = angFn

    lithoVelFn = angMag * mesh.unitvec_theta_Fn + (radMag - coolingFn) * mesh.unitvec_r_Fn

    depthCorrection = fn.branching.conditional([
        (depthFn < 0.8, lithoField),
        (True, lithoField * (depthFn - 0.8) / 0.2),
        ])
    limitCorrection = fn.branching.conditional([
        (lithoField < 0.01, 0.),
        (True, lithoField),
        ])

    def update_litho():
        lithoField.data[:] = np.clip(lithoField.data, 0., 1.)
        lithoField.data[:] = depthCorrection.evaluate(mesh)
        lithoField.data[:] = limitCorrection.evaluate(mesh)
        lithoField.data[:] = np.clip(lithoField.data, 0., 1.)
    #     lithoField.data[:] = np.round(lithoField.data * 2.) / 2.
        lithoVelField.data[:] = lithoVelFn.evaluate(mesh)
    #     materialVar.data[:] = lithoField.evaluate(mesh)

    materialFn = fn.branching.conditional([
        (lithoField < 0.4, 0),
        (lithoField > 0.6, 2),
        (True, 1),
        ])

    matDensityFn = fn.branching.map(
        fn_key = materialFn,
        mapping = {
            0: 1.,
            1: 1.,
            2: 2.,
            }
        )

    densityFn = Ra * matDensityFn * (depthFn * depth_density + 1.)

    ## RHEOLOGY ###

    matViscFn = fn.branching.map(
        fn_key = materialFn,
        mapping = {
            0: 1.,
            1: 1. / slipperiness,
            2: strength,
            }
        )

    viscosityFn = refVisc * matViscFn

    ### SYSTEMS ###

    stokes = uw.systems.Stokes(
        velocityField = velocityField,
        pressureField = pressureField,
        conditions = [velBC,],
        fn_viscosity = viscosityFn,
        fn_bodyforce = -densityFn * mesh.unitvec_r_Fn,
        _removeBCs = False,
        )

    solver = uw.systems.Solver(stokes)

    advDiff = uw.systems.AdvectionDiffusion(
        phiField = lithoField,
        phiDotField = lithoDotField,
        velocityField = lithoVelField,
        fn_diffusivity = 0.,
        fn_sourceTerm = 0.,
        conditions = [lithoBC,]
        )

    step = fn.misc.constant(0)
    modeltime = fn.misc.constant(0.)

    ### SOLVING ###

    def postSolve():
        # realign solution using the rotation matrix on stokes
        uw.libUnderworld.Underworld.AXequalsY(
            stokes._rot._cself,
            stokes._velocitySol._cself,
            vcVec._cself,
            False
            )
        # remove null space - the solid body rotation velocity contribution
        uw.libUnderworld.StgFEM.SolutionVector_RemoveVectorSpace(
            stokes._velocitySol._cself, 
            stokes._vnsVec._cself
            )
    #     update_litho()

    def solve():
        update_litho()
        velocityField.data[:] = 0.
        solver.solve(
            nonLinearIterate = False,
            callback_post_solve = postSolve,
            )
        uw.libUnderworld.Underworld.AXequalsX(
            stokes._rot._cself,
            stokes._velocitySol._cself,
            False
            )

    def integrate():
        dt = advDiff.get_max_dt()
        advDiff.integrate(dt)
        modeltime.value += dt
        step.value += 1

    def iterate():
        integrate()
        solve()

    ### HOUSEKEEPING: IMPORTANT! ###

    varsOfState = {'lithoField': lithoField}

    return Grouper(locals())

In [None]:
system = build(res = 32, Ra = 1e6, refVisc = 1., strength = 1e4, aspect = 2.)

In [None]:
# quick and dirty initialise:
boxDims = ((0., 1.),) * system.mesh.dim
box = planetengine.mapping.box(system.mesh, system.mesh.data, boxDims)
system.lithoField.data[:] = (1. - planetengine.initials.sinusoidal.IC(freq = 1., pert = 0.3).evaluate(box))**20.

In [None]:
system.solve()

In [None]:
show = lambda: quickShow(
#     system.lithoField,
    system.densityFn,
    system.velocityField,
    system.viscosityFn,
    system.materialFn,
    resolution = (2048, 1024)
    )

In [None]:
show()

In [None]:
for i in range(100):
    system.iterate()
    print("Iteration " + str(i))
    if i % 10 == 0:
        show()
print("Done!")
show()