In [65]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '../src/')
import utils as ut
import plotter
import wave2 as w2
import wave as w
import pickle
from scipy.optimize import brentq

In [2]:
def build_assay_wave(id_, r0, D, beta, gamma=1.0):
    """
    It generates a single wave object with a good parameter tuning.
    """
    dx = ut.lin_from_two_points(np.log(r0), np.log(0.01), 0.0005, np.log(1), 0.005)
    dt = ut.dt_from_cfl(0.03, D, dx)

    p = w.Vwave_pars(id_, tot_time=200, dt=dt, dx=dx, 
                     n_x_bins=int( ut.lin_from_two_points(r0, 0.01, 3, 1, 50)/dx ), 
                     M=5, r0=r0, D_coef=D, beta=beta, alpha=0.0, gamma=gamma,
                     Nh=10**12, N0=10**10,
                     t_burn=ut.lin_from_two_points(r0, 0.01, 200, 1, 5000), 
                     t_burn_cutoff=ut.lin_from_two_points(r0, 0.01, 50, 1, 1000), 
                     back_width_fract=ut.lin_from_two_points(np.log(r0), np.log(0.01), 0.5, np.log(1), 3),
                     traj_step=int(5/dt), check_step=int(100/dt), traj_after_burn=True, verbose=False)
    return w.Vwave(p)

def build_assay_wave2(id_, r0, D, Dm, beta, beta_m, gamma=1.0, gamma_m=1.0, eps=1e-3, time=400, init_n=None, init_nh=None):
    """
    It generates two competing waves with a good parameter tuning
    """
    dx = ut.lin_from_two_points(np.log(r0), np.log(0.01), 0.0005, np.log(1), 0.005)
    dt = ut.dt_from_cfl(0.03, D, dx)
    
    t_burn, t_burn_cutoff = 0, 0
    if type(init_n) != np.ndarray:
        t_burn = ut.lin_from_two_points(r0, 0.01, 100, 1, 5000)
        t_burn_cutoff = ut.lin_from_two_points(r0, 0.01, 50, 1, 1000)

    p = w2.Vwave2_pars(id_, tot_time=time, dt=dt, dx=dx, 
                     n_x_bins=int( ut.lin_from_two_points(r0, 0.01, 3, 1, 50)/dx ), 
                     M=5, r0=r0,
                     D_coef=D, beta=beta, alpha=0.0, gamma=gamma,
                     D_coef_m=D_m, beta_m=beta_m, alpha_m=0.0, gamma_m=gamma,
                     is_flux=False, eps=eps,
                     Nh=10**12, N0=10**10, t_burn=t_burn, t_burn_cutoff=t_burn_cutoff, 
                     back_width_fract=ut.lin_from_two_points(np.log(r0), np.log(0.01), 0.5, np.log(1), 3),
                     traj_step=int(5/dt), check_step=-1, traj_after_burn=True, verbose=False)
    return w2.Vwave2(p, init_n=init_n, init_nh=init_nh)

def compute_obs(wave):
    speed = np.mean(wave.traj.speed(10)[0:-2])
    f_tip = np.mean(wave.traj.f_tip[0:-2])
    s_tip = np.mean(wave.traj.s_tip[0:-2])
    P_tip = np.mean(wave.traj.P_tip[0:-2])
    dP_tip = np.mean(wave.traj.dP_tip[0:-2])
    return speed, f_tip, s_tip, P_tip, dP_tip

## Finding the zero of the relation through bisection

In [3]:
class monotonic_bisection:
    """
    Bisection method to find the value of beta at the threshold between invasion and not invasion
    """
    
    def __init__(self, start_val, start_delta, precision, function, threshold=0.5):
        self.start_val = start_val
        self.start_delta = start_delta
        self.precision = precision
        self.function = function
        self.threshold = threshold
        self.max_iter = 50
        
        
    def run(self):
        
        self.beta = self.start_val
        self.f = self.function(self.beta)
        if self.f > self.threshold:
            print('Bisection initialized where the function is larger than the threshold')
            return self
        
        delta = self.start_delta
        self.beta += delta
        count = 0
        was_backward = False
        
        while True:
            
            self.f = self.function(self.beta)
            #print(count, self.beta, self.f)
            
            if self.f >= self.threshold:
                delta *= 0.5
                if abs(delta) < self.precision: 
                    self.final_beta = self.beta - 3*delta/2.0
                    break
                self.beta -= delta
                was_backward = True
            else:
                if was_backward:
                    delta *= 0.5
                    if abs(delta) < self.precision: 
                        self.final_beta = self.beta + delta/2.0
                        break
                self.beta += delta
                was_backward = False
            
            count+=1
            if count > self.max_iter:
                print('Bisection reached max iterations')
                return self
        
        #print(self.final_beta)
        return self

def get_mutant_final_fract(r0, D_res, D_m, beta_res, b_m, eps, max_iterations=30):
    """
    Running a competition of waves and computing the ratio of the two populations at the
    end. If the ratio is not close to one (succesfull invasion) or close to zero (failed invasion)
    the experiment goes on until the invasion fails or succeds.~
    """
    wave2 = build_assay_wave2(0, r0, D_res, D_m, beta_res, b_m, time=100, eps=eps)
    wave2.run()
    
    i = 0
    while wave2.traj.n_mut[-1] > wave2.traj.nvirus[-1]*eps/10 and wave2.traj.n_res[-1] > wave2.traj.nvirus[-1]*eps/10:
        #print(wave2.traj.n_mut[-1]/wave2.traj.nvirus[-1], wave2.traj.n_res[-1]/wave2.traj.nvirus[-1], i)
        wave2 = build_assay_wave2(0, r0, D_res, D_m, beta_res, b_m, time=200, eps=eps, init_n=wave2.n2, init_nh=wave2.nh)
        wave2.run()
        
        i+=1
        if i > max_iterations:
            print('Too many iterations', wave2.traj.n_mut[-1]/wave2.traj.nvirus[-1])
            break
            
    return wave2.traj.n_mut[-1]/wave2.traj.nvirus[-1]

In [75]:
D_res = 5*1e-6
D_m_list = np.linspace(4, 6, 30)*1e-6
r0 = 0.02
beta_res = 2
eps = 0.05

### Running an isolated wave of the resident to compute its properties

In [76]:
wave = build_assay_wave(0, r0, D_res, beta_res)
wave.run()
speed_res, f_tip_res, s_tip_res, P_tip, dP_tip = compute_obs(wave)

### Running the competition

In [None]:
beta_zeros = []
for D_m in D_m_list:
    f = lambda x : get_mutant_final_fract(r0, D_res, D_m, beta_res, x, eps)
    m = monotonic_bisection(1.902314, 0.05, 0.001, f).run()
    beta_zeros.append(m.final_beta)
    print(D_m, m.final_beta)

In [19]:
#f = open('data/invasion_diagram_r0=%g.pickle'%r0, 'wb')
#pickle.dump(beta_zeros, f)
#f.close()

In [77]:
# importing the assays without running everything again
f = open('data/invasion_diagram_r0=%g.pickle'%r0, 'rb')
beta_zeros = pickle.load(f)
f.close()

### Finding the critical beta through the equation

In [78]:
def ineq(D_m, b_res, b_m, f_tip_res, s_tip_res, v, alpha=0, gamma=1):
    f_m = b_m/b_res*(f_tip_res+alpha+gamma) - alpha - gamma
    s_m = b_m/b_res*s_tip_res
    return f_m - v**2/4/D_m - 2.3381*(D_m*s_m**2)**(1/3)

In [79]:
b_m_list = np.array([])
for D_m in D_m_list[:len(beta_zeros)]:
    f = lambda x : ineq(D_m, beta_res, x, f_tip_res, s_tip_res, speed_res)
    b_m_list = np.append(b_m_list, brentq(f, 1.7, 2.3))

In [80]:
f = open('data/invasion_diagram_r0=%g.tsv'%r0, 'w')

header = '#D_mutant\tbeta_critical\tbeta_theory\n'
f.write(header)

for i in range(len(beta_zeros)):
    f.write(str(D_m_list[i]) + '\t')
    f.write(str(beta_zeros[i]) + '\t')
    f.write(str(b_m_list[i]) + '\t')
        
    f.write('\n')
    
f.close()