Skip to content

Commit

Permalink
Clean up code generation scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ckhroulev committed Feb 19, 2021
1 parent dd147fe commit 2003cce
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 55 deletions.
16 changes: 13 additions & 3 deletions src/stressbalance/blatter/verification/blatter_codegen.py
Expand Up @@ -8,18 +8,28 @@
{}
}};"""

def code(a, **kwargs):
return sp.ccode(a, standard="c99", **kwargs)

def join(args):
return ", ".join(["double " + x for x in args])

def print_var(var, name):
print(" double " + sp.ccode(var, assign_to=name, standard="c99"))
print(" double " + code(var, assign_to=name))

def print_header(name, args, return_type="Vector2"):
print("")
print((func_template + " {{").format(return_type=return_type,
name=name,
arguments=join(args)))

def print_footer(a, b=None):
if b is not None:
print(return_template.format(code(a), code(b)))
else:
print(" return {};".format(code(a)))
print("}")

def declare(name, args, return_type="Vector2"):
print("")
print((func_template + ";").format(return_type=return_type,
Expand All @@ -38,7 +48,7 @@ def define(f_u, f_v, name, args):
for variable, value in tmps:
print_var(value, variable)

print(return_template.format(sp.ccode(u, standard="c99"),
sp.ccode(v, standard="c99")))
print(return_template.format(code(u),
code(v)))

print("}")
Expand Up @@ -123,7 +123,6 @@ Vector2 blatter_xz_halfar_exact(double x, double z, double H_0, double R_0, doub
double C_2 = (1.0/2.0)*pow(g, 3)*pow(rho_i, 3)/pow(B, 3);
double h0 = C_0*pow(-pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) + 1, 3.0/7.0);
double h_x = -4.0/7.0*C_0*pow(C_1, 4.0/3.0)*cbrt(x)/pow(-pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) + 1, 4.0/7.0);

return {
-C_2*pow(h_x, 3)*(pow(h0, 4) - pow(h0 - z, 4)),
0
Expand All @@ -142,7 +141,6 @@ Vector2 blatter_xz_halfar_source(double x, double z, double H_0, double R_0, dou
double u_xx = 0;
double u_xz = -12*C_2*pow(h_x, 4)*pow(h0 - z, 2) - 12*C_2*pow(h_x, 2)*h_xx*pow(h0 - z, 3);
double u_zz = 12*C_2*pow(h_x, 3)*pow(h0 - z, 2);

return {
2*B*u_x*(-2.0/3.0*u_x*u_xx - 1.0/6.0*u_xz*u_z)/pow(pow(u_x, 2) + (1.0/4.0)*pow(u_z, 2), 4.0/3.0) + 2*B*u_xx/cbrt(pow(u_x, 2) + (1.0/4.0)*pow(u_z, 2)) + (1.0/2.0)*B*u_z*(-2.0/3.0*u_x*u_xz - 1.0/6.0*u_z*u_zz)/pow(pow(u_x, 2) + (1.0/4.0)*pow(u_z, 2), 4.0/3.0) + (1.0/2.0)*B*u_zz/cbrt(pow(u_x, 2) + (1.0/4.0)*pow(u_z, 2)),
0.0
Expand All @@ -158,7 +156,6 @@ Vector2 blatter_xz_halfar_source_lateral(double x, double z, double H_0, double
double h_xx = (4.0/147.0)*C_0*pow(C_1, 4.0/3.0)*(16*pow(C_1, 4.0/3.0)*pow(x, 2.0/3.0)/(pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) - 1) - 7/pow(x, 2.0/3.0))/pow(-pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) + 1, 4.0/7.0);
double u_x = -C_2*pow(h_x, 3)*(4*pow(h0, 3)*h_x - 4*h_x*pow(h0 - z, 3)) - 3*C_2*pow(h_x, 2)*h_xx*(pow(h0, 4) - pow(h0 - z, 4));
double u_z = -4*C_2*pow(h_x, 3)*pow(h0 - z, 3);

return {
2*pow(2, 2.0/3.0)*B*u_x/cbrt(4*pow(u_x, 2) + pow(u_z, 2)),
0.0
Expand All @@ -174,7 +171,6 @@ Vector2 blatter_xz_halfar_source_surface(double x, double H_0, double R_0, doubl
double h_x = -4.0/7.0*C_0*pow(C_1, 4.0/3.0)*cbrt(x)/pow(-pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) + 1, 4.0/7.0);
double h_xx = (4.0/147.0)*C_0*pow(C_1, 4.0/3.0)*(16*pow(C_1, 4.0/3.0)*pow(x, 2.0/3.0)/(pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) - 1) - 7/pow(x, 2.0/3.0))/pow(-pow(C_1, 4.0/3.0)*pow(x, 4.0/3.0) + 1, 4.0/7.0);
double u_x = -C_2*pow(h_x, 3)*(4*pow(h0, 3)*h_x - 4*h_x*pow(h0 - z, 3)) - 3*C_2*pow(h_x, 2)*h_xx*(pow(h0, 4) - pow(h0 - z, 4));

return {
-2*B*h_x*u_x/(sqrt(pow(h_x, 2) + 1)*cbrt(pow(u_x, 2))),
0.0
Expand Down
51 changes: 15 additions & 36 deletions src/stressbalance/blatter/verification/test_xz_halfar.py
Expand Up @@ -2,13 +2,7 @@
from sympy import S

from blatter import x, y, z, B, source_term, eta, M
from blatter_codegen import define, declare

return_template = """
return {{
{},
{}
}};"""
from blatter_codegen import define, declare, print_header, print_var, print_footer

sp.var("R_0 H_0 rho_i g C_0 C_1 C_2", positive=True)
h = sp.Function("h", positive=True)(x)
Expand All @@ -34,7 +28,7 @@
# Glen exponents n
n = 3

def parameters():
def constants():
# s = 1 corresponds to t = t_0
s = 1

Expand Down Expand Up @@ -83,23 +77,11 @@ def print_code(header=False):
declare(name="blatter_xz_halfar_source_surface", args=["x"] + constants)
return

definitions = parameters()

print_exact(coords + constants)
print_source(coords + constants)
print_source_lateral(coords + constants)
print_source_surface(["x"] + constants)

def print_var(var, name):
print(" double " + sp.ccode(var, assign_to=name))

def print_header(suffix, args):
arguments = ", ".join(["double " + x for x in args])

print("")
print("Vector2 blatter_xz_halfar_{suffix}({args}) {{".format(suffix=suffix,
args=arguments))

def print_source_surface(args):
"Print the code computing the extra term at the top surface"

Expand All @@ -113,17 +95,17 @@ def print_source_surface(args):
U_x = u0.diff(x).subs(subs)
U_z = u0.diff(z).subs(subs)

print_header("source_surface", args)
print_header("blatter_xz_halfar_source_surface", args)

for key, value in parameters().items():
for key, value in constants().items():
print_var(value, key)
print_var(H(x), h0)
print_var(h0, z)
print_var(H(x).diff(x), h_x)
print_var(H(x).diff(x, 2), h_xx)
print_var(U_x, u_x)
print(return_template.format(sp.ccode(f_top), 0.0))
print("}")

print_footer(f_top, 0.0)

def print_source_lateral(args):
"Print the code computing the extra term at the right boundary"
Expand All @@ -136,33 +118,31 @@ def print_source_lateral(args):
U_x = u0.diff(x).subs(subs)
U_z = u0.diff(z).subs(subs)

print_header("source_lateral", args)
print_header("blatter_xz_halfar_source_lateral", args)

for key, value in parameters().items():
for key, value in constants().items():
print_var(value, key)
print_var(H(x), h0)
print_var(H(x).diff(x), h_x)
print_var(H(x).diff(x, 2), h_xx)
print_var(U_x, u_x)
print_var(U_z, u_z)

print(return_template.format(sp.ccode(f_lat), 0.0))
print("}")
print_footer(f_lat, 0.0)

def print_exact(args):
u0, v0 = u_exact()

u0 = u0.subs(subs)

print_header("exact", args)
print_header("blatter_xz_halfar_exact", args)

for key, value in parameters().items():
for key, value in constants().items():
print_var(value, key)
print_var(H(x), h0)
print_var(H(x).diff(x), h_x)

print(return_template.format(sp.ccode(u0), v0))
print("}")
print_footer(u0, v0)

def print_source(args):
f, _ = source_term(eta(u, v, 3), u, v)
Expand All @@ -173,9 +153,9 @@ def print_source(args):
U_x = u0.diff(x).subs(subs)
U_z = u0.diff(z).subs(subs)

print_header("source", args)
print_header("blatter_xz_halfar_source", args)

for key, value in parameters().items():
for key, value in constants().items():
print_var(value, key)
print_var(H(x), h0)
print_var(H(x).diff(x), h_x)
Expand All @@ -186,5 +166,4 @@ def print_source(args):
print_var(U_x.diff(z), u_xz)
print_var(U_z.diff(z), u_zz)

print(return_template.format(sp.ccode(f), 0.0))
print("}")
print_footer(f, 0.0)
18 changes: 6 additions & 12 deletions src/stressbalance/blatter/verification/test_xz_vanderveen.py
Expand Up @@ -2,7 +2,7 @@
from sympy import S, Eq, solve

from blatter import x, y, z, B, source_term, eta, M, grad
from blatter_codegen import define, declare, print_header, print_var, return_template
from blatter_codegen import define, declare, print_header, print_var, print_footer

sp.var("alpha H_0 Q_0 rho_i g C", positive=True)

Expand Down Expand Up @@ -122,8 +122,7 @@ def print_exact(args):
print_var(C0(), C)
print_var(H_exact, h0)

print(return_template.format(sp.ccode(u_exact.subs(H, h0)), v_exact))
print("}")
print_footer(u_exact.subs(H, h0), v_exact)

def print_source_lateral(args):
"Print the code computing the extra term at the right boundary"
Expand All @@ -137,9 +136,7 @@ def print_source_lateral(args):
print_var(C0(), C)
print_var(H_exact, h0)

print(return_template.format(sp.ccode(f_lat.subs(H, h0)), 0.0))
print("}")

print_footer(f_lat.subs(H, h0), 0.0)

def print_source_surface(args):
"Print the code computing the extra term at the top surface"
Expand All @@ -153,8 +150,7 @@ def print_source_surface(args):
print_var(C0(), C)
print_var(H_exact, h0)

print(return_template.format(sp.ccode(f_top.subs(H, h0)), 0.0))
print("}")
print_footer(f_top.subs(H, h0), 0.0)

def print_basal_beta(args):
"Print the code computing basal sliding coefficient"
Expand All @@ -168,8 +164,7 @@ def print_basal_beta(args):
print_var(C0(), C)
print_var(H_exact, h0)

print(" return {};".format(sp.ccode(beta.subs(H, h0))))
print("}")
print_footer(beta.subs(H, h0))


def print_thickness(args):
Expand All @@ -179,5 +174,4 @@ def print_thickness(args):

print_var(C0(), C)

print(" return {};".format(sp.ccode(H_exact)))
print("}")
print_footer(H_exact)

0 comments on commit 2003cce

Please sign in to comment.