In [None]:
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from numba import njit, cuda

########################################
# Numba Implementations for Interface Update
########################################

@njit
def fix_interface_numba_cpu(grid, seg_indices, alphas, k_interface, c_dt_dx2):
    """
    CPU version of the interface update using Numba.
    Args:
        grid (np.ndarray): 1D array of grid values.
        seg_indices (np.ndarray): Interface indices (last index of left segment).
        alphas (np.ndarray): Conductivities. Length = n_interfaces+1.
        k_interface (np.ndarray): Computed interface conductivities.
        c_dt_dx2 (float): Precomputed factor.
    """
    for i in range(seg_indices.shape[0]):
        idx = seg_indices[i]
        i0 = grid[idx - 1]
        i1 = grid[idx]
        i2 = grid[idx + 1]
        i3 = grid[idx + 2]
        change1 = c_dt_dx2 * (k_interface[i] * (i2 - i1) + alphas[i] * (i0 - i1))
        change2 = c_dt_dx2 * (alphas[i+1] * (i3 - i2) + alphas[i] * (i1 - i2))
        grid[idx] = i1 + change1
        grid[idx+1] = i2 + change2
    return grid

@cuda.jit
def fix_interface_numba_gpu(grid, seg_indices, alphas, k_interface, c_dt_dx2):
    """
    GPU version of the interface update using Numba CUDA.
    This kernel is launched with 1D grid.
    """
    i = cuda.grid(1)
    if i < seg_indices.shape[0]:
        idx = seg_indices[i]
        # Direct indexing on device arrays
        i0 = grid[idx - 1]
        i1 = grid[idx]
        i2 = grid[idx + 1]
        i3 = grid[idx + 2]
        change1 = c_dt_dx2 * (k_interface[i] * (i2 - i1) + alphas[i] * (i0 - i1))
        change2 = c_dt_dx2 * (alphas[i+1] * (i3 - i2) + alphas[i] * (i1 - i2))
        grid[idx] = i1 + change1
        grid[idx+1] = i2 + change2

########################################
# Original Classes with Merged Numba Update
########################################

class BC_1D:
    def __init__(self, left, right):
        """
        Args:
            left, right: (alpha, beta, f(t))
        """
        self.left_alpha, self.left_beta, self.left_func = left
        self.right_alpha, self.right_beta, self.right_func = right
        
    def apply(self, grid, dx, cur_time):
        gamma_left = self.left_beta / dx
        gamma_right = self.right_beta / dx
        # Left boundary
        grid[0] = (self.left_func(cur_time) - gamma_left * grid[1]) / (self.left_alpha - gamma_left)
        # Right boundary
        grid[-1] = (self.right_func(cur_time) + gamma_right * grid[-2]) / (self.right_alpha + gamma_right)

class SegConduct:
    def __init__(self, alphas, segs=[]):
        """
        Args:
            alphas (list[float]): a list of conductivities
            segs (list[float]): boundary of different conductivities
        """
        self.alphas = alphas
        self.segs = segs
        self.conduct_map = None
    
    def sanity_check(self, simu):
        # Check the validility of input conditions
        
        # Check if each segement can be assigned with a value or not
        assert len(self.alphas) == len(self.segs) + 1, 'Number of conductivities does not match with number of segment bars.'
        
        # Check if seg bars are in legal order or not
        jud = True
        if len(self.segs) > 1:
            for i in range(len(self.segs)-1):
                if self.segs[i] >= self.segs[i+1]:
                    jud = False
                    break
        assert jud, 'Value inside segment bars is not strictly monotonically increasing.'
        
        # Check if the seg bar range is still within the simulation grid or not
        assert self.segs[-1] < simu.L - simu.dx and self.segs[0] > 0, 'Position of segment bars is out of simulation bound or too close to the boundary.'
        
        simu.if_seg = True
    
    def make_seg_index_and_calc_k_interface(self, simu):
        # Location of the boundary index is the last index of the left segment
        self.seg_index = []
        self.k_interface = []
        append1 = self.seg_index.append
        append2 = self.k_interface.append
        for i in range(len(self.segs)):
            append1(math.floor(self.segs[i] / simu.dx))
            append2( 2 * self.alphas[i] * self.alphas[i+1] / (self.alphas[i] + self.alphas[i+1]))
  
    def make_conduct_map(self, simu):
        self.conduct_map = torch.zeros(simu.xstep)
        # Assigning values
        start = 0
        for i in range(len(self.seg_index)):
            end = self.seg_index[i] + 1
            self.conduct_map[start:end] = self.alphas[i]
            start = end
        self.conduct_map[start:-1] = self.alphas[-1]
        # print(simu.device)
        self.conduct_map = self.conduct_map.to(simu.device)
        # print(self.conduct_map.device)
class Heat1dSimu:
    def __init__(self, L, xstep, total_time, tstep, bc, ic, c, plot_step, device='cpu'):
        """
        Args:
            L (float): Length of the 1D domain.
            xstep (int): Number of interior points.
            total_time (float): End time for simulation.
            tstep (int): Number of time steps.
            bc (BC_1D): Boundary condition object.
            ic (callable): Function for initial condition.
            c (float or SegConduct): Diffusion coefficient (uniform or piecewise).
            plot_step (int): Steps between plots.
            device (str): 'cpu' or 'cuda'.
        """
        self.L = L
        self.xstep = xstep
        self.total_time = total_time
        self.tstep = tstep
        self.bc = bc
        self.ic = ic
        self.c = c
        self.device = device
        self.cur_time = 0.0
        self.plot_step = plot_step
        self.if_seg = False

        # Discretization
        self.dx = L / (xstep + 0)  # note: ensure correct division for your case
        self.dt = total_time / tstep
        if isinstance(self.c, SegConduct):
            self.c.sanity_check(self)
            self.c.make_seg_index_and_calc_k_interface(self)
            self.c.make_conduct_map(self)
            self.c.conduct_map = self.c.conduct_map.to(self.device)
            mul = 1
        else:
            mul = c
            print('Uniform conductivity')
        self.c_dt_dx2 = mul * self.dt / (self.dx ** 2)

        # Solution array (including boundaries)
        self.grid = torch.zeros(xstep + 2, device=self.device)

        # Set initial condition
        self.set_ic()

        # Define convolution kernel (for the interior update)
        self.conv = nn.Conv1d(1, 1, kernel_size=3, bias=False, device=self.device)
        with torch.no_grad():
            self.conv.weight[:] = self.c_dt_dx2 * torch.tensor([[[1, -2, 1]]], dtype=torch.float, device=self.device)

    def conduct_seg_sanity_check(self):
        if self.if_seg:
            self.c.sanity_check(self)
        elif not isinstance(self.c, int):
            raise TypeError('Expected int or SegConduct object.')

    def set_bc(self):
        """Apply boundary conditions."""
        self.bc.apply(self.grid, self.dx, self.cur_time)

    def set_ic(self):
        """Initialize interior points using the initial condition function."""
        x_interior = torch.linspace(self.dx, self.dx * self.xstep, self.xstep, device=self.device)
        self.grid[1:-1] = self.ic(x_interior)

    def update(self):
        """Perform one time step update."""
        self.set_bc()
        
        interface_list = self.record_interface()
        
        with torch.no_grad():
            input_for_conv = self.grid.view(1, 1, -1)
            second_diff = self.conv(input_for_conv).view(-1)
            if self.if_seg:
                second_diff *= self.c.conduct_map
            self.grid[1:-1] += second_diff
        
        if self.if_seg:
            # Update the interfaces using Numba
            self.update_interface_numba()
        
        self.cur_time += self.dt

    def update_interface_numba(self):
        """
        Use Numba to update the grid at the interface points.
        Dispatch to CPU or GPU version based on self.device.
        """
        # Prepare data as numpy arrays
        seg_indices_np = np.array(self.c.seg_index, dtype=np.int64)
        alphas_np = np.array(self.c.alphas, dtype=np.float32)
        k_interface_np = np.array(self.c.k_interface, dtype=np.float32)
        
        if self.device == 'cpu':
            # For CPU, move grid to numpy, update, then convert back.
            grid_np = self.grid.cpu().numpy().astype(np.float32)
            grid_np = fix_interface_numba_cpu(grid_np, seg_indices_np, alphas_np, k_interface_np, self.c_dt_dx2)
            self.grid = torch.from_numpy(grid_np).to(self.device)
        else:
            # For GPU, wrap the grid via __cuda_array_interface__
            grid_dev = cuda.as_cuda_array(self.grid)
            seg_indices_dev = cuda.to_device(seg_indices_np)
            alphas_dev = cuda.to_device(alphas_np)
            k_interface_dev = cuda.to_device(k_interface_np)
            threads_per_block = 32
            blocks_per_grid = (len(seg_indices_np) + (threads_per_block - 1)) // threads_per_block
            fix_interface_numba_gpu[blocks_per_grid, threads_per_block](
                grid_dev, seg_indices_dev, alphas_dev, k_interface_dev, self.c_dt_dx2
            )
            # grid_dev is a view of self.grid, so self.grid is updated in place.

    def record_interface(self):
        """Optional: record interface neighborhoods as a list of tensors."""
        return [self.grid[seg-1:seg+3] for seg in self.c.seg_index]
            
    def start(self, do_plot=True):
        """Run the simulation and optionally plot the solution."""
        if do_plot:
            fig, ax = plt.subplots(figsize=(8, 5))
            cmap = plt.cm.plasma
            norm = plt.Normalize(vmin=0, vmax=self.total_time)
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([])

        for step in tqdm(range(self.tstep),disable=True):
            self.update()
            if do_plot and step % self.plot_step == 0:
                current_time = step * self.dt
                color = cmap(norm(current_time))
                ax.plot(self.grid.cpu().numpy(), color=color, label=f't={current_time:.2f}')

        if do_plot:
            cbar = fig.colorbar(sm, ax=ax, label='Time')
            ax.set_xlabel("Grid Index")
            ax.set_ylabel("Temperature")
            ax.set_title("1D Heat Equation Evolution")
            ax.grid(True)
            plt.show()


In [20]:
def ic(x):
    return torch.sin(2*x)

def left(t):
    return 0

def right(t):
    return 0

bc = BC_1D((1,0,left), (1,0,right)) 
L = math.pi
xstep = 100
total_time = 0.5
tstep = 3200
# c = 1
num = 6
c = SegConduct([0.1*i for i in range(1,num+1)], [math.pi/num*i for i in range(1,num)])
# factor = c * total_time / tstep /(L / (xstep + 1))**2
plot_step = 80
# print(factor)

In [21]:
def target(L, xstep, total_time, tstep, bc, ic, c, plot_step):
    test = Heat1dSimu(L, xstep, total_time, tstep, bc, ic, c, plot_step, 'cuda')
    test.start(do_plot=False)




In [22]:
%timeit target(L, xstep, total_time, tstep, bc, ic, c, plot_step)

2.68 s ± 31.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
