In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
from models import DeepOHeat_v1
import cupy as cp
from cupyx.scipy.sparse.linalg import gmres
from cupyx.scipy import sparse
from pyamg import smoothed_aggregation_solver
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class gmres_counter(object):
    def __init__(self, disp=True):
        self._disp = disp
        self.niter = 0
    def __call__(self, rk=None):
        self.niter += 1
        if self._disp:
            print('iter %3i\trk = %s' % (self.niter, str(rk)))
            

class Hybrid_solver:
    def __init__(self, nx=101, ny=101, nz=56, model_path=None, precondition=False):
        self.nx, self.ny, self.nz = nx, ny, nz
        self.dx = 1.0 / (nx-1)
        self.dy = 1.0 / (ny-1)
        self.dz = 0.55 / (nz-1)
        
        # Initialize OKAN model if path provided
        self.model = None
        if model_path is not None:
            self.model = self._init_model(model_path)
            
        # Create evaluation grids for surrogate
        self.x_eval = jnp.linspace(0, 1, nx).reshape(-1,1)
        self.y_eval = jnp.linspace(0, 1, ny).reshape(-1,1)
        self.z_eval = jnp.linspace(0, 0.55, nz).reshape(-1,1)
        
        # Build sparse matrix
        print("Building system matrix...")
        t0 = time.time()
        self.A = self._build_matrix()
        print(f"Matrix built in {time.time() - t0:.2f}s")
        
        # Setup AMG preconditioner
        print("Setting up AMG preconditioner...")
        t0 = time.time()
        A_cpu = self.A.get()
        
        self.ml = smoothed_aggregation_solver(
            A_cpu,
            max_levels=10,
            max_coarse=200,
            aggregate='standard',
            strength='symmetric',
            smooth=('jacobi', {'omega': 4/3.0}),
            presmoother=None,
            postsmoother=('gauss_seidel', {'sweep': 'symmetric'}),
            keep=False
        )
        
        print(f"AMG hierarchy:\n{self.ml}")
        self.M = sparse.linalg.LinearOperator(
            self.A.shape, 
            matvec=self._amg_preconditioner
        )
        if not precondition:
            self.M = None
        print(f"AMG setup completed in {time.time() - t0:.2f}s")
        
    def _get_k(self, z):
        """Get thermal conductivity at given z coordinate"""
      
        
        if abs(z - 0.1) < 1e-10:  # At z = 0.1 interface
            return 2 * 0.1 * 2 / (0.1 + 2)  # Harmonic average at interface
        elif z < 0.1:
            return 2
        else:
            return 0.1
        
    def _init_model(self, model_path):
        """Initialize and load the DeepOHeat model"""
        model = eqx.filter_jit(DeepOHeat_v1(
            dim=3, 
            branch_dim=101**2, 
            field_dim=1,
            branch_depth=8, 
            branch_hidden=256, 
            trunk_depth=3,
            trunk_hidden=64,
            rank=128, 
            key=jax.random.PRNGKey(42)
        ))
        return eqx.tree_deserialise_leaves(model_path, model)
    
    def _build_matrix(self):
        """Build sparse matrix for finite difference scheme"""
        n = self.nx * self.ny * self.nz
        row_indices = []
        col_indices = []
        values = []
        
        for k in range(self.nz):
            z = k * self.dz
            
            for j in range(self.ny):
                for i in range(self.nx):
                    row = k*self.nx*self.ny + j*self.nx + i
                    
                    if k == 0:  # Bottom boundary with -40 coefficient
                        row_indices.append(row)
                        col_indices.append(row)
                        values.append(1.0 + 40/self.dz)
                        
                        if k < self.nz-1:
                            row_indices.append(row)
                            col_indices.append(row + self.nx*self.ny)
                            values.append(-40/self.dz)
                            
                    elif k == self.nz-1:  # Top boundary with 2 coefficient
                        row_indices.append(row)
                        col_indices.append(row)
                        values.append(1.0 + 2/self.dz)
                        
                        if k > 0:
                            row_indices.append(row)
                            col_indices.append(row - self.nx*self.ny)
                            values.append(-2/self.dz)
                            
                    else:  # Interior points
                        k_curr = self._get_k(z)
                        
                        # Initialize diagonal value
                        diag_value = 0.0
                        
                        # x-direction with Neumann BC
                        if i == 0:  # Left wall
                            diag_value += -k_curr/self.dx**2
                            row_indices.append(row)
                            col_indices.append(row+1)
                            values.append(k_curr/self.dx**2)
                        elif i == self.nx-1:  # Right wall
                            diag_value += -k_curr/self.dx**2
                            row_indices.append(row)
                            col_indices.append(row-1)
                            values.append(k_curr/self.dx**2)
                        else:  # Interior x
                            diag_value += -2.0*k_curr/self.dx**2
                            row_indices.append(row)
                            col_indices.append(row-1)
                            values.append(k_curr/self.dx**2)
                            row_indices.append(row)
                            col_indices.append(row+1)
                            values.append(k_curr/self.dx**2)
                        
                        # y-direction with Neumann BC
                        if j == 0:  # Front wall
                            diag_value += -k_curr/self.dy**2
                            row_indices.append(row)
                            col_indices.append(row+self.nx)
                            values.append(k_curr/self.dy**2)
                        elif j == self.ny-1:  # Back wall
                            diag_value += -k_curr/self.dy**2
                            row_indices.append(row)
                            col_indices.append(row-self.nx)
                            values.append(k_curr/self.dy**2)
                        else:  # Interior y
                            diag_value += -2.0*k_curr/self.dy**2
                            row_indices.append(row)
                            col_indices.append(row-self.nx)
                            values.append(k_curr/self.dy**2)
                            row_indices.append(row)
                            col_indices.append(row+self.nx)
                            values.append(k_curr/self.dy**2)
                        
                        # z-direction
                        diag_value += -2.0*k_curr/self.dz**2
                        if k > 0:
                            row_indices.append(row)
                            col_indices.append(row-self.nx*self.ny)
                            values.append(k_curr/self.dz**2)
                        if k < self.nz-1:
                            row_indices.append(row)
                            col_indices.append(row+self.nx*self.ny)
                            values.append(k_curr/self.dz**2)
                        
                        # Add accumulated diagonal value
                        row_indices.append(row)
                        col_indices.append(row)
                        values.append(diag_value)
        
        # Convert lists to arrays
        values = cp.array(values, dtype=float)
        row_indices = cp.array(row_indices, dtype=int)
        col_indices = cp.array(col_indices, dtype=int)
        
        return sparse.csr_matrix((values, (row_indices, col_indices)), 
                               shape=(n, n))
    
    def build_rhs(self, q_v):
        """Build right-hand side vector with 2D power distribution"""
        b = cp.zeros(self.nx * self.ny * self.nz)
        if isinstance(q_v, (np.ndarray, jnp.ndarray)):
            q_v = cp.asarray(q_v)
        
        for k in range(self.nz):
            z_coord = k * self.dz
            start_idx = k * self.nx * self.ny
            
            if k == 0 or k == self.nz-1:
                # Boundary conditions
                b[start_idx:start_idx + self.nx * self.ny] = 0.2
            else:
                # Power distribution
                if abs(z_coord - 0.1) < 1e-10 or abs(z_coord - 0.15) < 1e-10:
                    b[start_idx:start_idx + self.nx * self.ny] = -1 * q_v.ravel()  # Half power at interfaces
                elif 0.1 < z_coord < 0.15:
                    b[start_idx:start_idx + self.nx * self.ny] = -2 * q_v.ravel()  # Full power in active region
                else:
                    b[start_idx:start_idx + self.nx * self.ny] = 0.0
        
        return b
    
    def _amg_preconditioner(self, b):
        """Simplified AMG preconditioner function"""
        b_cpu = cp.asnumpy(b)
        x = self.ml.solve(
            b_cpu, 
            x0=None,
            tol=1e-1, 
            maxiter=1,
        )
        return cp.array(x)
    
    def get_surrogate_prediction(self, model, x_eval, y_eval, z_eval, q_v):
        """JIT-compiled surrogate prediction using SepOKAN format"""
        return model(((x_eval, y_eval, z_eval), q_v.reshape(1,-1)))
    
    def compute_all_residuals(self, x, b, q_v, is_initial=False):
        """Compute both system and PDE residuals"""
        # System residual
        system_residual = cp.linalg.norm(b - self.A @ x.ravel())
        relative_system_residual = system_residual / cp.linalg.norm(b)
        
        # Reshape solution for PDE residuals
        T = x.reshape((self.nz, self.nx, self.ny)).transpose((1,2,0))
        
        # Interior PDE residual
        d2x = (T[2:-1,1:-2,1:-2] - 2*T[1:-2,1:-2,1:-2] + T[:-3,1:-2,1:-2])/(self.dx**2)
        d2y = (T[1:-2,2:-1,1:-2] - 2*T[1:-2,1:-2,1:-2] + T[1:-2,:-3,1:-2])/(self.dy**2)
        d2z = (T[1:-2,1:-2,2:-1] - 2*T[1:-2,1:-2,1:-2] + T[1:-2,1:-2,:-3])/(self.dz**2)
        
        # Get z coordinates for interior points
        z_coords = cp.arange(1, self.nz-2) * self.dz
        
        # Initialize k array for interior points
        k_interior = cp.zeros_like(z_coords)
        
        # Set k values including harmonic averages at interfaces
        for i, z in enumerate(z_coords):
                
            if abs(z - 0.1) < 1e-10:
                k_interior[i] = 2 * 0.1 * 2 / (0.1 + 2)
            elif z < 0.1:
                k_interior[i] = 2.0
            else:
                k_interior[i] = 0.1
        
        # Apply k to Laplacian
        laplacian = k_interior.reshape(1, 1, -1) * (d2x + d2y + d2z)
        
        # Power term
        power_mask = cp.zeros_like(z_coords)
        power_mask[(z_coords > 0.1) & (z_coords < 0.15)] = 2
        power_mask[cp.abs(z_coords - 0.1) < 1e-10] = 1  # Half power at interfaces
        power_mask[cp.abs(z_coords - 0.15) < 1e-10] = 1  # Half power at interfaces
        
        power_term = cp.einsum('ij,k->ijk', q_v[1:-2,1:-2], power_mask)
        
        interior_residual = cp.linalg.norm(laplacian + power_term)
        
        # Boundary residuals
        # Top BC: T - 0.2 + 2∂T/∂z = 0
        dz_top = (T[:,:,-1] - T[:,:,-2])/self.dz
        top_residual = cp.linalg.norm(T[:,:,-1] - 0.2 + 2*dz_top)
        
        # Bottom BC: T - 0.2 - 40∂T/∂z = 0
        dz_bottom = (T[:,:,1] - T[:,:,0])/self.dz
        bottom_residual = cp.linalg.norm(T[:,:,0] - 0.2 - 40*dz_bottom)
        
        # Side walls: ∂T/∂x = 0, ∂T/∂y = 0
        left_residual = cp.linalg.norm((T[1,:,:] - T[0,:,:])/self.dx)
        right_residual = cp.linalg.norm((T[-1,:,:] - T[-2,:,:])/self.dx)
        front_residual = cp.linalg.norm((T[:,1,:] - T[:,0,:])/self.dy)
        back_residual = cp.linalg.norm((T[:,-1,:] - T[:,-2,:])/self.dy)
        
        total_pde_residual = cp.sqrt(interior_residual**2 + 
                                    top_residual**2 + 
                                    bottom_residual**2 +
                                    left_residual**2 +
                                    right_residual**2 +
                                    front_residual**2 +
                                    back_residual**2)
        
        prefix = "Initial" if is_initial else "Final"
        print(f"\n{prefix} Residuals:")
        print(f"System residual: {float(system_residual)}")
        print(f"Relative system residual: {float(relative_system_residual)}")
        print(f"Interior PDE residual: {float(interior_residual)}")
        print(f"BC residuals - Top: {float(top_residual)}, Bottom: {float(bottom_residual)}")
        print(f"BC residuals - Left: {float(left_residual)}, Right: {float(right_residual)}")
        print(f"BC residuals - Front: {float(front_residual)}, Back: {float(back_residual)}")
        print(f"Total PDE residual: {float(total_pde_residual)}")
        
        return {
            'system_residuals': {
                'absolute': float(system_residual),
                'relative': float(relative_system_residual)
            },
            'pde_residuals': {
                'interior': float(interior_residual),
                'boundaries': {
                    'top': float(top_residual),
                    'bottom': float(bottom_residual),
                    'left': float(left_residual),
                    'right': float(right_residual),
                    'front': float(front_residual),
                    'back': float(back_residual)
                },
                'total': float(total_pde_residual)
            }
        }
    
    def solve(self, q_v, use_surrogate=True, tol=5e-2, maxiter=20000, restart=200):
        """Solve with optional surrogate initial guess"""
        if q_v.shape != (self.nx, self.ny):
            raise ValueError(f"Expected q_v shape ({self.nx}, {self.ny}), got {q_v.shape}")
        
        t0 = time.time()
        b = self.build_rhs(q_v)
        M = self.M
        
        # Get and verify initial guess
        x0 = None
        initial_residuals = None
        if use_surrogate and self.model is not None:
            t1 = time.time()
            x0 = self.get_surrogate_prediction(
                self.model,
                self.x_eval,
                self.y_eval,
                self.z_eval,
                q_v
            )
            print(f"Surrogate inference in {time.time() - t1:.5f}s")
            
            x0 = x0.reshape(self.nx, self.ny, self.nz).transpose((2,0,1))
            x0 = cp.asarray(x0.reshape(-1))
            initial_residuals = self.compute_all_residuals(x0, b, q_v, is_initial=True)
        
        counter = gmres_counter(disp=True)
        
        x, info = gmres(
            self.A, b, 
            x0=x0,
            M=M, 
            tol=tol,
            maxiter=maxiter,
            restart=restart,
            callback=counter,
        )
        
        # Verify final solution
        solve_time = time.time() - t0
        final_residuals = self.compute_all_residuals(x, b, q_v, is_initial=False)
        
        print(f"\nSolution time: {solve_time:.2f}s")
        print(f"GMRES iterations: {counter.niter}")
        
        if use_surrogate:
            return x0.reshape((self.nz, self.nx, self.ny)).transpose((1,2,0)), x.reshape((self.nz, self.nx, self.ny)).transpose((1,2,0)), {
                'initial_residuals': initial_residuals,
                'final_residuals': final_residuals,
                'iterations': counter.niter,
                'time': solve_time
            }
        else:
            return x.reshape((self.nz, self.nx, self.ny)).transpose((1,2,0)), {
                'final_residuals': final_residuals,
                'iterations': counter.niter,
                'time': solve_time
            }
            
def convert_interval_to_grid(power_map):
    interval_shape = power_map.shape
    grid_shape = tuple([int(i + 1) for i in interval_shape])

    grid = np.zeros(grid_shape)

    for j in [0, -1]:
        grid[:-1, j] += power_map[:, j]
        grid[1:, j] += power_map[:, j]
        grid[1:-1, j] /= 2

    grid[:-1, 1:-1] += power_map[:, :-1] + power_map[:, 1:]
    grid[1:, 1:-1] += power_map[:, :-1] + power_map[:, 1:]
    grid[1:-1, 1:-1] /= 4
    grid[0, 1:-1] /= 2
    grid[-1, 1:-1] /= 2

    return grid

In [3]:
model_path = "results/results_volume/DeepOHeat_v1/nf50_nc101_branch_8_256_trunk_3_64_r128/DeepOHeat_v1_trained_model.eqx"
solver_with_surrogate = Hybrid_solver(model_path=model_path)
solver_no_surrogate_pc = Hybrid_solver(precondition=True)
solver_no_surrogate = Hybrid_solver(precondition=False)

iyz,jyz,kyz,byz->bijky
Building system matrix...
Matrix built in 1.81s
Setting up AMG preconditioner...
AMG hierarchy:
MultilevelSolver
Number of Levels:     4
Operator Complexity:   1.657
Grid Complexity:       1.145
Coarse Solver:        'pinv'
  level   unknowns     nonzeros
     0      571256      3874966 [60.36%]
     1       81077      2437525 [37.97%]
     2        2003       106837 [1.66%]
     3          31          681 [0.01%]

AMG setup completed in 2.09s
Building system matrix...
Matrix built in 1.66s
Setting up AMG preconditioner...
AMG hierarchy:
MultilevelSolver
Number of Levels:     4
Operator Complexity:   1.657
Grid Complexity:       1.145
Coarse Solver:        'pinv'
  level   unknowns     nonzeros
     0      571256      3874966 [60.36%]
     1       81077      2437525 [37.97%]
     2        2003       106837 [1.66%]
     3          31          681 [0.01%]

AMG setup completed in 2.06s
Building system matrix...
Matrix built in 1.72s
Setting up AMG preconditioner...


In [4]:
fs_test_volume = np.load('data/fs_test_volume.npy')
fs_test_random = fs_test_volume[27]
u_test_volume = np.load('data/u_test_volume.npy')
u_test_random = u_test_volume[27]

In [5]:
# gmres w/o precondition
_, info = solver_no_surrogate.solve(fs_test_random, use_surrogate=False)

iter   1	rk = 0.5105381405932757
iter   2	rk = 0.4851694538159846
iter   3	rk = 0.478037470662677
iter   4	rk = 0.4741485603633125
iter   5	rk = 0.47105191103203187
iter   6	rk = 0.4683764275412546
iter   7	rk = 0.46528557820465943
iter   8	rk = 0.462680916431091
iter   9	rk = 0.45970638494472427
iter  10	rk = 0.4571379091899718
iter  11	rk = 0.4534113368438143
iter  12	rk = 0.44266379143190765
iter  13	rk = 0.4397595971000476
iter  14	rk = 0.436400954169568
iter  15	rk = 0.42618630685903813
iter  16	rk = 0.4215172700046839
iter  17	rk = 0.41911184615529345
iter  18	rk = 0.417305661791354
iter  19	rk = 0.4114098717840103
iter  20	rk = 0.40625328558425716
iter  21	rk = 0.40204965494559913
iter  22	rk = 0.3985474629611959
iter  23	rk = 0.39703143582637956
iter  24	rk = 0.3959478585477376
iter  25	rk = 0.3945181148167034
iter  26	rk = 0.39225774438922106
iter  27	rk = 0.390120028923519
iter  28	rk = 0.38161602679724116
iter  29	rk = 0.3786937279653176
iter  30	rk = 0.3775406270826822
iter

In [9]:
# gmres w/ operator model initial guess
x0, model_output, info = solver_with_surrogate.solve(fs_test_random, use_surrogate=True)

Surrogate inference in 0.00153s

Initial Residuals:
System residual: 4228.712937143562
Relative system residual: 9.964201705170948
Interior PDE residual: 4133.398973619138
BC residuals - Top: 5.009191513061523, Bottom: 72.57321166992188
BC residuals - Left: 1.5793908834457397, Right: 1.1930056810379028
BC residuals - Front: 0.6000903248786926, Back: 1.3689662218093872
Total PDE residual: 4134.039814772427
iter   1	rk = 0.13796102042465327
iter   2	rk = 0.0723577179899582
iter   3	rk = 0.06240495522051169
iter   4	rk = 0.05656476590091868
iter   5	rk = 0.04923293140244488

Final Residuals:
System residual: 20.893990318058144
Relative system residual: 0.04923293140244489
Interior PDE residual: 17.935409184553674
BC residuals - Top: 5.707863058170653, Bottom: 4.193470445006232
BC residuals - Left: 2.506287585611033, Right: 0.9137134589626688
BC residuals - Front: 0.9637976834715744, Back: 1.3064034920810639
Total PDE residual: 19.534475766884725

Solution time: 1.39s
GMRES iterations: 5


In [7]:
# gmres w/ AMG preconditioner
_, info = solver_no_surrogate_pc.solve(fs_test_random, use_surrogate=False)

iter   1	rk = 3.3376473911146054e-11

Final Residuals:
System residual: 1.4164659769086488e-08
Relative system residual: 3.3376473911146054e-11
Interior PDE residual: 1.3807855028748004e-08
BC residuals - Top: 1.1263358951647526e-11, Bottom: 8.204601329837885e-11
BC residuals - Left: 2.4184578124853147, Right: 0.8448976573723951
BC residuals - Front: 0.9997853493749437, Back: 1.2875432268974027
Total PDE residual: 3.036466457604755

Solution time: 17.75s
GMRES iterations: 1
