# Setup Google drive

Authorize access to google drive.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
drive_path = '/content/gdrive/My Drive/singapore/'
# comment above three lines if you don't run it on Colab
# and change the drive_path to the path with data

# Install library

Install covasim/optuna/plotly library.

In [None]:
!pip install covasim -q
!pip install optuna -q
!pip install kaleido -q
!pip install plotly==5.3.1 -q

In [None]:
import covasim as cv
import numpy as np
import pandas as pd
import sciris as sc
import optuna as ot
import multiprocessing as mp
import matplotlib.pyplot as plt
import math
import subprocess
import plotly.io as pio

# Install MySQL

In [None]:
!apt-get install mysql-server > /dev/null

In [None]:
!pip install mysql-python -q
!pip install mysqlclient -q

# Prepare data for covasim

In [None]:
reduction = 0.01
end_date = pd.to_datetime('2021-10-25')

sg_data = pd.read_csv(drive_path + 'sg_final.csv', parse_dates=['Date'])

# mask = (sg_data['Date'] >= start_date) & (sg_data['Date'] <= end_date)
sg_for_covasim = sg_data[['Date', 'Daily Confirmed', 'Daily Deaths', 'residential', 'facial_coverings', 'Cumulative Vaccine Doses']].copy()

sg_for_covasim['Daily Confirmed'] = sg_data['Daily Confirmed'] - sg_data['Daily Imported']
sg_for_covasim['residential'] = -sg_for_covasim['residential']
sg_for_covasim['facial_coverings'] = -sg_for_covasim['facial_coverings'].astype('float')

sg_for_covasim = sg_for_covasim.rename({'Date': 'date',
                                        'Daily Confirmed': 'new_diagnoses',
                                        'Daily Deaths': 'new_deaths',
                                        'residential': 'mobility',
                                        'facial_coverings': 'face_cover',
                                        'Cumulative Vaccine Doses': 'vaccine'},
                                       axis='columns')
sg_for_covasim['vaccine'] = (np.round(reduction * sg_for_covasim['vaccine'])).astype('int64')
sg_for_covasim['vaccine']= [sg_for_covasim['vaccine'].to_numpy()[0]] + list(
    sg_for_covasim['vaccine'].to_numpy()[1:] - sg_for_covasim['vaccine'].to_numpy()[:-1])
sg_for_covasim['new_diagnoses'] = (reduction * sg_for_covasim['new_diagnoses']).astype('int64')
sg_for_covasim['new_deaths'] = (reduction * sg_for_covasim['new_deaths']).astype('int64')

start_date = sg_for_covasim[sg_for_covasim['vaccine'] >= 1]['date'].min()
mask = (sg_for_covasim['date'] >= start_date) & (sg_for_covasim['date'] <= end_date)


datafile = sg_for_covasim[mask]

# Define interventions

In [None]:
class Interventions:

    def __init__(self):
        self.interventions = []

    def add_mobility(self, date:pd.Series, mobility:pd.Series, mo_max=1.0, mo_min=0.0,
                     mask:pd.Series=None, ma_max=1.0, ma_min=0.0):
        mobility = self._normalize(mobility, mo_max, mo_min)
        if mask is not None:
            mask = self._normalize(mask, ma_max, ma_min)
            _min = np.minimum(mask, mobility)
            _max = np.maximum(mask, mobility)
            mobility = 2 * _min * _max / (_min + _max)
            # mobility = np.power(0.5 * np.power(mobility, 0.5) + 0.5 * np.power(mask, 0.5), 2)
        self.interventions.append(cv.change_beta(date.to_list(), mobility.to_list()))

    def add_covid_test(self, symp_prob=0.95, asymp_prob=0.0, test_delay=1):
        self.interventions.append(cv.test_prob(symp_prob=symp_prob, asymp_prob=asymp_prob,
                                               symp_quar_prob=1.0, asymp_quar_prob=1.0,
                                               test_delay=test_delay))

    def add_contact_tracing(self, trace_probs=0.8):
        self.interventions.append(cv.contact_tracing(trace_probs=trace_probs))

    def add_vaccine(self, date:pd.Series, dose:pd.Series):
        dose = dict(zip(date, dose))
        def age_sequence(people):
            # sequence = np.arange(len(people.age))
            # np.random.shuffle(sequence)
            # return sequence
            return np.argsort(-people.age)
        self.interventions.append(cv.vaccinate_num(vaccine='pfizer', num_doses=dose, sequence=age_sequence))

    def _normalize(self, data:pd.Series, n_max, n_min):
        d_min, d_max = np.min(data), np.max(data)
        if d_max == d_min:
            data = n_min
        else:
            k = (n_max - n_min) / (d_max - d_min)
            b = n_min - k * d_min
            data = data * k + b
        return data

# Define pruner

from https://gist.github.com/bfs15/24045ab5e8ad007b4a09f708adfe359f

In [None]:
from typing import Dict, List, Optional
from collections import defaultdict

class ParamRepeatPruner:
    """Prunes reapeated trials, which means trials with the same paramters won't waste time/resources."""

    def __init__(
        self,
        study: ot.study.Study,
        repeats_max: int = 0,
        should_compare_states: List[ot.trial.TrialState] = [ot.trial.TrialState.COMPLETE],
        compare_unfinished: bool = True,
    ):
        """
        Args:
            study (ot.study.Study): Study of the trials.
            repeats_max (int, optional): Instead of prunning all of them (not repeating trials at all, repeats_max=0) you can choose to repeat them up to a certain number of times, useful if your optimization function is not deterministic and gives slightly different results for the same params. Defaults to 0.
            should_compare_states (List[ot.trial.TrialState], optional): By default it only skips the trial if the paremeters are equal to existing COMPLETE trials, so it repeats possible existing FAILed and PRUNED trials. If you also want to skip these trials then use [ot.trial.TrialState.COMPLETE,ot.trial.TrialState.FAIL,ot.trial.TrialState.PRUNED] for example. Defaults to [ot.trial.TrialState.COMPLETE].
            compare_unfinished (bool, optional): Unfinished trials (e.g. `RUNNING`) are treated like COMPLETE ones, if you don't want this behavior change this to False. Defaults to True.
        """
        self.should_compare_states = should_compare_states
        self.repeats_max = repeats_max
        self.repeats: Dict[int, List[int]] = defaultdict(lambda: [], {})
        self.unfinished_repeats: Dict[int, List[int]] = defaultdict(lambda: [], {})
        self.compare_unfinished = compare_unfinished
        self._study = study

    @property
    def study(self) -> Optional[ot.study.Study]:
        return self._study

    @study.setter
    def study(self, study):
        self._study = study
        if self._study is not None:
            self.register_existing_trials()

    def register_existing_trials(self):
        """In case of studies with existing trials, it counts existing repeats"""
        trials = self._study.trials
        trial_n = len(trials)
        for trial_idx, trial_past in enumerate(study.trials[1:]):
            self.check_params(trial_past, False, -trial_n + trial_idx)

    def prune(self):
        self.check_params()

    def should_compare(self, state):
        return any(state == state_comp for state_comp in self.should_compare_states)

    def clean_unfinised_trials(self):
        trials = self._study.trials
        finished = []
        for key, value in self.unfinished_repeats.items():
            if self.should_compare(trials[key].state):
                for t in value:
                    self.repeats[key].append(t)
                finished.append(key)

        for f in finished:
            del self.unfinished_repeats[f]

    def check_params(
        self,
        trial: Optional[ot.trial.BaseTrial] = None,
        prune_existing=True,
        ignore_last_trial: Optional[int] = None,
    ):
        if self._study is None:
            return
        trials = self._study.trials
        if trial is None:
            trial = trials[-1]
            ignore_last_trial = -1

        self.clean_unfinised_trials()

        self.repeated_idx = -1
        self.repeated_number = -1
        for idx_p, trial_past in enumerate(trials[:ignore_last_trial]):
            should_compare = self.should_compare(trial_past.state)
            should_compare |= (
                self.compare_unfinished and not trial_past.state.is_finished()
            )
            if should_compare and trial.params == trial_past.params:
                if not trial_past.state.is_finished():
                    self.unfinished_repeats[trial_past.number].append(trial.number)
                    continue
                self.repeated_idx = idx_p
                self.repeated_number = trial_past.number
                break

        if self.repeated_number > -1:
            self.repeats[self.repeated_number].append(trial.number)
        if len(self.repeats[self.repeated_number]) > self.repeats_max:
            if prune_existing:
                raise ot.exceptions.TrialPruned()

        return self.repeated_number

    def get_value_of_repeats(
        self, repeated_number: int, func=lambda value_list: np.mean(value_list)
    ):
        if self._study is None:
            raise ValueError("No study registered.")
        trials = self._study.trials
        values = (
            trials[repeated_number].value,
            *(
                trials[tn].value
                for tn in self.repeats[repeated_number]
                if trials[tn].value is not None
            ),
        )
        return func(values)

# Define calibration

In [None]:
class Calibration:
    def __init__(self, n_trials=20, n_workers=mp.cpu_count(),
                 study_name=None, datafile=None):
        self.n_trials = n_trials
        self.n_workers = n_workers
        self.study_name = study_name
        self.storage = 'mysql://root@localhost/{}'.format(self.study_name)
        self._has_calib = False
        self.datafile = datafile
        self.study = None

    def create_calib(self):
        subprocess.run(args='service mysql start', shell=True)
        subprocess.run(args='mysql -u root -e "DROP DATABASE IF EXISTS {}"'.format(self.study_name), shell=True)
        subprocess.run(args='mysql -u root -e "CREATE DATABASE IF NOT EXISTS {}"'.format(self.study_name), shell=True)
        self.study = ot.create_study(storage=self.storage, study_name=self.study_name, direction='minimize',
                                     sampler=ot.samplers.TPESampler(seed=42))
        subprocess.run(args='mysqldump -u root ' + self.study_name + ' > ' + drive_path.replace(' ', '\ ') + self.study_name + '.sql',
                       shell=True)

    def run_calib(self, n_trials=20, n_workers=mp.cpu_count()):
        self.load_calib()
        self._run_workers(n_trials=n_trials, n_workers=n_workers)
        self.study = ot.load_study(storage=self.storage, study_name=self.study_name)
        subprocess.run(args='mysqldump -u root ' + self.study_name + ' > ' + drive_path.replace(' ', '\ ') + self.study_name + '.sql',
                       shell=True)

    def load_calib(self):
        subprocess.run(args='service mysql start', shell=True)
        subprocess.run(args='mysql -u root -e "DROP DATABASE IF EXISTS {}"'.format(self.study_name), shell=True)
        subprocess.run(args='mysql -u root -e "CREATE DATABASE IF NOT EXISTS {}"'.format(self.study_name), shell=True)
        subprocess.run(args='mysql -u root ' + self.study_name + ' < ' + drive_path.replace(' ', '\ ') + self.study_name + '.sql',
                       shell=True)
        self.study = ot.load_study(storage=self.storage, study_name=self.study_name)

    def get_best(self):
        if self.study is None:
            self.load_calib()
        return self._objective(self.study.best_trial, return_sim=True)

    def _worker(self, n_trials):
        ot.logging.set_verbosity(ot.logging.INFO)
        study = ot.load_study(storage=self.storage, study_name=self.study_name)
        pruner_para = ParamRepeatPruner(study)
        output = study.optimize(lambda trial: self._objective(trial, pruner=pruner_para), n_trials=n_trials)
        return output

    def _run_workers(self, n_trials, n_workers):
        # output = sc.parallelize(self._worker(n_trials), iterarg=n_workers)
        pool = mp.Pool(n_workers)
        with pool:
            pool.map(self._worker, [n_trials] * n_workers)
        return None
  
    def _objective(self, trial, pruner=None, return_sim=False):
        interventions = Interventions()

        start_shift = trial.suggest_int("start_shift", low=-21, high=21, step=1)

        beta = trial.suggest_float('beta', low=0.061, high=0.3, step=0.001)
        # pop_infected = trial.suggest_int("pop_infected", low=1, high=10, step=1)
        pop_infected = 10
        # quar_period = trial.suggest_int("quar_period", low=7, high=14, step=1)
        quar_period = 7  # from https://www.covid.gov.sg/exposed/hrw

        # mo_max = trial.suggest_float("mo_max", low=0.6, high=1.0, step=0.01)
        mo_max = 1.0
        mo_min = 0.43
        ma_max = 1.0
        ma_min = 0.13
        interventions.add_mobility(date=self.datafile['date'],
                                    mobility=self.datafile['mobility'],
                                    mo_max=mo_max, mo_min=mo_min,
                                    mask=self.datafile['face_cover'],
                                    ma_max=ma_max, ma_min=ma_min)

        symp_prob = 0.9
        asymp_prob = 0.0
        test_delay = 1
        interventions.add_covid_test(symp_prob=symp_prob, asymp_prob=asymp_prob, test_delay=test_delay)

        trace_probs = 0.79
        interventions.add_contact_tracing(trace_probs=trace_probs)

        interventions.add_vaccine(date=self.datafile['date'], dose=self.datafile['vaccine'])

        b16172 = cv.variant('b16172', days=pd.to_datetime('2021-08-24') - pd.DateOffset(days=start_shift), n_imports=pop_infected)

        pars = sc.objdict(
            pop_size = math.ceil(5686000 * reduction),
            pop_infected = 0,
            use_waning = True,
            quar_period = quar_period,
            beta = beta,
            rel_death_prob = 0,
            interventions = interventions.interventions,
            pop_type = 'hybrid',
            location = 'Singapore',
            start_day = start_date,
            end_day = end_date,
        )

        # seed = trial.suggest_categorical("seed", list(range(-1, 5)))

        if pruner is not None:
            repeated = pruner.check_params()

        sim = cv.Sim(pars=pars, variants=[b16172], datafile=self.datafile)
        # sim.set_seed(seed=seed)
        if return_sim is False:
            sim.run(verbose=-1)
            sim.compute_fit(keys=['cum_diagnoses'])
            return sim.fit.mismatch
        else:
            return sim

# Calibrate

In [None]:
calib_name = 'sg_calib2'
calib = Calibration(study_name=calib_name, datafile=datafile)
# calib.create_calib()  # uncomment this line if you are creating a new calibration or you have changed parameters
calib.load_calib()

In [None]:
while len(calib.study.trials) < 2000:
    calib.run_calib(n_trials=50)

# Visualize

In [None]:
import matplotlib.pyplot as plt
def visualize(calib_name):
    calib = Calibration(study_name=calib_name, datafile=datafile)
    sim = calib.get_best()
    print(calib.study.best_params)
    best_value = calib.study.best_value
    print(calib.study.best_value)

    sim.run()
    print(sim.summary[["cum_vaccinations", "cum_vaccinated", "n_vaccinated"]])
    fig = sim.plot(to_plot={
                    'Total counts': ['cum_infections', 'cum_diagnoses'],
                    'Daily counts': ['new_infections', 'new_diagnoses']},
                   n_cols=1,
                   fig_args={'figsize': (7, 7)},
                   mpl_args={'font_size': 12},
                   show_args={'interventions': False},
                   date_args={'interval': 90})
    # ax_list = fig.axes
    # for ax in ax_list:
    #     ax.xaxis.label.set_size(18)
    plt.tight_layout()
    fig.savefig(drive_path + calib_name + "_result.pdf", bbox_inches = 'tight')
    plt.close(fig)

    from typing import cast
    def _target(t) -> float:
        return cast(float, np.log(t.value - best_value + 1))

    ax = ot.visualization.matplotlib.plot_parallel_coordinate(calib.study, target = _target, target_name='Log Objective Value',)
    plt.tight_layout()
    ax.figure.savefig(drive_path + calib_name + '_para.pdf', bbox_inches = 'tight')

visualize(calib_name)

# Predict

In [None]:
def get_future(datafile, predict_days, direction='mean'):

    from dateutil import rrule
    from datetime import datetime, timedelta

    start_date = datafile['date'].array[-1] + timedelta(days=1)
    end_date = datafile['date'].array[-1] + timedelta(days=predict_days)

    if direction == 'mean':
        mobility_m = datafile['mobility'][-30:].mean()
        face_cover_m = datafile['face_cover'][-30:].mean()
        vaccine_m = datafile['vaccine'][-30:].mean()
    elif direction == 'min':
        mobility_m = datafile['mobility'][-60:].min()
        face_cover_m = datafile['face_cover'][-60:].min()
        vaccine_m = datafile['vaccine'][-60:].max()
    elif direction == 'max':
        mobility_m = datafile['mobility'][-60:].max()
        face_cover_m = datafile['face_cover'][-60:].max()
        vaccine_m = datafile['vaccine'][-60:].min()

    datafile_predict = []

    for dt in rrule.rrule(rrule.DAILY, dtstart=start_date, until=end_date):
        _datafile = pd.DataFrame({'date': [dt], 'mobility': [mobility_m], 'face_cover': [face_cover_m], 'vaccine': [vaccine_m]})
        datafile_predict.append(_datafile)

    return pd.concat([datafile] + datafile_predict, ignore_index=True)

In [None]:
preidct_days = 30

new_datafile = get_future(datafile, preidct_days)
# datafile.fillna(0)

start_date = new_datafile['date'].array[0]
end_date = new_datafile['date'].array[-1]

In [None]:
def test_sim(datafile):
    interventions = Interventions()

    start_shift = 10

    beta = 0.134
    # pop_infected = trial.suggest_int("pop_infected", low=1, high=10, step=1)
    pop_infected = 10
    # quar_period = trial.suggest_int("quar_period", low=7, high=14, step=1)
    quar_period = 7  # from https://www.covid.gov.sg/exposed/hrw

    # mo_max = trial.suggest_float("mo_max", low=0.6, high=1.0, step=0.01)
    mo_max = 1.0
    mo_min = 0.43
    ma_max = 1.0
    ma_min = 0.13
    interventions.add_mobility(date=datafile['date'],
                               mobility=datafile['mobility'],
                               mo_max=mo_max, mo_min=mo_min,
                               mask=datafile['face_cover'],
                               ma_max=ma_max, ma_min=ma_min)

    symp_prob = 0.9
    asymp_prob = 0.0
    test_delay = 1
    interventions.add_covid_test(symp_prob=symp_prob, asymp_prob=asymp_prob, test_delay=test_delay)

    trace_probs = 0.79
    interventions.add_contact_tracing(trace_probs=trace_probs)

    interventions.add_vaccine(date=datafile['date'], dose=datafile['vaccine'])

    b16172 = cv.variant('b16172', days=pd.to_datetime('2021-08-24') - pd.DateOffset(days=start_shift), n_imports=pop_infected)

    pars = sc.objdict(
        pop_size = math.ceil(5686000 * reduction),
        pop_infected = 0,
        use_waning = True,
        quar_period = quar_period,
        beta = beta,
        rel_death_prob = 0,
        interventions = interventions.interventions,
        pop_type = 'hybrid',
        location = 'Singapore',
        start_day = start_date,
        end_day = end_date,
    )

    sim = cv.Sim(pars=pars, variants=[b16172], datafile=datafile)
    return sim

In [None]:
sim = test_sim(new_datafile)
sim.run()
print(sim.summary[["cum_vaccinations", "cum_vaccinated", "n_vaccinated"]])

In [None]:
def plot_ax(table, ax, interval=1):
    import matplotlib.dates as mdates
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=interval))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b-%d'))
    # plt.gcf().autofmt_xdate()
    table.plot(ax=ax)


In [None]:
fig, ax = plt.subplots(2, 1, figsize=(7, 7), dpi=150)

results = pd.DataFrame({'Date': sim.results['date'],
                        'cumulative infections': sim.results['cum_infections'],
                        'cumulative diagnoses': sim.results['cum_diagnoses'],
                        'current infections': sim.results['n_infectious']})
results = results[results['Date'] >= pd.to_datetime('2021-07-25')]
results = results.set_index(keys='Date')
plot_ax(results, ax[0])
_datafile = sg_for_covasim[sg_for_covasim['date'] <= pd.to_datetime('2021-11-24')]
_datafile = _datafile[_datafile['date'] >= pd.to_datetime('2021-07-25')]
ax[0].scatter(_datafile['date'], np.cumsum(_datafile['new_diagnoses']), s=20, marker='s', alpha=0.5, zorder=0, c='indigo', label='data')
ax[0].axvspan(pd.to_datetime('2021-10-25'), pd.to_datetime('2021-11-24'), facecolor='yellow', alpha=0.2)
ax[0].set_title('Total counts')
ax[0].legend()

results = pd.DataFrame({'Date': sim.results['date'],
                        'new infections': sim.results['new_infections'],
                        'new diagnoses': sim.results['new_diagnoses']})
results = results[results['Date'] >= pd.to_datetime('2021-07-25')]
results = results.set_index(keys='Date')
plot_ax(results, ax[1])
ax[1].scatter(_datafile['date'], _datafile['new_diagnoses'], s=20, marker='s', alpha=0.5, zorder=0, c='indigo', label='data')
ax[1].axvspan(pd.to_datetime('2021-10-25'), pd.to_datetime('2021-11-24'), facecolor='yellow', alpha=0.2)
ax[1].set_title('Daily counts')
ax[1].legend()

fig.tight_layout(rect=[0, 0.03, 1, 0.95])

fig.savefig(drive_path + calib_name + "_predict.pdf", bbox_inches = 'tight')

plt.show()
plt.close(fig)

# Zero vaccination

In [None]:
new_datafile['vaccine'] = 0
calib_name = 'sg2_zero'

In [None]:
sim = test_sim(new_datafile)
sim.run()
print(sim.summary[["cum_vaccinations", "cum_vaccinated", "n_vaccinated"]])

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(7, 7), dpi=150)

results = pd.DataFrame({'Date': sim.results['date'],
                        'cumulative infections': sim.results['cum_infections'],
                        'cumulative diagnoses': sim.results['cum_diagnoses'],
                        'current infections': sim.results['n_infectious']})
results = results[results['Date'] >= pd.to_datetime('2021-07-25')]
results = results.set_index(keys='Date')
plot_ax(results, ax[0])
_datafile = sg_for_covasim[sg_for_covasim['date'] <= pd.to_datetime('2021-11-24')]
_datafile = _datafile[_datafile['date'] >= pd.to_datetime('2021-07-25')]
ax[0].scatter(_datafile['date'], np.cumsum(_datafile['new_diagnoses']), s=20, marker='s', alpha=0.5, zorder=0, c='indigo', label='data')
ax[0].axvspan(pd.to_datetime('2021-10-25'), pd.to_datetime('2021-11-24'), facecolor='yellow', alpha=0.2)
ax[0].set_title('Total counts')
ax[0].legend()

results = pd.DataFrame({'Date': sim.results['date'],
                        'new infections': sim.results['new_infections'],
                        'new diagnoses': sim.results['new_diagnoses']})
results = results[results['Date'] >= pd.to_datetime('2021-07-25')]
results = results.set_index(keys='Date')
plot_ax(results, ax[1])
ax[1].scatter(_datafile['date'], _datafile['new_diagnoses'], s=20, marker='s', alpha=0.5, zorder=0, c='indigo', label='data')
ax[1].axvspan(pd.to_datetime('2021-10-25'), pd.to_datetime('2021-11-24'), facecolor='yellow', alpha=0.2)
ax[1].set_title('Daily counts')
ax[1].legend()

fig.tight_layout(rect=[0, 0.03, 1, 0.95])

fig.savefig(drive_path + calib_name + "_predict.pdf", bbox_inches = 'tight')

plt.show()
plt.close(fig)