In [1]:
# General imports
import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt

# GluonTS imports
from gluonts.dataset.common import ListDataset
from gluonts.torch.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.dataset.split import split

# SimbaML imports
from simba_ml.simulation import distributions, generators
from simba_ml.simulation import kinetic_parameters as kinetic_parameters_module
from simba_ml.simulation import noisers
from simba_ml.simulation import species, system_model


In [2]:
start_date = pd.to_datetime('2020-02-20')
offset = 22

prediction_length = 7
context_length = 7

In [3]:
name = "SIR - Covid-19 - Data Augmentation"
# Population obtained form:
# https://www-genesis.destatis.de/genesis/online?operation=abruftabelleBearbeiten&levelindex=1&levelid=1676991208921&auswahloperation=abruftabelleAuspraegungAuswaehlen&auswahlverzeichnis=ordnungsstruktur&auswahlziel=werteabruf&code=12411-0001&auswahltext=&werteabruf=Value+retrieval#abreadcrumb
specieses = [
    species.Species("Suspectible", distributions.Constant(83166711-100), contained_in_output=False, min_value=0), #83166711
    species.Species("Infected", distributions.Constant(100), contained_in_output=False, min_value=0),
    species.Species("Recovered", distributions.Constant(0), contained_in_output=False, min_value=0),
    species.Species("Cumulative Infected", distributions.Constant(100), contained_in_output=True, min_value=0),
]

kinetic_parameters: dict[str, kinetic_parameters_module.KineticParameter] = {
    "beta": kinetic_parameters_module.ConstantKineticParameter(distributions.ContinuousUniformDistribution(0.32, 0.35)),
    "gamma": kinetic_parameters_module.ConstantKineticParameter(distributions.ContinuousUniformDistribution(0.123, 0.125)),
}

def deriv(_t: float, y: list[float], arguments: dict[str, float]) -> tuple[float, float, float]:
    """Defines the derivative of the function at the point _.

    Args:
        y: Current y vector.
        arguments: Dictionary of arguments configuring the problem.

    Returns:
        Tuple[float, float, float]
    """
    S, I, R, _ = y
    N = S + I + R
    

    dS_dt = -arguments["beta"] * S * I / N
    dI_dt = arguments["beta"] * S * I / N - (arguments["gamma"]) * I
    dR_dt = arguments["gamma"] * I
    dC_dt = arguments["beta"] * S * I / N
    return dS_dt, dI_dt, dR_dt, dC_dt



noiser = noisers.AdditiveNoiser(distributions.NormalDistribution(0, 42*10**3))

sm = system_model.SystemModel(
            name,
            specieses,
            kinetic_parameters,
            deriv=deriv,
            noiser=noiser,
            timestamps=distributions.Constant(100)
        )
    


In [4]:
simulations = generators.TimeSeriesGenerator(sm).generate_signals(n=100)
simulations_new_cases = [simulation.assign(new_cases = simulation["Cumulative Infected"].diff()) for simulation in simulations]
sim_targets = [{"target": simulation["new_cases"].iloc[20:100].to_numpy(), "start": start_date} for simulation in simulations_new_cases]

In [22]:
real_data =  pd.read_csv('data/rki_case_numbers_germany.csv')
real_data = real_data.loc[50:150].reset_index(drop=True)
real_target = [{"target": real_data["new_cases_7d_average"].to_numpy(), "start": start_date}]


In [23]:

target = [{"target": real_target[0]["target"][:offset], "start": start_date}] + sim_targets 

In [24]:
dataset = ListDataset(target, freq='d')
real_dataset = ListDataset(real_target, freq='d')

train_real, test_gen = split(real_dataset, offset=offset)


AttributeError: 'TrainingDataset' object has no attribute 'plot'

In [25]:
# Training with augmented dataset
model = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, context_length=context_length,  trainer_kwargs={"max_epochs": 30}
)
predictor = model.train(dataset)

test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=1)
forecasts_mix = list(predictor.predict(test_data.input))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/umurcankaya/anaconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | SimpleFeedForwardModel | 3.2 K  | train
---------------------------------------------------------
3.2 K     Trainable params
0         Non-trainable params
3.2 K     Total params
0.013     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 50: 'train_loss' reached 15.91741 (best 15.91741), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~apple~CloudDocs/PhD/Act-i-ML/Scripts/SimbaML_examples/src/covid_data_augmentation/lightning_logs/version_2/checkpoints/epoch=0-step=50.ckpt' as top 1
Epoch 1, global step 100: 'train_loss' reached 15.22418 (best 15.22418), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~apple~CloudDocs/PhD/Act-i-ML/Scripts/SimbaML_examples/src/covid_data_augmentation/lightning_logs/version_2/checkpoints/epoch=1-step=100.ckpt' as top 1
Epoch 2, global step 150: 'train_loss' reached 14.82530 (best 14.82530), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~apple~CloudDocs/PhD/Act-i-ML/Scripts/SimbaML_examples/src/covid_data_augmentation/lightning_logs/version_2/checkpoints/epoch=2-step=150.ckpt' as top 1
Epoch 3, global step 200: 'train_loss' was not in top 1
Epoch 4, global step 250: 'train_loss' reached 14.75086 (best 14.75086), saving

In [26]:
# Training with only real-world dataset
del model
model = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, context_length=context_length,  trainer_kwargs={"max_epochs": 30}, weight_decay=0.01
)
predictor = model.train(train_real)

test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=1)
forecasts_obs_only = list(predictor.predict(test_data.input))

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | SimpleFeedForwardModel | 3.2 K  | train
---------------------------------------------------------
3.2 K     Trainable params
0         Non-trainable params
3.2 K     Total params
0.013     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 50: 'train_loss' reached 11.35503 (best 11.35503), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~apple~CloudDocs/PhD/Act-i-ML/Scripts/SimbaML_examples/src/covid_data_augmentation/lightning_logs/version_3/checkpoints/epoch=0-step=50.ckpt' as top 1
Epoch 1, global step 100: 'train_loss' reached 8.45071 (best 8.45071), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~apple~CloudDocs/PhD/Act-i-ML/Scripts/SimbaML_examples/src/covid_data_augmentation/lightning_logs/version_3/checkpoints/epoch=1-step=100.ckpt' as top 1
Epoch 2, global step 150: 'train_loss' reached 7.89309 (best 7.89309), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~apple~CloudDocs/PhD/Act-i-ML/Scripts/SimbaML_examples/src/covid_data_augmentation/lightning_logs/version_3/checkpoints/epoch=2-step=150.ckpt' as top 1
Epoch 3, global step 200: 'train_loss' reached 7.49804 (best 7.49804), saving model to '/Users/umurcankaya/Library/Mobile Documents/com~app

In [27]:
# Requires pdflatex
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})

In [52]:
forecast_date = pd.to_datetime(real_data.loc[offset, "day_idx"])
test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=1)

forecasts_obs_only = list(predictor.predict(test_data.input))

medium = 11
large = 12

plt.rc('font', size=large)         
plt.rc('axes', titlesize=large)     
plt.rc('axes', labelsize=large)    
plt.rc('xtick', labelsize=medium)   
plt.rc('ytick', labelsize=medium)    
plt.rc('legend', fontsize=large)   
plt.rc('figure', titlesize=large)  

fig = plt.gcf()
fig.set_size_inches(6.8, 4.8)

# Plot groud truth time series
real_data["day_idx"] = pd.to_datetime(real_data["day_idx"])
ax1 = plt.plot(real_data["day_idx"][:offset+prediction_length+1], real_data["new_cases_7d_average"][:offset+prediction_length+1],
               label="Ground Truth", color="#332288", linewidth=2.5)

# Plot forecast of model trained with augmented dataset
fcoo = forecasts_mix[0].to_sample_forecast(num_samples=10000)
fcoo.start_date = forecast_date
fcoo.start_date.freq = "D"

fcoo.samples = np.array([[real_data['new_cases_7d_average'][offset]] + list(a) for a in fcoo.samples])
fcoo.plot(intervals=[0.5, 0.85], color="#44AA99");

# Plot forecast of model trained with only the real-world dataset
fcoo = forecasts_obs_only[0].to_sample_forecast(num_samples=10000)
fcoo.start_date = forecast_date
fcoo.start_date.freq = "D"

fcoo.samples = np.array([[real_data['new_cases_7d_average'][offset]] + list(a) for a in fcoo.samples])
fcoo.plot(intervals=[0.5, 0.85], color="#AA4499");  


# Set correct legend
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()

new_handles = [handles[0]]
new_labels = [labels[0]]
first = True
for handle, label in zip(handles[1:], labels[1:]):
    if "%" not in label:
        new_handles.append(handle)
        if first:
            new_labels.append("Forecast: Synthetically Augmented Data")

            first = False
        else:
            new_labels.append("Forecast: Only Real Data")


ax.legend(new_handles, new_labels, loc="upper left", fontsize="medium", ncols=1)    

# Set ticks
tick_dates = pd.date_range(start="2020-02-22", periods=offset+prediction_length, freq="D")[::4]
ax.set_xticks(tick_dates)

tick_labels = [date.strftime('%d\n%b') for date in tick_dates]
ax.set_xticklabels(tick_labels)

plt.setp(ax.get_xticklabels(), rotation=0, ha="center");

# Set axis labels
plt.xlabel("Date (2020)", fontsize="large")
plt.ylabel("New Cases (7-day average)", fontsize="large")

# Add train cutoff visualisation
plt.vlines(x=forecast_date, ymin=0, ymax=14000, color="black", linestyle="dashed", linewidth=1)
plt.text(pd.to_datetime('2020-03-12 21:30'), 4500, "Training cutoff", fontsize=12, color="black", rotation=90, horizontalalignment='right');

# Shade the background of the training data
start_date = pd.to_datetime('2020-02-01')
end_date = forecast_date
ax.axvspan(start_date, end_date, facecolor='grey', alpha=0.14)

box = ax.get_position()
print(box.height)

# Set axis limits
plt.xlim(pd.to_datetime('2020-02-20'), pd.to_datetime('2020-03-20'))
plt.ylim(0, 12000)


plt.savefig('figure2.pdf', bbox_inches='tight')
plt.close()


0.77


In [47]:
ax.get_legend_handles_labels()
for handle, label in zip(handles[1:], labels[1:]):
    print(handle, label)

In [49]:
ax.get_legend_handles_labels()

([<matplotlib.lines.Line2D at 0x31f175310>], ['Ground Truth'])

In [None]:
ax.