In [1]:
import random
import hydra
import torch
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import deque
from scipy.integrate import odeint
from dqn_agent import Agent
from omegaconf import DictConfig, OmegaConf

In [2]:
def sir(y, t, beta, gamma, u):
    S, I = y
    dydt = np.array([-beta * S * I - u * S, beta * S * I - gamma * I])
    return dydt

class SirEnvironment:
    def __init__(self, S0=990, I0=10):
        self.state = np.array([S0, I0])
        self.beta = 0.002
        self.gamma = 0.5

    def reset(self, S0=990, I0=10):
        self.state = np.array([S0, I0])
        self.beta = 0.002
        self.gamma = 0.5
        return self.state

    def step(self, action):
        sol = odeint(sir, self.state, np.linspace(0, 1, 101), args=(self.beta, self.gamma, action))
        new_state = sol[-1, :]
        S0, I0 = self.state
        S, I = new_state
        self.state = new_state
        reward = - I - action
        done = True if new_state[1] < 1.0 else False
        return (new_state, reward, False, 0)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [7]:
# 2. Train DQN Agent
env = SirEnvironment()
agent = Agent(state_size=2, action_size=2, seed=0)

S = np.linspace(0, 1000, 101)
I = np.linspace(0, 1000, 101)

SS, II = np.meshgrid(S, I)

vf = np.zeros((len(I), len(S)))
af = np.zeros((len(I), len(S)))

for si, s in enumerate(S):
    for ii, i in enumerate(I):
        v = agent.qnetwork_local.forward(torch.tensor([float(s), float(i)]).to(device))
        v = v.detach().cpu().numpy()
        vf[si, ii] = np.max(v)
        af[si, ii] = np.argmax(v)

vf[SS + II > 1000] = None
af[SS + II > 1000] = None

In [8]:
import plotly.graph_objects as go

fig = go.Figure(data =
    go.Contour(
        z=-vf,
        x=S,
        y=I
    ))
fig.show()

In [9]:
fig = go.Figure(data =
    go.Contour(
        z=af,
        x=S,
        y=I
    ))
fig.show()

In [10]:

## Parameters
n_episodes=1000
max_t=30
eps_start=1.0
eps_end= 0.0000
eps_decay=0.995

## Loop to learn
scores = []                        # list containing scores from each episode
scores_window = deque(maxlen=100)  # last 100 scores
eps = eps_start                    # initialize epsilon
for i_episode in range(1, n_episodes+1):
    state = env.reset()
    score = 0
    actions = []
    for t in range(max_t):
        action = agent.act(state, eps)
        actions.append(action)
        next_state, reward, done, _ = env.step(action)
        agent.step(state, action, reward, next_state, done)
        state = next_state
        score += reward
        if done:
            break
    scores_window.append(score)       # save most recent score
    scores.append(score)              # save most recent score
    eps = max(eps_end, eps_decay*eps) # decrease epsilon
    print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")

torch.save(agent.qnetwork_local.state_dict(), 'checkpoint.pth')

Episode 1000	Average Score: -82.46

In [11]:
# 3. Visualize Controlled SIR Dynamics
agent.qnetwork_local.load_state_dict(torch.load('checkpoint.pth'))
env = SirEnvironment()
state = env.reset()
max_t = 30
states = state
reward_sum = 0.
actions = []
for t in range(max_t):
    action = agent.act(state, eps=0.0)
    actions = np.append(actions, action)
    next_state, reward, done, _ = env.step(action)
    reward_sum += reward
    states = np.vstack((states, next_state))
    state = next_state

In [12]:
# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Add traces
fig.add_trace(
    go.Scatter(x=list(range(max_t+1)), y=states[:,0].flatten(), name="susceptible",
        mode='lines+markers'),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(x=list(range(max_t+1)), y=states[:,1].flatten(), name="infected",
        mode='lines+markers'),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(x=list(range(max_t+1)), y=actions, name="vaccine",
        mode='lines+markers'),
    secondary_y=True,
)
# Add figure title
fig.update_layout(
    title_text=f'{reward_sum:.2f}: SIR model with control'
)
# Set x-axis title
fig.update_xaxes(title_text="day")
# Set y-axes titles
fig.update_yaxes(title_text="Population", secondary_y=False)
fig.update_yaxes(title_text="Vaccine", secondary_y=True)

In [14]:
S = np.linspace(0, 1000, 101)
I = np.linspace(0, 1000, 101)

SS, II = np.meshgrid(S, I)

vf = np.zeros((len(I), len(S)))
af = np.zeros((len(I), len(S)))

for si, s in enumerate(S):
    for ii, i in enumerate(I):
        v = agent.qnetwork_local.forward(torch.tensor([float(s), float(i)]).to(device))
        v = v.detach().cpu().numpy()
        vf[si, ii] = np.max(v)
        af[si, ii] = np.argmax(v)

vf[SS + II > 1000] = None
af[SS + II > 1000] = None

In [15]:
import plotly.graph_objects as go

fig = go.Figure(data =
    go.Contour(
        z=-vf,
        x=S,
        y=I
    ))
fig.show()

In [16]:
fig = go.Figure(data =
    go.Contour(
        z=af,
        x=S,
        y=I
    ))
fig.show()