In [None]:
from math import log
from random import random
from warnings import warn


class Base(dict):
    """Container for Stochastic Simulation Algorithms"""

    def __init__(
        self,
        initial_conditions,
        propensities,
        stoichiomemtry
    ):
        """Initialize SSA"""
        self.propen = list(propensities.items())
        self.stoich = list(stoichiometry.items())
        super().__init__(**initial_conditions)
        
    def exit(self, *args, **kwargs):
        """Return True if conditions met, else False"""
        raise NotImplementedError
        
    def reset(self, *args, **kwargs):
        """Clean up trajectory on or after exit"""
        raise NotImplementedError
        
    def direct(self):
        """Indefinite generator of direct-method trajectories"""
        while True:
            while not self.exit():
                
                # init step: evaluate propensities and partition
                weights = list((k, v(self)) for k,v in self.propen)
                partition = sum(tup[1] for tup in weights)
                
                # monte carlo step 1: next reaction time
                sojourn = log(1.0 / random()) / partition
                self["time"].append(
                    self["time"][-1] + sojourn
                )
                
                # monte carlo step 2: next reaction
                partition = partition * random()
                j = len(weights) - 1
                while partition >= 0.0:
                    partition -= weights.pop()[1]
                    j -= 1
                reaction_stoich = self.stoich[j][1]
                
                # final step: update reaction species
                for species, delta in reaction_stoich.items():
                    self[species].append(
                        self[species][-1] + delta
                    )
                print(self)
            yield self.items()
            self.reset()
            
    def first_reaction(self):
        """Indefinite generator of 1st-reaction trajectories"""
        while True:
            while not self.exit():

                # monte carlo step: generate reaction times
                times = list(
                    (k,  log(1.0 / random()) / v(model))
                    for k,v in self.propen
                ).sort(key=lambda t: t[1])

                # update next reaction time
                model["time"].append(times[0][1])

                # update reaction species
                reaction_stoich = self.stoich[times[0][0]]
                for species, delta in reaction_stoich:
                    self[species] += delta
                print(self)
            yield self.items()
            self.reset()
            

In [None]:
class Epidemic(Base):
    """Epidemic without vital dynamics"""
    
    def exit(self):
        if self["s"][-1] == 0 == self["i"][-1]:
            return True
        elif self["s"][-1] != 0 == self["i"][-1]:
            self.reset(tail=True)
            return False
        return False
            
    def reset(self, tail=False):
        if tail:
            for key in self.keys():
                del self[key][-1]
        for key in self.keys():
            del self[key][1:]
            

In [None]:
# initial species counts and sojourn times
initital_conditions = {
    "s": [299],
    "i": [1],
    "r": [0],
    "time": [0.0],
}


# propensity functions
propensities = {
    0: lambda d: 10.0 * d["s"][-1] * d["i"][-1] / 300,
    1: lambda d: 10.0 * d["i"][-1],
}


# change in species for each propensity
stoichiometry = {
    0: {"s": -1, "i": 1, "r": 0},
    1: {"s": 0, "i": -1, "r": 1},
}


In [None]:
from matplotlib import pyplot, rcParams


# make figure 10" x 3", 200 dots per inch
rcParams["figure.figsize"] = 10, 3
rcParams["figure.dpi"] = 200


# instantiate figure and axes
figure, axes = pyplot.subplots(1, 3)


# append trajectories to plot
epidemic = Epidemic(
    initital_conditions,
    propensities,
    stoichiometry
)
trajectories = 0
for trajectory in epidemic.direct():
    trajectories += 1
    # exit condition       
#     # do plotting
#     for k,v in trajectory.items():
#         if k == "time":
#             continue
#         axes.plot(trajectory["time"], v)
    if trajectories == 1:
        break
        