# The University of Hong Kong
## DASC7600 Data Science Project 2024

# Import modules

In [1]:
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt

from scipy import integrate, optimize
from typing import List
from sklearn import metrics

warnings.filterwarnings('ignore')

# Pre-defined Values

In [2]:
hk_popul_2020 = 7520500
# https://gia.info.gov.hk/general/202402/20/P2024022000221_448985_1_1708412192189.pdf

# Functions

In [3]:
def get_date_count(df: pd.DataFrame,
                   col: str) -> pd.DataFrame:
    agg_df = df.groupby(col)[col].count()
    date_idx = pd.date_range(agg_df.index.min(), agg_df.index.max())
    agg_series = pd.Series(agg_df)
    agg_series.index = pd.DatetimeIndex(agg_series.index)
    agg_series = agg_series.reindex(date_idx, fill_value=0)
    return pd.DataFrame({col: agg_series.index, 'count': agg_series.values})

# Load Data

In [4]:
# Read csv files
covid_hk_std = pd.read_csv('./data/std_data/hk/covid_hk_std.csv')

# SIR Model

In [5]:
class SIR_model:
    def __init__(self, I_counts: List[int], initial_counts: List[int]):
        self.I_counts = I_counts
        self.initial_counts = initial_counts
        self.t_length = len(I_counts)
        self.step = 1
    
    @staticmethod
    def SIR_DE(t, initial_counts, beta, gamma):
        if len(initial_counts) != 3:
            raise Exception('Length of initial_counts should be 3.')
        
        S = initial_counts[0]
        I = initial_counts[1]
        R = initial_counts[2]
        
        return([-beta*S*I,
                beta*S*I - gamma*I,
                gamma*I])

    def solve_SIR(t_length, initial_counts, step, params):
        return integrate.solve_ivp(
            SIR_model.SIR_DE,
            [0, t_length-1],
            initial_counts,
            t_eval=np.arange(0, t_length-1+step, step), 
            args=params)
    
    @staticmethod
    def sum_of_square(params, t_length, initial_counts, step, I_counts):
        sol = SIR_model.solve_SIR(t_length, initial_counts, step, params)
        return (sum((sol.y[1][::int(1/step)]-I_counts)**2))
    
    def get_optm_params(self):
        sol = optimize.minimize(
            SIR_model.sum_of_square,
            (0.001,1), 
            (self.t_length, self.initial_counts, self.step, self.I_counts), 
            method='Nelder-Mead')
        return sol.x
    
    def fit(self):
        return SIR_model.solve_SIR(self.t_length, self.initial_counts, self.step, self.get_optm_params())
    
    def plot(self, predict, true_I_counts):
        fig = plt.figure(figsize=(12,4))
        plt.plot(predict.t, predict.y[0])
        plt.plot(predict.t, predict.y[1])
        plt.plot(predict.t, predict.y[2])
        plt.plot(np.arange(0, len(true_I_counts)), true_I_counts, "k*:")
        plt.grid("True")
        plt.legend(["Susceptible","Infected","Removed","Original Data"])
        plt.plot()
    
    def plot_fitted_model(self):
        self.plot(self.fit(), self.I_counts)

In [6]:
covid_hk_new_case_cnt = get_date_count(covid_hk_std, 'report_date')

In [7]:
first_wave_daily_cnt = covid_hk_new_case_cnt["count"].values[10:36].tolist()

In [9]:
SIR_model_hk = SIR_model(first_wave_daily_cnt, [sum(first_wave_daily_cnt), 1, 0])
# SIR_model_hk.plot_fitted_model()