# 1. Explore the data

Load some basic libraries.

In [None]:
from goad_toolkit.config import DataConfig, FileConfig
from pathlib import Path
from loguru import logger
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

Lets load the data

In [None]:
fileconfig = FileConfig(data_dir=Path("../data/"))
dataconfig = DataConfig()
fileconfig, dataconfig

Try to explore the data for yourself

In [None]:
from goad_toolkit.dataprocessor import DataProcessor
from goad_toolkit.datatransforms import SelectDataRange, DiffValues, ZScaler, RollingAvg


class YourPreprocessor(DataProcessor):
    def config_pipeline(self, dataconfig: DataConfig):
        self.pipeline.add(DiffValues, column="deaths")
        self.pipeline.add(
            SelectDataRange,
            start_date=dataconfig.start_date,
            end_date=dataconfig.end_date,
        )
        self.pipeline.add(RollingAvg, column="deaths", window=dataconfig.window)
        self.pipeline.add(RollingAvg, column="positivetests", window=dataconfig.window)
        self.pipeline.add(ZScaler, column="deaths", rename=True)
        self.pipeline.add(ZScaler, column="positivetests", rename=True)


preprocessor = YourPreprocessor(fileconfig, dataconfig)
df = preprocessor.process(raw=True)
df.head()

In [None]:
plt.plot(df.index, df["deaths_zscore"])
plt.plot(df.index, df["positivetests_zscore"])
plt.title("Deaths vs Positive Tests (smooted)")
plt.xticks(rotation=45)
plt.ylabel("Z-Score")

Compare your preprocessing with mine. Try to understand why I did what I did.

In [None]:
from goad_toolkit.dataprocessor import CovidDataProcessor

covidprocessor = CovidDataProcessor(fileconfig, dataconfig)
data = covidprocessor.process(save=True)
data.head()

# 2. Model the data

We can model the vaccination effect as a logistic function.

In [None]:
from goad_toolkit.models import logistic
from goad_toolkit.visualizer import ComparePlot, PlotSettings

x = np.linspace(-10, 10, 1000)
y1 = logistic(x, k=1, x0=0)
y2 = logistic(x, k=-0.5, x0=0)

example = pd.DataFrame(
    {
        "x": x,
        "y1(k=1)": y1,
        "y2(k=-0.5)": y2,
    }
)
settings = PlotSettings(title="Logistic Function Example")
compareplot = ComparePlot(settings)
fig, ax = compareplot.plot(data=example, x="x", y1="y1(k=1)", y2="y2(k=-0.5)")

We will need to fit the variables of the logistic function to the data.

this means:
- k: how fast moves the curve from 0 to 1 (or from 1 to 0)
- x0: the point where the curve is at 0.5
- limit: we can keep it at 1

This means, we will need the index of the dates as a list of numbers (0, 1, ..., n) and fit the logistic function with x0 as the date where the curve is at 0.5 and k as the slope of the curve. So lets add the index as a list of integers to the data.

In [None]:
x = data["positivetests"].values
y = data["deaths"].values
data_idx = np.arange(len(x))
X = np.stack([x, data_idx], axis=1)
X.shape, y.shape

Now it is up to you to add the logistic function to the model.
Below is a naive model, that ignores the parameters and just returns the shifted amount of positive tests.

Think about initial parameters, based on your data visualizations.

In [None]:
def covid_model(X: np.ndarray, params: list[float]) -> np.ndarray:
    """your improved model here"""
    return X[:, 0]


initial_params = [1.0, 1.0, -1.0, 60.0]  # dummy parameters
initial_predict = covid_model(X, initial_params)

# 3. Train the model

now, with your updated model, train it such that we can fit the parameters to the data.

In [None]:
from goad_toolkit.models import train_model, mse

try:
    params = train_model(X, y, covid_model, mse, initial_params)
    logger.success(f"Model training finished with: {params}")
except Exception as e:
    logger.error(f"Model training failed: {e}")
    params = []
a, b, k, x0 = params

Lets first visualize just the logistic curve. Does it align with the data? Especially the point where vaccination starts?

In [None]:
from goad_toolkit.visualizer import ComparePlotDate
import pandas as pd

data["logistic model"] = logistic(
    x=X[:, 1], k=k, x0=x0
)  # your fitted parameters k and x0

start = "2021-01-06"
settings = PlotSettings(
    title="Vaccination Effect", xlabel="Date", ylabel="Scaled Values"
)
comparedate = ComparePlotDate(settings)
comparedate.plot(
    data=data,
    x="date",
    y1="deaths_shifted_zscore",
    y2="logistic model",
    date=start,
    datelabel="Start vaccinations",
)

If this is not what you expect, go back to 2.

# 4. Predict

Now, if the model seems to work as expected, lets do a full prediction.

In [None]:
predict = covid_model(X, params)
data["predict"] = predict

In [None]:
from goad_toolkit.visualizer import ComparePlotDate

settings = PlotSettings()
compare = ComparePlotDate(settings)
compare.plot(
    data=data,
    x="date",
    y1="deaths_shifted",
    y2="predict",
    date="2021-01-06",
    datelabel="start vaccinations",
)

And lets check the residual. In addition to the model fitting the data, maybe we can see patterns in the residual that can give us clues about how to improve the model.

In [None]:
from goad_toolkit.visualizer import ResidualPlot

data["residual"] = data["deaths_shifted"] - data["predict"]
settings = PlotSettings(
    figsize=(12, 6), title="Residual Plot", xlabel="dates", ylabel="error"
)
resplot = ResidualPlot(settings)
fig, _ = resplot.plot(
    data=data,
    x="date",
    y="residual",
    date="2021-01-06",
    datelabel="Vaccination Started",
    interval=1,
)

# 5. Test the residual

In addition to looking at the plots, we can also test the residual. 
You want to see errors with a mean of 0 and a low variance.

In [None]:
from goad_toolkit.analytics import DistributionFitter

fitter = DistributionFitter()
fits = fitter.fit(data["residual"], discrete=False)
best = fitter.best(fits)
best

In [None]:
from goad_toolkit.visualizer import PlotFits, PlotSettings, FitPlotSettings

settings = PlotSettings(
    figsize=(12, 6), title="Residuals", xlabel="error", ylabel="probability"
)
fitplotsettings = FitPlotSettings(bins=30, max_fits=3)
fitplotter = PlotFits(settings)
fig = fitplotter.plot(
    data=data["residual"], fit_results=fits, fitplotsettings=fitplotsettings
)