In [None]:
import jax.numpy as np
import matplotlib.pyplot as plt
import matplotlib
font = {'size'   : 16}
matplotlib.rc('font', **font)
from NODE_fns import NODE
from jax import grad, random, jit, partial
from jax.experimental import optimizers
from jax.lax import while_loop
dNODE = grad(NODE)
key = random.PRNGKey(0)

In [None]:
@partial(jit, static_argnums=(3,))
def nvisco(F, C_i_inv, params, dt):
    #Material parameters:
    mu_m = np.array([51.4, -18, 3.86])
    alpha_m = np.array([1.8, -2, 7])
    K_m = 10000
    tau = 17.5
    shear_mod = 1/2*(mu_m[0]*alpha_m[0] + mu_m[1]*alpha_m[1] + mu_m[2]*alpha_m[2])
    eta_D = tau*shear_mod
    eta_V = tau*K_m
    mu = 77.77 #=shear_mod
    K = 10000
    
    
    #Preprocessing
    be_trial = np.dot(F, np.dot(C_i_inv, F.transpose()))
    lamb_e_trial, n_A = np.linalg.eig(be_trial)
    lamb_e_trial = np.sqrt(np.real(lamb_e_trial))
    eps_e_trial = np.log(lamb_e_trial)
    eps_e = eps_e_trial #Initial guess for eps_e
    
    class viscous():
        def __init__(self, eps_e_init, eps_e_trial, params, dt):
            mu_m = np.array([51.4, -18, 3.86])
            alpha_m = np.array([1.8, -2, 7])
            K_m = 10000
            tau = 17.5
            shear_mod = 1/2*(mu_m[0]*alpha_m[0] + mu_m[1]*alpha_m[1] + mu_m[2]*alpha_m[2])
            eta_D = tau*shear_mod
            eta_V = tau*K_m

            self.eps_e     = eps_e_init
            self.eps_e_t   = eps_e_trial
            self.params    = params
            self.dt        = dt
            
            self.mu_m      = mu_m
            self.alpha_m   = alpha_m
            self.K_m       = K_m
            self.tau       = tau
            self.shear_mod = shear_mod
            self.eta_D     = eta_D
            self.eta_V     = eta_V
            
        @partial(jit, static_argnums=(0,))
        def iterate(self, metrics):
            eps_e     = self.eps_e
            eps_e_trial = self.eps_e_t
            params    = self.params
            dt        = self.dt
            
            mu_m      = self.mu_m
            alpha_m   = self.alpha_m
            K_m       = self.K_m
            tau       = self.tau
            shear_mod = self.shear_mod
            eta_D     = self.eta_D
            eta_V     = self.eta_V
            NODE1_params, NODE2_params, NODE3_params, NODE4_params, NODE5_params = params
            normres, itr = metrics

            eps_e = self.eps_e
            lamb_e = np.exp(eps_e)
            Je = lamb_e[0]*lamb_e[1]*lamb_e[2]
            bbar_e = Je**(-2/3)*lamb_e**2 #(54)

            b1 = bbar_e[0]
            b2 = bbar_e[1]
            b3 = bbar_e[2]

            #Calculate K_AB
            ddev11 = 0
            ddev12 = 0
            ddev13 = 0
            ddev22 = 0
            ddev23 = 0
            ddev33 = 0

            for r in range(3):
                e = alpha_m[r]/2
                ddev11 = ddev11 + mu_m[r]*(2*e)*( 4/9*b1**e + 1/9*(b2**e + b3**e)) #(B12)
                ddev22 = ddev22 + mu_m[r]*(2*e)*( 4/9*b2**e + 1/9*(b1**e + b3**e))
                ddev33 = ddev33 + mu_m[r]*(2*e)*( 4/9*b3**e + 1/9*(b1**e + b2**e))

                ddev12 = ddev12 + mu_m[r]*(2*e)*(-2/9*(b1**e + b2**e) + 1/9*b3**e) #(B13)
                ddev13 = ddev13 + mu_m[r]*(2*e)*(-2/9*(b1**e + b3**e) + 1/9*b2**e)
                ddev23 = ddev23 + mu_m[r]*(2*e)*(-2/9*(b2**e + b3**e) + 1/9*b1**e)
            ddev = np.array([[ddev11, ddev12, ddev13],[ddev12, ddev22, ddev23], [ddev13, ddev23, ddev33]])

            alpha_m = self.alpha_m
            mu_m = self.mu_m
            K_m = self.K_m
            b1, b2, b3, _, Je = self.kinematics1()

            devtau1 = 0
            devtau2 = 0
            devtau3 = 0
            for r in range(3):
                e = alpha_m[r]/2
                devtau1 = devtau1 + mu_m[r]*(2/3*b1**e - 1/3*(b2**e + b3**e)) #(B8)
                devtau2 = devtau2 + mu_m[r]*(2/3*b2**e - 1/3*(b1**e + b3**e))
                devtau3 = devtau3 + mu_m[r]*(2/3*b3**e - 1/3*(b1**e + b2**e))

            devtau = np.array([devtau1, devtau2, devtau3])

            tau_NEQdyadicI = 3*K_m/2*(Je**2-1) #(B8)
            tau_A = devtau + 1/3*tau_NEQdyadicI #(B8)
            tau_3, tau_2, tau_1 = np.sort(tau_A)

            dN1 = dNODE(tau_1, NODE1_params)
            dN2 = dNODE(tau_1 + tau_2, NODE2_params)
            dN3 = dNODE(tau_1 + tau_2 + tau_3, NODE3_params)
            dN4 = dNODE(tau_1**2 + tau_2**2 + tau_3**2 + 2*tau_1*tau_2 + 2*tau_1*tau_3 + 2*tau_2*tau_3, NODE4_params)
            dN5 = dNODE(tau_1**2 + tau_2**2 + tau_3**2 -   tau_1*tau_2 -   tau_1*tau_3 -   tau_2*tau_3, NODE5_params)

            d2phid11 = dN1 + dN2 + dN3 + 2*dN4 + 2*dN5 #d^2phi/dtau1 dtau1
            d2phid22 =       dN2 + dN3 + 2*dN4 + 2*dN5
            d2phid33 =             dN3 + 2*dN4 + 2*dN5

            d2phid12 =       dN2 + dN3 + 2*dN4 - dN5
            d2phid13 =             dN3 + 2*dN4 - dN5
            d2phid23 =             dN3 + 2*dN4 - dN5

            d2phid2tau = np.array([[d2phid11, d2phid12, d2phid13], [d2phid12, d2phid22, d2phid23], [d2phid13, d2phid23, d2phid33]])

            dtaui_depsej = ddev + K_m*Je**2
            dtaui_depsej = dtaui_depsej[(-tau_A).argsort()] #-tau_A.argsort sorts descending order which is what I need.

            K_AB = np.eye(3) + dt*np.dot(d2phid2tau, dtaui_depsej)

            K_AB_inv = np.linalg.inv(K_AB)

            tau_NEQdyadicI = 3/2*K_m*(Je**2-1) #(B8)

            res = eps_e + dt*(1/2/eta_D*devtau + 1/9/eta_V*tau_NEQdyadicI*np.ones(3))-eps_e_trial #(60)
            deps_e = np.dot(K_AB_inv, -res)
            eps_e = eps_e + deps_e
            self.eps_e = eps_e
            normres = np.linalg.norm(res)
            itr+= 1
            return [normres, itr]
        
    #Neuton Raphson
    normres = 1.0
    iter = 0
    itermax = 20
    cond_fun = lambda x: np.sign(x[0]-1.e-6) + np.sign(itermax - x[1]) > 0
    vis = viscous(eps_e, eps_e_trial, params, dt)
    while_loop(cond_fun, vis.iterate, [normres,iter])
    
    
    #Now that the iterations have converged, calculate stress
    eps_e = self.eps_e
    lamb_e = np.exp(eps_e)
    Je = lamb_e[0]*lamb_e[1]*lamb_e[2]
    bbar_e = Je**(-2/3)*lamb_e**2 #(54)

    b1 = bbar_e[0]
    b2 = bbar_e[1]
    b3 = bbar_e[2]
    
    alpha_m = self.alpha_m
    mu_m = self.mu_m
    K_m = self.K_m
    b1, b2, b3, _, Je = self.kinematics1()

    devtau1 = 0
    devtau2 = 0
    devtau3 = 0
    for r in range(3):
        e = alpha_m[r]/2
        devtau1 = devtau1 + mu_m[r]*(2/3*b1**e - 1/3*(b2**e + b3**e)) #(B8)
        devtau2 = devtau2 + mu_m[r]*(2/3*b2**e - 1/3*(b1**e + b3**e))
        devtau3 = devtau3 + mu_m[r]*(2/3*b3**e - 1/3*(b1**e + b2**e))

    devtau = np.array([devtau1, devtau2, devtau3])

    tau_NEQdyadicI = 3*K_m/2*(Je**2-1) #(B8)
    tau_A = devtau + 1/3*tau_NEQdyadicI #(B8)
    
    tau_NEQ = tau_A[0]*np.outer(n_A[:,0], n_A[:,0]) + tau_A[1]*np.outer(n_A[:,1], n_A[:,1]) + tau_A[2]*np.outer(n_A[:,2], n_A[:,2]) #(58)
    b = np.dot(F,F.transpose())
    J = np.linalg.det(F)
    sigma_EQ = mu/J*(b-np.eye(3)) + 2*K*(J-1)*np.eye(3) #neo Hookean material
    sigma = 1/Je*tau_NEQ + sigma_EQ #(7)
    
    
    #Post processing
    be = np.einsum('i,ji,ki->jk', lamb_e**2, n_A, n_A)
    F_inv = np.linalg.inv(F)
    C_i_inv_new = np.dot(F_inv, np.dot(be, F_inv.transpose()))
    return sigma, C_i_inv_new, lamb_e