# Question 5 - Generalised Logistic Equation
This notebook is a toy to mess around with the solutions of the logistic equation

Import required libraries

In [None]:
import numpy as np
from scipy.integrate import odeint  # Used to numerically solve the ODE
import matplotlib.pyplot as plt

In [None]:
# Library to allow for interactive plots
from IPython.display import display
from ipywidgets import interact, widgets, interactive

Define the function (right hand side of the eqution)

In [None]:
def logistic_generalised(N, t=0, r=1, a=1, b=1):
    return N * (r - a * (N - b)**2)

We define the equilibria and the stability properties

In [None]:
# zeros of above function
def calculate_equlibria(r=1, a=1, b=1):
    eq1 = 0
    eq2 = b + np.sqrt(r / a)
    eq3 = b - np.sqrt(r / a)
    return (eq1, eq2, eq3)

# Stability of solutions (value of the derivative)
def calculate_stability(eq, r=1, a=1, b=1):
    eq1, eq2, eq3 = eq
    output = (
        r - a * b ** 2,
        -2 * np.sqrt(r / 1) * eq2,
        2 * np.sqrt(r / 1) * eq3
    )
    return output

Lets plot this for different values of $a,~b,~r$

In [None]:
%matplotlib notebook

# Plotting simplified equation
Here we devide the derivative by $N$: $\frac{\dot{N}}{N}$ to result in the traditional logistic equaiton

In [None]:
x = np.linspace(-0.2, 2, 100)

fig, ax = plt.subplots(figsize=(8, 5))
line, = ax.plot(x, logistic_generalised(x))  # Initial plot
ax.set_xlabel('N')
ax.set_ylabel('dN/dt / N')

# set zero line
ax.hlines(0, xmin=-0.2, xmax=2, color='lightgrey')


def update1(r, a, b):
    line.set_ydata(logistic_generalised(x, r=r, a=a, b=b) / x)
    fig.canvas.draw_idle()
    
slid1 = interactive(update1, 
         r=widgets.FloatSlider(min=-1, max=1, step=0.01, value=1), 
         a=widgets.FloatSlider(min=-1, max=1, step=0.01, value=1), 
         b=widgets.FloatSlider(min=-1, max=1, step=0.01, value=1))
display(slid1)

# Plotting full equation

This now plots the equation: $\dot{N} = N (r - 1 (N - b)^2)$

In [None]:
x = np.linspace(-0.2, 2, 100)

fig, ax = plt.subplots(figsize=(8, 5))
line, = ax.plot(x, logistic_generalised(x))  # Initial plot
ax.set_xlabel('N')
ax.set_ylabel('dN/dt')

# set zero line
ax.hlines(0, xmin=-0.2, xmax=2, color='lightgrey')

def update2(r, a, b):
    line.set_ydata(logistic_generalised(x, r=r, a=a, b=b))
    fig.canvas.draw_idle()
    
slid2 = interactive(update2, 
         r=widgets.FloatSlider(min=-1, max=2, step=0.01, value=1), 
         a=widgets.FloatSlider(min=-1, max=2, step=0.01, value=1), 
         b=widgets.FloatSlider(min=-2, max=2, step=0.01, value=1))
display(slid2)

## Plotting the phase space
This plots a bunch of trajectories starting at $t=0$, with $x(t=0)\in[-0.2, 2.2]$, and shows how these initial conditions evolve in time. Note the colours of the trajectories have no meaning and are just to help keep them seperate. The equilibria are shown by grey lines, where the unstable solution is dotted.

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

t = np.linspace(0, 10, 100)

values  = np.linspace(-0.2, 2.2, 20)                      
vcolors = plt.cm.autumn_r(np.linspace(0.3, 1., len(values)))
                          
# Plot trajectories
lines = list()
for v, col in zip(values, vcolors):
    X = odeint(logistic_generalised, v, t, args=(1, 1, 1))
    line1, = ax.plot(t, X, color=col)
    lines.append(line1)
    
# plot Equilibria
lines_eq = list()
eq = calculate_equlibria()
for q in eq:
    line1, = ax.plot(t, q * np.ones(len(t)), color='lightgrey')
    lines_eq.append(line1)

def update3(r, a, b):
    print("r^2-ab: " + str(r ** 2 - a * b))
    for i, v in enumerate(values):
        X = odeint(logistic_generalised, v, t, args=(r, a, b))
        lines[i].set_ydata(X)
    
    eq = calculate_equlibria(r, a, b)
    stab = calculate_stability(eq, r, a, b)
    
    for i, (q, s) in enumerate(zip(eq, stab)):
        if s > 0:
            line_sty = ":"
        else:
            line_sty = "solid"
        lines_eq[i].set_ydata(q * np.ones(len(t)))
        lines_eq[i].set_linestyle(line_sty)
    
    fig.canvas.draw_idle()


plt.xlabel('t')
plt.ylabel('N')

slid3 = interactive(update3, 
         r=widgets.FloatSlider(min=-1, max=1, step=0.01, value=1), 
         a=widgets.FloatSlider(min=-1, max=1, step=0.01, value=1), 
         b=widgets.FloatSlider(min=-2, max=2, step=0.01, value=1))
display(slid3)