In [None]:
import numpy as np
from scipy import ndimage
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import numpy.typing as npt
from typing import List, Tuple, Any
from dataclasses import dataclass

@dataclass
class GridChunk:
    """Data class for holding chunk information for parallel processing"""
    start_idx: Tuple[int, int, int]
    end_idx: Tuple[int, int, int]
    S_chunk: List[npt.NDArray]
    cells_in_chunk: List[Any]

def split_grid_into_chunks(S_list: List[npt.NDArray], cells: List[Any], 
                          n_chunks: int) -> List[GridChunk]:
    """
    Split the 3D grid and cells into chunks for parallel processing
    """
    shape_S = S_list[0].shape
    chunks_per_dim = int(np.cbrt(n_chunks))  # Cubic root for 3D splitting
    
    chunk_size = [s // chunks_per_dim for s in shape_S]
    chunks = []
    
    for i in range(chunks_per_dim):
        for j in range(chunks_per_dim):
            for k in range(chunks_per_dim):
                start_idx = (i * chunk_size[0], j * chunk_size[1], k * chunk_size[2])
                end_idx = ((i + 1) * chunk_size[0], (j + 1) * chunk_size[1], 
                          (k + 1) * chunk_size[2])
                
                # Get cells in this chunk
                cells_in_chunk = [
                    cell for cell in cells if cell.is_alive and
                    start_idx[0] <= cell.x < end_idx[0] and
                    start_idx[1] <= cell.y < end_idx[1] and
                    start_idx[2] <= cell.z < end_idx[2]
                ]
                
                # Get nutrient chunks
                S_chunk = [
                    S[start_idx[0]:end_idx[0],
                      start_idx[1]:end_idx[1],
                      start_idx[2]:end_idx[2]].copy()
                    for S in S_list
                ]
                
                chunks.append(GridChunk(start_idx, end_idx, S_chunk, cells_in_chunk))
    
    return chunks

def process_chunk(chunk: GridChunk, D_list: List[float], diff: int, dt: float,
                 uptake_rates: npt.NDArray, release_rates: npt.NDArray) -> Tuple[List[npt.NDArray], List[Any]]:
    """
    Process a single chunk of the grid
    """
    # Update nutrients in chunk
    S_new_chunk = update_nutrient_fields(
        chunk.S_chunk, D_list, chunk.cells_in_chunk, diff, dt,
        uptake_rates, release_rates
    )
    
    return S_new_chunk, chunk.cells_in_chunk

def parallel_simulate_system(S_list: List[npt.NDArray], D_list: List[float], 
                           cells: List[Any], diff: int, dt_diff: float, 
                           dt_cell: float, uptake_rates: npt.NDArray, 
                           release_rates: npt.NDArray, total_time: float):
    """
    Parallel implementation of system simulation
    """
    steps_per_cell_update = int(dt_cell / dt_diff)
    n_cell_updates = int(total_time / dt_cell)
    
    # Determine number of chunks based on CPU cores
    n_cores = mp.cpu_count()
    n_chunks = n_cores * 2  # Using 2 chunks per core for better load balancing
    
    with ProcessPoolExecutor(max_workers=n_cores) as process_executor:
        with ThreadPoolExecutor(max_workers=n_cores) as thread_executor:
            
            for cell_step in range(n_cell_updates):
                # Parallel nutrient updates
                for _ in range(steps_per_cell_update):
                    # Split grid into chunks
                    chunks = split_grid_into_chunks(S_list, cells, n_chunks)
                    
                    # Process chunks in parallel
                    future_results = [
                        process_executor.submit(
                            process_chunk, chunk, D_list, diff, dt_diff,
                            uptake_rates, release_rates
                        )
                        for chunk in chunks
                    ]
                    
                    # Collect results
                    chunk_results = [future.result() for future in future_results]
                    
                    # Merge chunk results back into main grid
                    for chunk, (S_new_chunk, updated_cells) in zip(chunks, chunk_results):
                        for i, S in enumerate(S_list):
                            S[chunk.start_idx[0]:chunk.end_idx[0],
                              chunk.start_idx[1]:chunk.end_idx[1],
                              chunk.start_idx[2]:chunk.end_idx[2]] = S_new_chunk[i]
                        
                        # Update cell states
                        for updated_cell in updated_cells:
                            for cell in cells:
                                if (cell.x == updated_cell.x and 
                                    cell.y == updated_cell.y and 
                                    cell.z == updated_cell.z):
                                    cell.internal_nutrients = updated_cell.internal_nutrients
                
                # Parallel cell population updates
                cell_chunks = np.array_split(cells, n_cores)
                future_results = [
                    thread_executor.submit(update_cell_population, chunk, dt_cell)
                    for chunk in cell_chunks
                ]
                
                # Collect and process cell updates
                new_cells = []
                dead_indices = []
                offset = 0
                
                for chunk_idx, future in enumerate(future_results):
                    chunk_new_cells, chunk_dead_indices = future.result()
                    new_cells.extend(chunk_new_cells)
                    dead_indices.extend([idx + offset for idx in chunk_dead_indices])
                    offset += len(cell_chunks[chunk_idx])
                
                # Update cell list
                cells = [cell for idx, cell in enumerate(cells) 
                        if idx not in dead_indices and cell.is_alive]
                cells.extend(new_cells)
    
    return S_list, cells

def compute_laplacian_3d(S):
    """Compute the 3D Laplacian of matrix S using convolution"""
    kernel = np.array([
        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]],
        [[0, 1, 0],
         [1, -6, 1],
         [0, 1, 0]],
        [[0, 0, 0],
         [0, 1, 0],
         [0, 0, 0]]
    ])
    return ndimage.convolve(S, kernel)

# Previous functions remain the same
def process_cell_nutrient_interactions(cells, S_list, shape_S, diff, dt, uptake_rates, release_rates):
    """Process nutrient uptake/release for all cells and update their states"""
    # Implementation remains the same
    pass

def update_cell_population(cells, dt_cell):
    """Update cell population through death and reproduction"""
    # Implementation remains the same
    pass

def update_nutrient_fields(S_list, D_list, cells, diff, dt, uptake_rates, release_rates):
    """Update nutrient concentration fields for one diffusion timestep"""
    # Implementation remains the same
    pass

In [None]:
# Setup
dt_diff = 0.1
dt_cell = 36.0
total_time = 1000.0

# Initial conditions
S1 = np.zeros((nx, ny, nz))
S2 = np.zeros((nx, ny, nz))
D1, D2 = 1.0, 0.5
uptake_rates = np.array([0.1, 0.2])
release_rates = np.array([0.05, 0.1])

# Run parallel simulation
S_list_final, final_cells = parallel_simulate_system(
    [S1, S2], [D1, D2], initial_cells, diff,
    dt_diff, dt_cell, uptake_rates, release_rates,
    total_time
)