In [None]:
import numpy as np
from scipy import ndimage
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import numpy.typing as npt
from typing import List, Tuple, Any
from dataclasses import dataclass

In [None]:
def convolution_laplacian_3D(M):
    #Make a big matrix to incorporate boundary conditions
    #for this implementation, we will have periodic conditions in x and y and no flux in Z.
    #This involves literally just adding rows such that the boundary conditions are satisfied.
    #For periodic, the padded layer will be the value from the opposite side of the matrix,
    #for no flux, the padded layer (-1) equals the (1) layer such that the central difference around (0) is 0
    shp = np.shape(M)
    
    #make the padded matrix
    M_padded = np.zeros((shp[0]+2,shp[1]+2,shp[2]+2))
    
    #put the original matrix inside
    M_padded[1:-1,1:-1,1:-1] = M.copy()
    
    #pad the edges, starting with Z
    M_padded[:,:,0] = M_padded[:,:,2]
    M_padded[:,:,-1] = M_padded[:,:,-3]
    
    #pad the edges, X direction
    M_padded[0,:,:] = M_padded[-2,:,:]
    M_padded[-1,:,:] = M_padded[1,:,:]
    
    #pad the edges, Y direction
    M_padded[:,0,:] = M_padded[:,-2,:]
    M_padded[:,-1,:] = M_padded[:,1,:]
    
    """
    """
    #using the 27 point stencil
    #k = 1/26*[[[2,3,2],[3,6,3],[2,3,2]],[[3,6,3],[6,-88,6],[3,6,3]],[[2,3,2],[3,6,3],[2,3,2]]]
    
    #7 point stencil
    k = [[[0,0,0],[0,1,0],[0,0,0]],[[0,1,0],[1,-6,1],[0,1,0]],[[0,0,0],[0,1,0],[0,0,0]]]
    """
    """
    
    L = convolve(M_padded, k)
    M_upd = L[2:-2,2:-2,2:-2]
    
    #L = convolve(M_padded, k, mode='same')
    #M_upd = L[1:-1,1:-1,1:-1]
    
    return M_upd


In [None]:
def update_nutrient_fields(S_list, D_list, cells, diff, dt, uptake_rates, release_rates):
    """Update nutrient concentration fields for one diffusion timestep"""
    n_nutrients = len(S_list)
    shape_S = S_list[0].shape
    S_new_list = []
    
    uptake_matrix, release_matrix = process_cell_nutrient_interactions(
        cells, S_list, shape_S, diff, dt, uptake_rates, release_rates
    )
    
    for i in range(n_nutrients):
        laplacian = compute_laplacian_3d(S_list[i])
        diffusion_term = D_list[i] * laplacian
        source_sink_term = release_matrix[i] - uptake_matrix[i]
        S_new = S_list[i] + dt * (diffusion_term + source_sink_term)
        S_new = np.maximum(S_new, 0)
        S_new_list.append(S_new)
    
    return S_new_list

In [None]:
def update_cell_population(cells, dt_cell):
    """
    Update cell population through death and reproduction
    
    Parameters:
    cells: list of Cell objects
    dt_cell: timestep for cell updates
    
    Returns:
    List of new cells created during reproduction
    List of indices of cells that died
    """
    new_cells = []
    dead_cell_indices = []
    
    for idx, cell in enumerate(cells):
        if not cell.is_alive:
            continue
            
        # Process death
        if np.random.random() < cell.death_rate * dt_cell:
            cell.is_alive = False
            dead_cell_indices.append(idx)
            continue
        
        # Check reproduction conditions
        can_reproduce = True
        for nutrient_idx in range(cell.dependencies.shape[0]):
            if cell.dependencies[nutrient_idx, 0]:  # if cell needs this nutrient
                if cell.internal_nutrients[nutrient_idx] < cell.alpha[nutrient_idx]:
                    can_reproduce = False
                    break
        
        # Process reproduction
        if can_reproduce:
            # Create new cell with small random displacement
            new_cell = cell.reproduce()  # This should create a new cell object
            if new_cell is not None:
                new_cells.append(new_cell)
    
    return new_cells, dead_cell_indices


In [None]:
def michaelis_menten_uptake(S: float, vm: float, km: float) -> float:
    """
    Calculate uptake rate using Michaelis-Menten kinetics
    
    Parameters:
    S: local nutrient concentration
    vm: maximum uptake rate
    km: half-saturation constant
    
    Returns:
    Uptake rate
    """
    return vm * S / (S + km)

In [None]:
def process_cell_nutrient_interactions(cells, S_list, shape_S, diff, dt, vm_list, km_list, release_rates):
    """
    Process nutrient uptake/release for all cells using Michaelis-Menten kinetics
    
    Parameters:
    cells: list of Cell objects
    S_list: list of nutrient concentration arrays
    shape_S: shape of nutrient grids
    diff: grid refinement factor
    dt: timestep
    vm_list: list of maximum uptake rates for each nutrient
    km_list: list of half-saturation constants for each nutrient
    release_rates: list of release rates for each nutrient
    """
    n_nutrients = len(S_list)
    uptake_matrix = np.zeros((n_nutrients, *shape_S))
    release_matrix = np.zeros((n_nutrients, *shape_S))
    
    for cell in cells:
        if not cell.is_alive:
            continue
            
        i, j, k = int(cell.x // diff), int(cell.y // diff), int(cell.z // diff)
        
        if not (0 <= i < shape_S[0] and 0 <= j < shape_S[1] and 0 <= k < shape_S[2]):
            continue
            
        for nutrient_idx in range(n_nutrients):
            local_concentration = S_list[nutrient_idx][i, j, k]
            
            if cell.dependencies[nutrient_idx, 0]:  # uptake
                # Calculate uptake using Michaelis-Menten kinetics
                uptake_rate = michaelis_menten_uptake(
                    local_concentration,
                    vm_list[nutrient_idx],
                    km_list[nutrient_idx]
                )
                
                # Calculate actual uptake for this timestep
                desired_uptake = uptake_rate * dt
                actual_uptake = min(desired_uptake, local_concentration)
                
                uptake_matrix[nutrient_idx, i, j, k] += actual_uptake
                cell.process_nutrient_uptake(nutrient_idx, actual_uptake)
            
            if cell.dependencies[nutrient_idx, 1]:  # release
                release_amount = release_rates[nutrient_idx] * dt
                release_matrix[nutrient_idx, i, j, k] += release_amount
                cell.process_nutrient_release(nutrient_idx, release_amount)
    
    return uptake_matrix, release_matrix

In [None]:
@dataclass
class GridChunk:
    """Data class for holding chunk information for parallel processing"""
    start_idx: Tuple[int, int, int]
    end_idx: Tuple[int, int, int]
    S_chunk: List[npt.NDArray]
    cells_in_chunk: List[Any]


In [None]:
def split_grid_into_chunks(S_list: List[npt.NDArray], cells: List[Any], 
                          n_chunks: int) -> List[GridChunk]:
    """
    Split the 3D grid and cells into chunks for parallel processing
    """
    shape_S = S_list[0].shape
    chunks_per_dim = int(np.cbrt(n_chunks))  # Cubic root for 3D splitting
    
    chunk_size = [s // chunks_per_dim for s in shape_S]
    chunks = []
    
    for i in range(chunks_per_dim):
        for j in range(chunks_per_dim):
            for k in range(chunks_per_dim):
                start_idx = (i * chunk_size[0], j * chunk_size[1], k * chunk_size[2])
                end_idx = ((i + 1) * chunk_size[0], (j + 1) * chunk_size[1], 
                          (k + 1) * chunk_size[2])
                
                # Get cells in this chunk
                cells_in_chunk = [
                    cell for cell in cells if cell.is_alive and
                    start_idx[0] <= cell.x < end_idx[0] and
                    start_idx[1] <= cell.y < end_idx[1] and
                    start_idx[2] <= cell.z < end_idx[2]
                ]
                
                # Get nutrient chunks
                S_chunk = [
                    S[start_idx[0]:end_idx[0],
                      start_idx[1]:end_idx[1],
                      start_idx[2]:end_idx[2]].copy()
                    for S in S_list
                ]
                
                chunks.append(GridChunk(start_idx, end_idx, S_chunk, cells_in_chunk))
    
    return chunks

In [None]:
def process_chunk(chunk: GridChunk, D_list: List[float], diff: int, dt: float,
                 uptake_rates: npt.NDArray, release_rates: npt.NDArray) -> Tuple[List[npt.NDArray], List[Any]]:
    """
    Process a single chunk of the grid
    """
    # Update nutrients in chunk
    S_new_chunk = update_nutrient_fields(
        chunk.S_chunk, D_list, chunk.cells_in_chunk, diff, dt,
        uptake_rates, release_rates
    )
    
    return S_new_chunk, chunk.cells_in_chunk


In [None]:
def parallel_simulate_system(S_list: List[npt.NDArray], D_list: List[float], 
                           cells: List[Any], diff: int, dt_diff: float, 
                           dt_cell: float, uptake_rates: npt.NDArray, 
                           release_rates: npt.NDArray, total_time: float):
    """
    Parallel implementation of system simulation
    """
    steps_per_cell_update = int(dt_cell / dt_diff)
    n_cell_updates = int(total_time / dt_cell)
    
    # Determine number of chunks based on CPU cores
    n_cores = mp.cpu_count()
    n_chunks = n_cores * 2  # Using 2 chunks per core for better load balancing
    
    with ProcessPoolExecutor(max_workers=n_cores) as process_executor:
        with ThreadPoolExecutor(max_workers=n_cores) as thread_executor:
            
            for cell_step in range(n_cell_updates):
                # Parallel nutrient updates
                for _ in range(steps_per_cell_update):
                    # Split grid into chunks
                    chunks = split_grid_into_chunks(S_list, cells, n_chunks)
                    
                    # Process chunks in parallel
                    future_results = [
                        process_executor.submit(
                            process_chunk, chunk, D_list, diff, dt_diff,
                            uptake_rates, release_rates
                        )
                        for chunk in chunks
                    ]
                    
                    # Collect results
                    chunk_results = [future.result() for future in future_results]
                    
                    # Merge chunk results back into main grid
                    for chunk, (S_new_chunk, updated_cells) in zip(chunks, chunk_results):
                        for i, S in enumerate(S_list):
                            S[chunk.start_idx[0]:chunk.end_idx[0],
                              chunk.start_idx[1]:chunk.end_idx[1],
                              chunk.start_idx[2]:chunk.end_idx[2]] = S_new_chunk[i]
                        
                        # Update cell states
                        for updated_cell in updated_cells:
                            for cell in cells:
                                if (cell.x == updated_cell.x and 
                                    cell.y == updated_cell.y and 
                                    cell.z == updated_cell.z):
                                    cell.internal_nutrients = updated_cell.internal_nutrients
                
                # Parallel cell population updates
                cell_chunks = np.array_split(cells, n_cores)
                future_results = [
                    thread_executor.submit(update_cell_population, chunk, dt_cell)
                    for chunk in cell_chunks
                ]
                
                # Collect and process cell updates
                new_cells = []
                dead_indices = []
                offset = 0
                
                for chunk_idx, future in enumerate(future_results):
                    chunk_new_cells, chunk_dead_indices = future.result()
                    new_cells.extend(chunk_new_cells)
                    dead_indices.extend([idx + offset for idx in chunk_dead_indices])
                    offset += len(cell_chunks[chunk_idx])
                
                # Update cell list
                cells = [cell for idx, cell in enumerate(cells) 
                        if idx not in dead_indices and cell.is_alive]
                cells.extend(new_cells)
    
    return S_list, cells


In [None]:
# Time and dimensions
side_len = 750 #microns
box_height = 300 #microns
cell_size = 5 #microns
cell_to_diff_ratio = 10
density = 10 #mm^-2

dt_diff = 1 #s
dt_cell = 0.1 #hour
dt_cell_to_diff_ratio = 360
total_time = 10 #hours

nsteps = int(t_final/dt_cell)
time = np.arange(nsteps)/dt_diff + dt_diff

D = 20 #um^2/s
DS = D/np.square(cell_size*cell_to_diff_ratio) #cell_side^2/s
#print(DS)

nx,ny = int(side_len/(cell_size,cell_to_diff_ratio))
S1 = np.zeros((nx, ny, nz))
S2 = np.zeros((nx, ny, nz))
D1, D2 = DS

In [None]:
filepath = 'data/cosmo_claude_220125_'

n_strains = 2
n_nutrients = 2

#The dependencies of strain the strains on each nutrient. Within each strain and nutrient, COLUMN 1 is UPTAKE, 
#COLUMN 2 is RELEASE. 1 means they do that function, 0 means they do not
dependencies = np.zeros((n_strains,n_nutrients,2))


#This initialization is for classic CoSMO type crossfeeding.
dependencies[0,0,0] = 1 #strain 1, uptake nutrient 1
dependencies[0,1,1] = 1 #strain 1, release nutrient 2

dependencies[1,1,0] = 1 #strain 2, uptake nutrient 2
dependencies[1,0,1] = 1 #strain 2, release nutrient 1


In [None]:
# Nutrient initialization
#Lys
alphaL = 5.4 #nutrient required for reproduction (fmole)
rL = 0.51 #Maximum growth rate (hr^-1)
vmL = alphaL*rL #maximum uptake rate (fmole/hr)
KL = 2.1e6 #Monod constant (fmole/ml)
gammaL = 0.4 #release rate (fmole/(cell*hr))
gammaL_s = gammaL/(60^2) #release rate (fmole/(cell*s))
dL = 0.021 #death rate (hr^-1)

#Ade
alphaA = 3.1
rA = 0.44
vmA = alphaA*rA
vmA_s = vmA/(60^2)
KA = 1.3e6
gammaA = 0.26
dA = 0.015 #death rate (hr^-1). Should this be an order of magnitude higher?


alpha_list = [alphaL,alphaA]
vm_list = [vmL,vmA]
gamma_list = [gammaL,gammaA]
km_list = [KL,KA]
d_list = [dL,dA]
r_list = [rL,rA]


uptake_rates = np.array([0.1, 0.2])
release_rates = np.array([0.05, 0.1])


In [None]:


# Run parallel simulation
S_list_final, final_cells = parallel_simulate_system(
    [S1, S2], [D1, D2], initial_cells, diff,
    dt_diff, dt_cell, uptake_rates, release_rates,
    total_time
)