In [None]:
import numpy as np
import matplotlib.pyplot as plt

def gillespie_2_type_pop(Nit, fA, gA, NA0, fB, gB, NB0, K, tlist, alpha, beta, gamma):
    NAlist = np.full((len(tlist), Nit), np.nan)
    NBlist = np.full((len(tlist), Nit), np.nan)

    for i in range(Nit):
        # Initialization
        NA = NA0  # Initialization of the number of A individuals
        NB = NB0  # Initialization of the number of B individuals
        t = 0     # Initialization of time
        q = 1     # Index for the vector of time points
        NAlist[q - 1, i] = NA  # Vector of the population size of A individuals
        NBlist[q - 1, i] = NB  # Vector of the population size of B individuals
        q += 1
        cumul = np.zeros(4)  # To build the sampling tower

        while (NA != 0 and NB != 0) and t < max(tlist):
            # Compute the transition rates acc to the growth function used
            TrepA = fA * ((NA+NB)/K) * NA
            TrepB = fB * ((NA+NB)/K)* NB
            TdeathA = gA * NA
            TdeathB = gB * NB
            T = TrepA + TrepB + TdeathA + TdeathB

            # Update time
            r1 = np.random.rand()
            tau = 1 / T * np.log(1 / r1)
            t = t + tau

            while t > tlist[q - 1] and q < len(tlist):
                NAlist[q - 1, i] = NA
                NBlist[q - 1, i] = NB
                q += 1

            # Build a sampling tower
            ir2 = 1
            r2 = np.random.rand()
            cumul[0] = TrepA
            cumul[1] = TrepA + TrepB
            cumul[2] = TrepA + TrepB + TdeathA
            cumul[3] = TrepA + TrepB + TdeathA + TdeathB

            # Determine which reaction occurs and update the number of individuals
            while cumul[ir2 - 1] < r2 * T:
                ir2 += 1

            if ir2 == 1:
                NA += 1
            elif ir2 == 2:
                NB += 1
            elif ir2 == 3:
                NA -= 1
            elif ir2 == 4:
                NB -= 1

        while q <= len(tlist):
            NAlist[q - 1, i] = NA
            NBlist[q - 1, i] = NB
            q += 1

    return NAlist, NBlist

# Set parameters
Nit = 100
fA, gA, NA0 = 1, 0.1, 90
fB, gB, NB0 = 1.01, 0.1, 1
alpha=1
beta=2
gamma=1.5
K = 100
tlist = np.linspace(0, 100, 100)

# Run the Python function
NAlist, NBlist = gillespie_2_type_pop(Nit, fA, gA, NA0, fB, gB, NB0, K, tlist, alpha, beta, gamma)

# Plot the results
for i in range(Nit):
    color = plt.cm.viridis(i / Nit)  # Choosing a color based on the trajectory index
    plt.plot(tlist, NAlist[:, i], label=f'Run {i+1} - Type A', color=color, alpha=0.5)
    plt.plot(tlist, NBlist[:, i], label=f'Run {i+1} - Type B', color=color, alpha=0.5)

plt.xlabel('Time')
plt.ylabel('Population Size')
plt.show()

survival_count_B = np.sum(NBlist[-1, :] > 0)
print(f"Number of times Type B survives: {survival_count_B} out of {Nit} simulations.")

# Vary fB values
f=fB_values = np.linspace(1.01, 2, 20)
survival_counts = []
functional3_values = []
functional_values = []
functional2_values =[]
fB_values_2 = np.linspace(1.01, 2, 100)
MeanErrors=[]
RMSErrors=[]
for fB in fB_values:
    # Run the Python function
    NAlist, NBlist = gillespie_2_type_pop(Nit, fA, gA, NA0, fB, gB, NB0, K, tlist, alpha,beta, gamma)

    # Count the number of times Type B survives
    survival_count_B = np.sum(NBlist[-1, :] > 0)/1000
    survival_counts.append(survival_count_B)

    Keff=K*(1-gA/fA) #Change Keff acc to population function used

    s = 1 - fA / fB

    functional_value = (1 - fA * gB / (fB * gA)) / (1 - (fA * gB / (fB * gA)) ** Keff)
    functional_values.append(functional_value)

    functional2_value = s
    functional2_values.append(functional2_value)

    functional3_value = (1 - np.exp(- 2*s)) / (1 - np.exp(-2 * Keff * s))
    functional3_values.append(functional3_value)

    MeanError= (functional_value-survival_count_B)/survival_count_B
    MeanErrors.append(MeanError)

    RMSError= np.sqrt(functional_value-survival_count_B)**2/survival_count_B
    RMSErrors.append(RMSError)
# Plot both on the same graph
plt.figure(figsize=(10, 6))

# Plot survival_count_B and function_values against fB on the same y-axis
plt.plot(fB_values, survival_counts, marker='o', color='#8a2be2', label='Empirical')
plt.plot(fB_values, functional_values, color='#4e0707', label='Predicted' )



plt.xlabel('Fitness of mutant')
plt.ylabel('Fixation Probability')
plt.ylim (0,1)
plt.title('Estimating fixation probability')
plt.legend()
plt.show()

print (np.sum(MeanErrors)/20)
print (np.sum (RMSErrors)/20)