In [None]:
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Optional

@dataclass
class NutrientProfile:
    """Class to hold nutrient-related parameters"""
    alpha: float  # nutrient required for reproduction (fmole)
    r: float     # Maximum growth rate (hr^-1)
    vm: float    # maximum uptake rate (fmole/hr)
    Km: float    # Monod constant (fmole/ml)
    gamma: float # release rate (fmole/(cell*hr))
    d: float     # death rate (hr^-1)
    
    def __post_init__(self):
        self.vm_s = self.vm / (60**2)  # maximum uptake rate (fmole/s)
        self.gamma_s = self.gamma / (60**2)  # release rate (fmole/(cell*s))

class Strain:
    """Class to represent a microbial strain and its properties"""
    def __init__(self, 
                 strain_id: int,
                 grid_shape: tuple,
                 n_nutrients: int,
                 dependencies: np.ndarray,
                 nutrient_profiles: List[NutrientProfile]):
        self.strain_id = strain_id
        self.dependencies = dependencies[strain_id]
        self.S_internal = np.zeros(grid_shape)  # Internal nutrient concentrations
        self.nutrient_profiles = nutrient_profiles
        
    def should_uptake_nutrient(self, nutrient_id: int) -> bool:
        """Check if strain uptakes specific nutrient"""
        return bool(self.dependencies[nutrient_id, 0])
    
    def should_release_nutrient(self, nutrient_id: int) -> bool:
        """Check if strain releases specific nutrient"""
        return bool(self.dependencies[nutrient_id, 1])
    
    def get_death_probability(self, nutrient_id: int, dt: float) -> float:
        """Calculate death probability for given timestep"""
        return self.nutrient_profiles[nutrient_id].d * dt
    
    def get_division_probability(self, nutrient_id: int, dt: float) -> float:
        """Calculate division probability for given timestep"""
        return self.nutrient_profiles[nutrient_id].r * dt
    
    def check_nutrient_requirements(self, coords: tuple) -> bool:
        """Check if all nutrient requirements are met for division"""
        x, y, z = coords
        for nutrient_id in range(len(self.nutrient_profiles)):
            if self.should_uptake_nutrient(nutrient_id):
                if self.S_internal[x, y, z] < self.nutrient_profiles[nutrient_id].alpha:
                    return False
        return True

class StrainSimulation:
    """Class to manage the simulation of multiple strains"""
    def __init__(self, 
                 grid_shape: tuple,
                 n_strains: int,
                 n_nutrients: int,
                 dependencies: np.ndarray):
        self.grid_shape = grid_shape
        self.n_strains = n_strains
        self.n_nutrients = n_nutrients
        self.grid = np.zeros(grid_shape)
        
        # Initialize nutrient profiles
        self.nutrient_profiles = self._create_nutrient_profiles()
        
        # Initialize strains
        self.strains = [
            Strain(i, grid_shape, n_nutrients, dependencies, self.nutrient_profiles)
            for i in range(n_strains)
        ]
        
        # Initialize external nutrient concentrations
        self.S_external = [np.zeros(grid_shape) for _ in range(n_nutrients)]
    
    def _create_nutrient_profiles(self) -> List[NutrientProfile]:
        """Create nutrient profiles with default values (from original code)"""
        # Lysine profile
        lys = NutrientProfile(
            alpha=5.4,
            r=0.51,
            vm=5.4 * 0.51,  # alphaL * rL
            Km=2.1e6,
            gamma=0.4,
            d=0.021
        )
        
        # Adenine profile
        ade = NutrientProfile(
            alpha=3.1,
            r=0.44,
            vm=3.1 * 0.44,  # alphaA * rA
            Km=1.3e6,
            gamma=0.26,
            d=0.015
        )
        
        return [lys, ade]
    
    def update_cell(self, coords: tuple, strain_id: int, dt: float) -> Optional[tuple]:
        """Update single cell state and return changes if any"""
        x, y, z = coords
        strain = self.strains[strain_id - 1]  # strain_id is 1-based in grid
        
        # Check death
        if random.random() < strain.get_death_probability(0, dt):
            return ((x, y, z), strain_id - 0.5)
            
        # Check division
        n_requirements = sum(strain.dependencies[:, 0])
        if n_requirements == 0:
            if random.random() < strain.get_division_probability(0, dt):
                return 'bud', (x, y, z)
        elif strain.check_nutrient_requirements((x, y, z)):
            return 'bud', (x, y, z)
            
        return None

    def process_cell_fates_parallel(self, dt: float, n_jobs: int = -1):
        """Process all cell fates in parallel"""
        x, y, z = np.where(self.grid != 0)
        number = len(x)
        indices = np.arange(number)
        np.random.shuffle(indices)
        
        # Initialize arrays for parallel processing
        need_to_bud = np.zeros((number, 2))
        
        # Prepare argument lists
        args = [
            (i, self.grid, indices, x, y, z, dt, need_to_bud, self.strains) 
            for i in range(number)
        ]
        
        # Run parallel processing
        results = Parallel(n_jobs=n_jobs)(
            delayed(self._cell_fate_worker)(arg) for arg in args
        )
        
        # Process results
        for result in results:
            if result is None:
                continue
            
            if isinstance(result[0], str) and result[0] == 'bud':
                _, coords = result
                self._handle_budding(coords)
            else:
                coords, new_value = result
                self.grid[coords] = new_value
    
    @staticmethod
    def _cell_fate_worker(args):
        """Worker function for parallel processing"""
        i, grid, indices, x, y, z, dt, need_to_bud, strains = args
        coord = (x[indices[i]], y[indices[i]], z[indices[i]])
        strain_id = int(grid[coord])
        
        if strain_id == 0 or strain_id % 1 != 0:
            return None
            
        return strains[strain_id - 1].update_cell(coord, strain_id, dt)
    
    def _handle_budding(self, coords: tuple):
        """Handle the budding process for a cell"""
        x, y, z = coords
        strain_id = int(self.grid[x, y, z])
        
        xes = np.zeros(9)
        yes = np.zeros(9)

        xes[:3] = periodic_image(x - 1,shp[0])
        xes[3:6] = periodic_image(x,shp[0])
        xes[6:] = periodic_image(x + 1,shp[0])

        yes[0::3] = periodic_image(y - 1,shp[1])
        yes[1::3] = periodic_image(y,shp[1])
        yes[2::3] = periodic_image(y + 1,shp[1])


        poss_xvals = xes[X[xes.astype('int'),yes.astype('int'),z]==0]
        poss_yvals = yes[X[xes.astype('int'),yes.astype('int'),z]==0]

        n_possibilities = np.size(poss_xvals)
        floor = 0

        success = True


        if n_possibilities == 0:
            #print('trying one floor up')

            new_poss_xvals = xes[X[xes.astype('int'),yes.astype('int'),int(coords[2] + 1)]==0]
            new_poss_yvals = yes[X[xes.astype('int'),yes.astype('int'),int(coords[2] + 1)]==0]

            n_new_possibilities = np.size(new_poss_xvals)

            if n_new_possibilities == 0:
                #print('no available space, cannot bud')
                success = False
            else:
                poss_xvals = new_poss_xvals
                poss_yvals = new_poss_yvals
                n_possibilities = n_new_possibilities
                floor = 1

        if success:
            #print('Budding!')
            index_to_bud = int(np.floor(np.random.rand()*n_possibilities))

            X[int(poss_xvals[index_to_bud]),int(poss_yvals[index_to_bud]),int(coords[2]+floor)] = typ
