In [1]:
#General packages
import os
import numpy as np
from tqdm import tqdm  ### package for progress bars
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
### MODEL 1 infrastructure
from model1_class import Model1
from run_enkf import *
### MODEL 2 infrastructure
from model2_class import Model2
import pandas as pd
from benchmarking_error_simple import benchmarking_error_simple

#%% benchmarking fangraph of both models 

start_year = 1990
end_year = 2022

# Define the period length should be end_year + 1 because you want to include start and final year in the period
period_length = ((end_year+1)-start_year)*12
print("period length is ", period_length)

model_params1 = {"population_size": 100,
                "growth_rate": 0.025,
                "b_begin": 1.3,
                "distribution": "exponential_pareto", # either "Pareto_lognormal" or "exponential_pareto"
                "start_year": start_year,
                "uncertainty_para": 0.1}

model_params2 = {"population_size": 100, 
                "concavity": 1,
                "growth_rate": 0.025, 
                "distribution": "exponential_pareto", # either "Pareto_lognormal" or "exponential_pareto"
                "start_year": start_year,
                "adaptive_sensitivity": 0.02,
                "uncertainty_para": 0}


### set filter frequency so high that it is not used

enkf1 = prepare_enkf(Model1, model_params1, uncertainty_obs = 0.5, ensemble_size=30, macro_state_dim=4, filter_freq=1000)
enkf2 = prepare_enkf(Model2, model_params2, uncertainty_obs = 0.5, ensemble_size=30, macro_state_dim=4, filter_freq=1000)
run_enkf(enkf1, start_year, end_year, filter_freq = 1000)
run_enkf(enkf2, start_year, end_year, filter_freq = 1000)

#### RUN BENCHMARK CLASS FROM ABOVE WHICH IS DIFFERENT FROM THE PREPARE ENKF 
benchmark = benchmarking_error_simple(30, distribution_model1 = "exponential_pareto", distribution_model2 = "exponential_pareto")
benchmark.collect_data(start_year, end_year)
benchmark.compute_error(start_year=start_year, end_year=end_year)   

# Now let's say you want to integrate this into another grid layout
fig = plt.figure(figsize=(10, 10))
# Create a gridspec object
gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1])
# Create individual subplots
ax0 = plt.subplot(gs[0, 0])
ax1 = plt.subplot(gs[0, 1])
ax2 = plt.subplot(gs[1, 0])
ax3 = plt.subplot(gs[1, 1])
ax4 = plt.subplot(gs[2, :])  # This one spans both columns


enkf1.models[0].plot_wealth_groups_over_time(ax0, start_year, end_year)
enkf2.models[0].plot_wealth_groups_over_time(ax1, start_year, end_year)
enkf1.plot_fanchart(ax2)
enkf2.plot_fanchart(ax3)
benchmark.plot_graph(ax4)

###EXTRAS
#AX0
ax0.text(0,0.85, 'A', fontsize = 12)
ax0.text(40,0.85, 'Model 1', fontsize = 12)

#AX1
ax1.legend(loc=(1.05, -0.15), frameon = False)
ax1.text(0,0.85, 'B', fontsize = 12)
ax1.text(40,0.85, 'Model 2', fontsize = 12)

#AX2
ax2.text(0,1.05, 'C', fontsize = 12)
ax2.set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
ax2.text(40,1.05, 'Model 1', fontsize = 12)

#AX3
ax3.set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
ax3.text(0,1.05, 'D', fontsize = 12)
ax3.text(40, 1.05, 'Model 2', fontsize = 12)

#AX4
# Get the limits
x_min, x_max = ax4.get_xlim()
y_min, y_max = ax4.get_ylim()
ax4.text(0,y_max+0.01, 'E', fontsize = 12)
plt.tight_layout()
plt.savefig('fig2.png', dpi = 300)
plt.show()

period length is  396


Iterations ENKF <class 'model1_class.Model1'>: 100%|██████████| 396/396 [00:26<00:00, 15.11it/s]
Iterations ENKF <class 'model2_class.Model2'>: 100%|██████████| 396/396 [00:38<00:00, 10.35it/s]
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '60%', '80%'])
  ax.set_yticklabels(['0%', '0%', '20%', '40%', '6