In [None]:
import matplotlib.pyplot as plt
import numpy as np
import hjb
import hjb_nocost
from matplotlib import animation
from IPython.display import HTML

from scipy.interpolate import interpn
from scipy.integrate import solve_ivp

In [None]:
mx = 100
my = 100
claw = hjb.setup(solver_type='sharpclaw',mx=mx,my=my,min_fac=0.4)
claw.tfinal = 100.
claw.num_output_times = 200
claw.solver.cfl_max=1.49
claw.solver.cfl_desired = 1.48
claw.solver.lim_type=1
#claw.solver.order = 2
claw.solver.call_before_step_each_stage = True
claw.run()

In [None]:
i=20
switch = claw.frames[i].aux[2,:,:]
X,Y = claw.frames[0].grid.p_centers

plt.pcolor(X,Y,switch,shading='auto')
plt.colorbar()

In [None]:
fig, axes = plt.subplots(1,2,figsize=(12,8))
u = claw.frames[0].q[0,:,:]
sigma = claw.frames[0].aux[2,:,:]

pc = axes[0].pcolor(X,Y,u,shading='auto', vmin=-0.25,vmax=0)
#ax.clim(-.4,0)
#ax = plt.gca()
plt.colorbar(pc, ax=axes[0])
axes[0].axis('image')
pcs = axes[1].pcolor(X,Y,sigma,shading='auto', vmin=-0.,vmax=3.)
plt.colorbar(pcs, ax=axes[1])
axes[1].axis('image')

def plot_frame(i):
    u = claw.frames[i].q[0,:,:]
    pc = axes[0].pcolor(X,Y,u,shading='auto',vmin=-0.4,vmax=0)
    sigma = claw.frames[i].aux[2,:,:]
    pc = axes[1].pcolor(X,Y,sigma,shading='auto',vmin=-0.,vmax=3)

In [None]:
anim = animation.FuncAnimation(fig, plot_frame, frames=range(len(claw.frames)))
HTML(anim.to_jshtml())

# Foward solve with optimal control

In [None]:
sigmaopt = np.zeros((mx,my,claw.num_output_times+1))
ttt = np.zeros(claw.num_output_times+1)
for i in range(claw.num_output_times+1):
    sigmaopt[:,:,i] = claw.frames[i].aux[2,:,:]
    ttt[i] = claw.frames[i].t

In [None]:
x = claw.grid.x.centers
y = claw.grid.y.centers

In [None]:
T = claw.tfinal
gamma = 0.1
beta = 0.3

def rhs(t, v):
    # Variables: x, y, lambda_1, lambda_2
    dv = np.zeros(2)
    if v[1]<np.min(y):
        sigma = 0.
    else:
        sigma = interpn((x.squeeze(),y.squeeze(),ttt.squeeze()),sigmaopt,(v[0],v[1],T-t))
    #print(t,sigma)

    dv[0] = -sigma*gamma*v[1]*v[0]
    dv[1] =  sigma*gamma*v[1]*v[0] - gamma*v[1]

    return dv

y0 = 0.01 # Initial infected
x0 = 0.99
v0 = np.array((x0,y0))
times = np.arange(0,T)

solution = solve_ivp(rhs,[0,T],v0,t_eval=times,method='RK23',rtol=1.e-3,atol=1.e-3,max_step=5e-2)

In [None]:
xsol = solution.y[0,:]
ysol = solution.y[1,:]
plt.figure(figsize=(12,6))
plt.plot(times,xsol,lw=3);
plt.plot(times,ysol,lw=3);
plt.legend(['x','y','sigma/sigma0'])

In [None]:
sigma0 = beta/gamma
N1 = 10; N2=5
Y, X = np.mgrid[0:1:100j, 0:1:100j]
U = -beta*X*Y
V = beta*X*Y - gamma*Y
x_points = list(np.linspace(0,1,N1)) + list(np.linspace(1./sigma0,1,N2))
y_points = list(1.-np.linspace(0,1,N1)) + [1.e-6]*N2
seed_points = np.array([x_points, y_points])

plt.figure(figsize=(6,6))
plt.streamplot(X, Y, U, V, start_points=seed_points.T,integration_direction='forward',maxlength=1000,
               broken_streamlines=False,linewidth=1)
plt.plot([0,1],[1,0],'-k',alpha=0.5)
plt.plot(xsol,ysol,'-k')
plt.plot([gamma/beta, gamma/beta],[0,1-gamma/beta],'--k',alpha=0.5)
plt.xlim(0,1); plt.ylim(0,1);
plt.xlabel('x'); plt.ylabel('y');