In [1]:
import sys
print(sys.executable)
# Import some useful modules.
import jax
import jax.numpy as np
import os
from functools import partial
from jax import grad, hessian


# Import JAX-FEM specific modules.
from jax_fem.problem import Problem
from jax_fem.solver import solver
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh



# Define constitutive relationship.
class HyperElasticity(Problem):
    # The function 'get_tensor_map' overrides base class method. Generally, JAX-FEM
    # solves -div(f(u_grad)) = b. Here, we define f(u_grad) = P. Notice how we first
    # define 'psi' (representing W), and then use automatic differentiation (jax.grad)
    # to obtain the 'P_fn' function.
    def get_tensor_map(self):

        def psi(F):
            E = 10.
            nu = 0.3
            mu = E / (2. * (1. + nu))
            kappa = E / (3. * (1. - 2. * nu))
            J = np.linalg.det(F)
            Jinv = J**(-2. / 3.)
            I1 = np.trace(F.T @ F)
            energy = (mu / 2.) * (Jinv * I1 - 3.) + (kappa / 2.) * (J - 1.)**2.
            return energy
        
        
        self.psi = psi

        P_fn = jax.grad(psi)
        self.P_fn = P_fn
        def first_PK_stress(u_grad):
            print(u_grad)
            I = np.eye(self.dim)
            F = u_grad + I
            P = P_fn(F)
            return P
        
        return first_PK_stress

    def total_strain_energyTEMP(self, u):
        energy = 0.0
        u_grad_all = self.fes[0].sol_to_grad(u)  # shape: (num_cells, num_quads, dim, dim)
        
        print("Gradients calculated")
        weights = self.fes[0].JxW                # shape: (num_cells, num_quads)
    
        for cell_idx in range(u_grad_all.shape[0]):
            for q in range(u_grad_all.shape[1]):
                F = u_grad_all[cell_idx, q] + np.eye(self.dim)
                W = self.psi(F)
                energy += W * weights[cell_idx, q]
    
        return energy
    
    
    def total_strain_energy(self, u):
        # Get shape: (num_elements, num_quadrature_points, dim, dim)
        u_grad_all = self.fes[0].sol_to_grad(u)
        JxW = self.fes[0].JxW  # shape: (num_elements, num_quadrature_points)
    
        # Compute F = I + ∇u
        F = u_grad_all + np.eye(self.dim)
    
        # Vectorize psi over quadrature points
        psi_q = jax.vmap(jax.vmap(self.psi))(F)  # shape: (num_elements, num_quadrature_points)
    
        # Integrate energy
        energy = np.sum(psi_q * JxW)
        return energy
    
    
    # Specify mesh-related information (first-order hexahedron element).
ele_type = 'HEX8'
cell_type = get_meshio_cell_type(ele_type)
data_dir = os.path.join(os.path.dirname(__file__), 'data')
Lx, Ly, Lz = 1., 1., 1.
meshio_mesh = box_mesh_gmsh(Nx=20,
                       Ny=20,
                       Nz=20,
                       Lx=Lx,
                       Ly=Ly,
                       Lz=Lz,
                       data_dir=data_dir,
                       ele_type=ele_type)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])


# Define boundary locations.
def left(point):
    return np.isclose(point[0], 0., atol=1e-5)

def right(point):
    return np.isclose(point[0], Lx, atol=1e-5)


# Define Dirichlet boundary values.
def zero_dirichlet_val(point):
    return 0.


def dirichlet_val_x2(point):
    return (0.5 + (point[1] - 0.5) * np.cos(np.pi / 3.) -
            (point[2] - 0.5) * np.sin(np.pi / 3.) - point[1]) / 2.


def dirichlet_val_x3(point):
    return (0.5 + (point[1] - 0.5) * np.sin(np.pi / 3.) +
            (point[2] - 0.5) * np.cos(np.pi / 3.) - point[2]) / 2.


dirichlet_bc_info = [[left] * 3 + [right] * 3, [0, 1, 2] * 2,
                     [zero_dirichlet_val, dirichlet_val_x2, dirichlet_val_x3] +
                     [zero_dirichlet_val] * 3]


# Create an instance of the problem.
problem = HyperElasticity(mesh,
                          vec=3,
                          dim=3,
                          ele_type=ele_type,
                          dirichlet_bc_info=dirichlet_bc_info)

# Solve the defined problem.
sol_list = solver(problem, solver_options={'petsc_solver': {}})


# Store the solution to local file.
vtk_path = os.path.join(data_dir, f'vtk/u.vtu')
save_sol(problem.fes[0], sol_list[0], vtk_path)

u = sol_list[0]
energy = problem.total_strain_energy(u)

@jax.jit
def energy_fn(u):
    return problem.total_strain_energy(u)
energy_grad = grad(energy_fn)(u)

# Store the solution to local file.
vtk_path = os.path.join(data_dir, f'vtk/Jac.vtu')
save_sol(problem.fes[0], energy_grad, vtk_path)

energy_hess = hessian(energy_fn)(u)


/home/samuel/miniconda3/envs/jax_FEM_env_WSL/bin/python
       __       ___      ___   ___                _______  _______ .___  ___. 
      |  |     /   \     \  \ /  /               |   ____||   ____||   \/   | 
      |  |    /  ^  \     \  V  /      ______    |  |__   |  |__   |  \  /  | 
.--.  |  |   /  /_\  \     >   <      |______|   |   __|  |   __|  |  |\/|  | 
|  `--'  |  /  _____  \   /  .  \                |  |     |  |____ |  |  |  | 
 \______/  /__/     \__\ /__/ \__\               |__|     |_______||__|  |__| 
                                                                              



NameError: name '__file__' is not defined