In [59]:
%load_ext autoreload
%autoreload 2

import numpy as np
import tensorflow as tf

import sys
sys.path.append('/mnt/c/Users/kheut/code/covid19-forecasting/tf_model_1p5/')

from enum import Enum

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense

import tensorflow_probability as tfp
from scipy.stats import beta, truncnorm


# Local imports from model.py, data.py
from model import CovidModel, LogPoissonProb, get_logging_callbacks, Comp, Vax
from model_config import ModelConfig, ModelVar, replace_keys
from data import read_data, create_warmup
#from plots import make_all_plots

import scipy
import copy

import matplotlib
import matplotlib.pyplot as plt

import json

plt.rcParams.update({'font.size': 20}) # set plot font sizes

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
old_config_path = '/mnt/c/Users/kheut/code/covid19-forecasting/tf_model_1p5/model_config.json'
new_config_path = '/mnt/c/Users/kheut/code/covid19-forecasting/tf_model_1p5/model_config_new_format.json'

In [61]:
with open(old_config_path, 'r') as f:
    old_config_json = json.load(f)

In [69]:
new_config_json = {}

var_transforms = [
    ('T_serial', 'softplus'),
    ('delta', 'sigmoid'),
    ('epsilon', 'sigmoid'),
]

for this_var, this_transform in var_transforms:
    new_config_json[this_var] = copy.deepcopy(old_config_json[this_var])
    
    parsed_transform = ModelVar._parse_transform(this_transform)
    
    for part in new_config_json[this_var].keys():
        if part =='prior':
            continue
        for param in new_config_json[this_var][part].keys():
            new_config_json[this_var][part][param] = parsed_transform.inverse(
                tf.cast(
                    new_config_json[this_var][part][param],
                tf.float32)
            ).numpy()

    new_config_json[this_var]['mean_transform']=this_transform
    
var_transforms = [
    ('rho', 'sigmoid'),
    ('eff', 'sigmoid'),
    ('lambda', 'softplus'),
    ('nu', 'softplus'),
    ('warmup', 'scale_100_softplus'),
    ('init_count', 'scale_100_softplus'),
]

for this_var, this_transform in var_transforms:
    new_config_json[this_var] = copy.deepcopy(old_config_json[this_var])
    
    parsed_transform = ModelVar._parse_transform(this_transform)
    
    for compartment in new_config_json[this_var].keys():
        for part in new_config_json[this_var][compartment].keys():
            if part =='prior':
                continue
            for vax in new_config_json[this_var][compartment][part].keys():
                for param in new_config_json[this_var][compartment][part][vax].keys():
                    if param != 'slope':
                        new_config_json[this_var][compartment][part][vax][param] = parsed_transform.inverse(
                            tf.cast(
                                new_config_json[this_var][compartment][part][vax][param],
                            tf.float32)
                        ).numpy()

    new_config_json[this_var]['mean_transform']=this_transform
    
new_config_json = replace_keys(new_config_json, str, from_tensor=True)

In [70]:
old_config_json

{'T_serial': {'prior': {'loc': 5.8, 'scale': 1},
  'value': {'loc': 5.3, 'scale': 1}},
 'delta': {'prior': {'a': 1, 'b': 1}, 'value': {'loc': 0.03, 'scale': 0.2}},
 'epsilon': {'prior': {'a': 1, 'b': 1}, 'value': {'loc': 0.2, 'scale': 0.3}},
 'rho': {'M': {'prior': {'0': {'a': 4.5, 'b': 3.1}},
   'value': {'0': {'loc': 0.35, 'scale': 0.1}}},
  'G': {'prior': {'0': {'a': 1.1, 'b': 9.6}},
   'value': {'0': {'loc': 0.12, 'scale': 0.1}}},
  'I': {'prior': {'0': {'a': 1.1, 'b': 9.6}},
   'value': {'0': {'loc': 0.12, 'scale': 0.1}}},
  'D': {'prior': {'0': {'a': 1.1, 'b': 9.6}},
   'value': {'0': {'loc': 0.12, 'scale': 0.1}}}},
 'eff': {'M': {'prior': {'1': {'a': 1, 'b': 1}},
   'value': {'1': {'loc': 0.5, 'scale': 0.1}}},
  'G': {'prior': {'1': {'a': 1, 'b': 1}},
   'value': {'1': {'loc': 0.5, 'scale': 0.1}}},
  'I': {'prior': {'1': {'a': 1, 'b': 1}},
   'value': {'1': {'loc': 0.5, 'scale': 0.1}}},
  'D': {'prior': {'1': {'a': 1, 'b': 1}},
   'value': {'1': {'loc': 0.5, 'scale': 0.1}}}},
 '

In [71]:
new_config_json

{'T_serial': {'prior': {'loc': '5.8', 'scale': '1'},
  'value': {'loc': '5.2949963', 'scale': '0.54132485'},
  'mean_transform': 'softplus'},
 'delta': {'prior': {'a': '1', 'b': '1'},
  'value': {'loc': '-3.4760988', 'scale': '-1.3862944'},
  'mean_transform': 'sigmoid'},
 'epsilon': {'prior': {'a': '1', 'b': '1'},
  'value': {'loc': '-1.3862944', 'scale': '-0.84729785'},
  'mean_transform': 'sigmoid'},
 'rho': {'M': {'prior': {0: {'a': '4.5', 'b': '3.1'}},
   'value': {0: {'loc': '-0.6190392', 'scale': '-2.1972246'}}},
  'G': {'prior': {0: {'a': '1.1', 'b': '9.6'}},
   'value': {0: {'loc': '-1.9924302', 'scale': '-2.1972246'}}},
  'I': {'prior': {0: {'a': '1.1', 'b': '9.6'}},
   'value': {0: {'loc': '-1.9924302', 'scale': '-2.1972246'}}},
  'D': {'prior': {0: {'a': '1.1', 'b': '9.6'}},
   'value': {0: {'loc': '-1.9924302', 'scale': '-2.1972246'}}},
  'mean_transform': 'sigmoid'},
 'eff': {'M': {'prior': {1: {'a': '1', 'b': '1'}},
   'value': {1: {'loc': '0.0', 'scale': '-2.1972246'}}}

In [72]:
with open(new_config_path, 'w') as f:
    json.dump(new_config_json, f, indent = 4)

In [73]:
json.dumps(json.loads(json.dumps(new_config_json), parse_float=lambda x: round(float(x), 10)))

'{"T_serial": {"prior": {"loc": "5.8", "scale": "1"}, "value": {"loc": "5.2949963", "scale": "0.54132485"}, "mean_transform": "softplus"}, "delta": {"prior": {"a": "1", "b": "1"}, "value": {"loc": "-3.4760988", "scale": "-1.3862944"}, "mean_transform": "sigmoid"}, "epsilon": {"prior": {"a": "1", "b": "1"}, "value": {"loc": "-1.3862944", "scale": "-0.84729785"}, "mean_transform": "sigmoid"}, "rho": {"M": {"prior": {"0": {"a": "4.5", "b": "3.1"}}, "value": {"0": {"loc": "-0.6190392", "scale": "-2.1972246"}}}, "G": {"prior": {"0": {"a": "1.1", "b": "9.6"}}, "value": {"0": {"loc": "-1.9924302", "scale": "-2.1972246"}}}, "I": {"prior": {"0": {"a": "1.1", "b": "9.6"}}, "value": {"0": {"loc": "-1.9924302", "scale": "-2.1972246"}}}, "D": {"prior": {"0": {"a": "1.1", "b": "9.6"}}, "value": {"0": {"loc": "-1.9924302", "scale": "-2.1972246"}}}, "mean_transform": "sigmoid"}, "eff": {"M": {"prior": {"1": {"a": "1", "b": "1"}}, "value": {"1": {"loc": "0.0", "scale": "-2.1972246"}}}, "G": {"prior": {