# Gillespie Algorithm

In [None]:
import matplotlib.pyplot as plt
import numpy as np

---

In [None]:
# +---+         +---+
# | 0 | <------ | 1 | ---+
# +---+ ------> +---+    |
#          ^             |
#          +-----------+ |
#                      | |
#          +-----------|-+
#          |           |
# +---+    V    +---+  |
# | 2 | <------ | 3 | -+
# +---+ ------> +---+
#
# d/dt x0 = -k01 s(x3) x0 + k10 x1
# d/dt x1 =  k01 s(x3) x0 - k10 x1
# d/dt x2 = -k23 x2 + k32 s(x1) x3
# d/dt x3 =  k23 x2 - k32 s(x1) x3
#
# The nonlinear function s modulates transition rate.
#

transitions = [
    {"source": 0, "destination": 1, "rate": 0.7, "input": 3},
    {"source": 1, "destination": 0, "rate": 0.5},
    {"source": 2, "destination": 3, "rate": 0.5},
    {"source": 3, "destination": 2, "rate": 0.7, "input": 1},
]

In [None]:
random = np.random.RandomState(0)

nodes = [
    {"count": 8, "transitions": []},
    {"count": 0, "transitions": []},
    {"count": 8, "transitions": []},
    {"count": 0, "transitions": []},
]

for tr in transitions:
    source = nodes[tr["source"]]
    destination = nodes[tr["destination"]]

    tr_spec = {
        "source": source,
        "destination": destination,
        "rate": tr["rate"],
    }
    if "input" in tr:
        tr_spec["input"] = nodes[tr["input"]]

    source["transitions"].append(tr_spec)

In [None]:
def sigma(x):
    if x >= 1:
        return 1
    else:
        return 0

time = 0

time_history = [time]
count_history = [[node["count"] for node in nodes]]

for _ in range(300):

    # Current total rate
    sum_rate = 0

    for node in nodes:
        rate = 0
        for tr in node["transitions"]:
            mod = 1
            if "input" in tr:
                mod = sigma(tr["input"]["count"])
            rate += mod * tr["rate"] * node["count"]

        sum_rate += rate

    # Choose actual transition
    stop_rate = sum_rate * random.uniform()
    par_rate = 0
    chosen_transition = None

    for node in nodes:
        for tr in node["transitions"]:
            mod = 1
            if "input" in tr:
                mod = sigma(tr["input"]["count"])
            par_rate += mod * tr["rate"] * node["count"]

            if stop_rate < par_rate:
                chosen_transition = tr
                break

        if chosen_transition is not None:
            break

    # Apply
    chosen_transition["source"]["count"] -= 1
    chosen_transition["destination"]["count"] += 1

    dt = -np.log(1 - random.uniform()) / sum_rate
    time += dt

    # Log
    time_history.append(time)
    count_history.append([node["count"] for node in nodes])

time_history = np.array(time_history)
count_history = np.array(count_history)

In [None]:
fig, ax = plt.subplots(figsize=(10, 2))

x = time_history
y = count_history
ax.step(x, y, lw=1)

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

pass