In [1]:
import pyvista

pyvista.set_jupyter_backend("client")
pyvista.start_xvfb()

In [2]:
import petsc4py
petsc4py.init(['-saws_port_auto_select'])
from petsc4py import PETSc

import os

os.environ["JAX_PLATFORM"] = "cpu"
os.environ["JAX_CACHE_DIR"] = "/cluster/scratch/mpundir/jax-cache-gpu"
os.environ["ROOT_LIB_PATH"] = "/cluster/home/mpundir/studies/corrosion-fracture/"
os.environ["SPECTRAL_LIB_PATH"] = "/cluster/home/mpundir/dev/spectralsolvers/"
os.environ["PLOT_LIB_PATH"] = "/cluster/home/mpundir/dev"
os.environ["PYVISTA_CACHE_DIR"] = "/cluster/scratch/mpundir/"


In [3]:
import jax

jax.config.update("jax_compilation_cache_dir", os.environ["JAX_CACHE_DIR"])
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)  # use double-precision
jax.config.update("jax_platforms", os.environ["JAX_PLATFORM"])
from functools import partial

import cmcrameri.cm as cmc
import numpy as np

import sys


sys.path.append(os.environ["PLOT_LIB_PATH"])
from plottwist.decorators import plot, subplots, imshow, imshow_grid

sys.path.append(os.environ["SPECTRAL_LIB_PATH"])
from functools import partial
from spectralsolvers.operators import tensor, fourier_galerkin
from spectralsolvers.fft.transform import _fft, _ifft
#from spectralsolvers.solvers.linear import conjugate_gradient
#from spectralsolvers.solvers.nonlinear import newton_krylov_solver


sys.path.append(os.environ["ROOT_LIB_PATH"])
import structure_helpers as _structure
import math

In [4]:
jax.devices()

[CpuDevice(id=0)]

In [5]:
def generateGaussian(shape, length, correlation, random_seed=2021):
    """Generate Gaussian random field for pore structure generation"""

    size = shape[0]
    dim = len(shape)

    x = np.linspace(-length / 2, length / 2, num=size)

    if dim == 3:
        pos_grid = np.meshgrid(x, x, x)
    else:
        pos_grid = np.meshgrid(x, x)

    positions = np.zeros((size**dim, dim))
    for i in range(dim):
        positions[:, i] = pos_grid[i].flatten()

    kernel = np.zeros((size**dim, 1))

    print("Computing Gaussian Kernel")
    lambda_mat = np.eye(dim) * correlation
    _structure.pore_helper.computeGaussianKernel(lambda_mat, kernel, positions)
    kernel = kernel.reshape((size,) * dim)

    print("Performing FFT")
    np.random.seed(random_seed)
    random_field = np.random.normal(size=(size,) * dim)

    T = jnp.multiply(fft(random_field), fft(kernel))

    print("Generating Gaussian Field")
    T = ifft(T)
    T = jnp.real(T)
    return T

In [6]:
N = 63
ndim = 3
length = 1
dx = length / N

correlation = 0.1
porosity = 0.2

grid_size = (N,) * ndim
elasticity_dof_shape = (ndim, ndim) + grid_size


Spatial and FFT operators that we need for solving of frature problem. 

In [7]:
fft = jax.jit(partial(_fft, N=N, ndim=ndim))
ifft = jax.jit(partial(_ifft, N=N, ndim=ndim))

In [8]:
field = generateGaussian(grid_size, length, correlation, random_seed=30)
field -= np.mean(field)  # to make mean 0
# determining the threshold value
std = np.std(field)
threshold_value = np.sqrt(2) * std * math.erfc(1 - porosity)

structure = np.where(np.abs(field) < threshold_value, 1., 0.)
# pore = np.ones_like(pore)

Computing Gaussian Kernel
Performing FFT
Generating Gaussian Field


In [9]:
# Create the spatial reference
grid = pyvista.ImageData(origin=(0, 0, 0))

# Set the grid dimensions: shape + 1 because we want to inject our values on
#   the CELL data
grid.dimensions = np.array(structure.shape) + 1

# Add the data values to the cell data
grid.cell_data["values"] = structure.flatten(order="F")  # Flatten the array

# grid = grid.cast_to_structured_grid()


pl = pyvista.Plotter(shape=(1, 1))

_ = pl.add_volume(grid, scalars="values", cmap=cmc.berlin, opacity_unit_distance=0.1)
'''_ = pl.add_mesh(
    grid,
    color="black",
    scalars="values",
    cmap=cmc.berlin,
    show_edges=False,
)'''

pl.export_html(os.environ["PYVISTA_CACHE_DIR"] + "/pv.html")


In [10]:
# material parameters + function to convert to grid of scalars
@partial(jax.jit, static_argnames=['inclusion', 'solid'])
def param(X, inclusion, solid):
    props = inclusion*jnp.ones_like(X)*(1-X) + solid*jnp.ones_like(X)*(X)
    return props


In [11]:
# material parameters
phase_contrast = 1/1e3

# lames constant
lambda_modulus = {"solid": 1.0, "inclusion": phase_contrast}
shear_modulus = {"solid": 1.0, "inclusion": phase_contrast}

bulk_modulus = {}
bulk_modulus["solid"] = lambda_modulus["solid"] + 2 * shear_modulus["solid"] / 3
bulk_modulus["inclusion"] = lambda_modulus["inclusion"] + 2 * shear_modulus["inclusion"] / 3

In [12]:
λ0 = param(
    structure, inclusion=lambda_modulus["inclusion"], solid=lambda_modulus["solid"]
)  # lame parameter
μ0 = param(
    structure, inclusion=shear_modulus["inclusion"], solid=shear_modulus["solid"]
)  # lame parameter
K0 = param(structure, inclusion=bulk_modulus["inclusion"], solid=bulk_modulus["solid"])

In [13]:
# Create the spatial reference
grid = pyvista.ImageData(origin=(0, 0, 0))

# Set the grid dimensions: shape + 1 because we want to inject our values on
#   the CELL data
grid.dimensions = np.array(λ0.shape) + 1

# Add the data values to the cell data
grid.cell_data["values"] = λ0.flatten(order="F")  # Flatten the array

# grid = grid.cast_to_structured_grid()


pl = pyvista.Plotter(shape=(1, 1))

_ = pl.add_volume(grid, scalars="values", cmap=cmc.berlin)
'''_ = pl.add_mesh(
    grid,
    color="black",
    scalars="values",
    cmap=cmc.berlin,
    show_edges=False,
)'''

pl.export_html(os.environ["PYVISTA_CACHE_DIR"] + "/pv.html")


In [14]:
length = 1.  # m

In [15]:
@jax.jit
def strain_energy(eps):
    eps_sym = 0.5*(eps + tensor.trans2(eps))
    energy = 0.5 * jnp.multiply(λ0, tensor.trace2(eps_sym)**2) + jnp.multiply(μ0, tensor.trace2(tensor.dot22(eps_sym, eps_sym)))
    return energy.sum()


sigma = jax.jit(jax.jacrev(strain_energy))

The $\sigma$ is calculated using automatic differentiation.

We also define functions to projection operator $\mathbb{G}$ that is used to project test function $\delta \varepsilon$ to a compatible space. 

In [16]:
Ghat = fourier_galerkin.compute_projection_operator(
    grid_size=grid_size, operator="rotated-difference", length=length
)

# functions for the projection 'G', and the product 'G : K : eps'
@jax.jit
def G(A2):
    return jnp.real(ifft(tensor.ddot42(Ghat, fft(A2)))).reshape(-1)



@jax.jit
def G_K_deps(depsm, additionals=None):
    depsm = depsm.reshape(elasticity_dof_shape)
    return G(sigma(depsm))



In [None]:
@partial(jax.jit, static_argnames=["A", "max_iter"])
def conjugate_gradient_solver(b, additionals, A, atol, max_iter):
    x = jnp.full_like(b, fill_value=0.0)

    r = b - A(x, additionals)
    p = r
    rsold = jnp.vdot(r, r)

    state = (b, p, r, rsold, x)

    def conjugate_gradient(state, n):
        b, p, r, rsold, x = state

        def true_fun(state):
            b, p, r, rsold, x = state
            Ap = A(p, additionals)
            alpha = rsold / jnp.vdot(p, Ap)
            x = x + jnp.dot(alpha, p)
            r = r - jnp.dot(alpha, Ap)
            rsnew = jnp.vdot(r, r)
            p = r + (rsnew / rsold) * p
            rsold = rsnew
            return (b, p, r, rsold, x)

        def false_fun(state):
            return state

        return (
            jax.lax.cond(
                jnp.sqrt(rsold) > atol,
                true_fun,
                false_fun,
                state,
            ),
            n,
        )

    final_state, xs = jax.lax.scan(conjugate_gradient, init=state, xs=jnp.arange(0, max_iter))

    return final_state[-1], xs

In [None]:
@partial(
    jax.jit,
    static_argnames=[
        "A",
        "krylov_solver",
        "tol",
        "max_iter",
        "krylov_tol",
        "krylov_max_iter",
    ],
)
def newton_krylov_solver(
    state, A, additionals, tol, max_iter, krylov_solver, krylov_tol, krylov_max_iter
):

    def newton_raphson(state, n):
        dF, b, F = state
        error = jnp.linalg.norm(b)
        # jnp.linalg.norm(dF) / jnp.linalg.norm(F)
        # jax.debug.print("residual={}", error)

        def true_fun(state):
            dF, b, F = state

            dF, iiter = krylov_solver(
                A=A,
                b=b,
                atol=krylov_tol,
                max_iter=krylov_max_iter,
                additionals=additionals,
            )  # solve linear system
            
            dF = dF.reshape(F.shape)
            F = jax.lax.add(F, dF)
            b = -A(F, additionals)  # compute residual

            return (dF, b, F)

        def false_fun(state):
            return state

        return jax.lax.cond(error > tol, true_fun, false_fun, state), n

    final_state, xs = jax.lax.scan(
        newton_raphson, init=state, xs=jnp.arange(0, max_iter)
    )

    def not_converged(residual):
        jax.debug.print("Didnot converge, Residual value : {}", residual)
        return residual

    def converged(residual):
        jax.debug.print("Converged, Residual value : {}", residual)
        return residual

    residual = jnp.linalg.norm(final_state[1])
    jax.lax.cond(residual > tol, not_converged, converged, residual)

    return final_state


In [19]:
eps = jnp.zeros(elasticity_dof_shape)
deps = jnp.zeros(elasticity_dof_shape)


In [20]:
deps = deps.at[0, 0].set(0.001)
b = -G_K_deps(deps, None)

eps = jax.lax.add(eps, deps)

In [22]:
trial  = conjugate_gradient_solver(
    A=G_K_deps,
    b=b,
    atol=1e-8,
    max_iter=2,
    additionals=None,
)

In [23]:
trial

Array([ 2.71383408e-04,  2.79761686e-04,  2.12242318e-04, ...,
       -1.05219674e-05, -1.46095719e-05, -2.16487658e-05], dtype=float64)

In [25]:
final_state = newton_krylov_solver(
    state=(deps, b, eps),
    A=G_K_deps,
    tol=1e-6,
    max_iter=20,
    krylov_solver=conjugate_gradient_solver,
    krylov_tol=1e-6,
    krylov_max_iter=20,
    additionals=None,
)

Converged, Residual value : 9.993733746837954e-07


In [26]:
stress00 = sigma(final_state[-1])[0, 0]

In [28]:
# Create the spatial reference
grid = pyvista.ImageData(origin=(0, 0, 0))

# Set the grid dimensions: shape + 1 because we want to inject our values on
#   the CELL data
grid.dimensions = np.array(stress00.shape) + 1

# Add the data values to the cell data
grid.cell_data["stress"] = stress00.flatten(order="F")  # Flatten the array
grid.cell_data["structure"] = structure.flatten(order="F")  # Flatten the array

threshed = grid.threshold(value=(0.5), scalars='structure')


pl = pyvista.Plotter(shape=(1, 1))

_ = pl.add_volume(threshed, scalars="stress", cmap=cmc.roma)
'''_ = pl.add_mesh(
    threshed,
    color="black",
    scalars="stress",
    cmap=cmc.managua,
    show_edges=True,
    show_scalar_bar=True,
)'''

pl.export_html(os.environ["PYVISTA_CACHE_DIR"] + "/pv.html")