# STEP: IMPORT FUNCTIONS
import pandas as pd
from data_gathering import gather_data_features, gather_data_actuals
from data_exploration import calculate_correlation
from mappings import import_country_mapping
from data_visualization import global_scatter_plot

In [None]:
# STEP: LOAD DATA
# Load actuals data
data_cm_actual_2018, data_cm_actual_2019, data_cm_actual_2020, data_cm_actual_2021, data_cm_actual_allyears \
    = gather_data_actuals()

# Load features data
data_cm_features_2017, data_cm_features_2018, data_cm_features_2019, data_cm_features_2020, data_cm_features_allyears \
    = gather_data_features()


output_directory = r"C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Baseline_Model"

crps_scores_all_year_country_specific = {}
crps_scores_all_year_global = {}
for year in [2015, 2016, 2017, 2018, 2019]:
    file_path = rf'C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Baseline_Model\crps_scores_all_year_{year}_country_specific.parquet'
    crps_scores_all_year_country_specific[year] = pd.read_parquet(file_path)
    file_path = rf'C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Baseline_Model\crps_scores_all_year_{year}_global.parquet'
    crps_scores_all_year_global[year] = pd.read_parquet(file_path)

crps_scores_all_year_country_specific_year_2015 = crps_scores_all_year_country_specific[2015]
crps_scores_all_year_country_specific_year_2016 = crps_scores_all_year_country_specific[2016]
crps_scores_all_year_country_specific_year_2017 = crps_scores_all_year_country_specific[2017]
crps_scores_all_year_country_specific_year_2018 = crps_scores_all_year_country_specific[2018]
crps_scores_all_year_country_specific_year_2019 = crps_scores_all_year_country_specific[2019]

crps_scores_all_year_global_year_2015 = crps_scores_all_year_global[2015]
crps_scores_all_year_global_year_2016 = crps_scores_all_year_global[2016]
crps_scores_all_year_global_year_2017 = crps_scores_all_year_global[2017]
crps_scores_all_year_global_year_2018 = crps_scores_all_year_global[2018]
crps_scores_all_year_global_year_2019 = crps_scores_all_year_global[2019]

country_mapping = import_country_mapping()

In [None]:
# STEP: DATA EXPLORATION
input_variables = \
    data_cm_features_allyears.columns.drop(['index', 'country_id', 'month_id', 'ged_sb']).tolist()
target_variables = \
    ['ged_sb', 'ged_sb_tlag_1', 'ged_sb_tlag_2', 'ged_sb_tlag_3', 'ged_sb_tlag_4', 'ged_sb_tlag_5', 'ged_sb_tlag_6']
correlation_values = calculate_correlation(data_cm_features_allyears, target_variables=target_variables, input_variables=input_variables)



In [None]:
data_cm_features_allyears

In [None]:
data_cm_actual_allyears

In [None]:
# Create a boolean mask of values with absolute value greater than or equal to 0.1
corr_threshold = 0.1
mask = correlation_values.abs() >= corr_threshold

# Apply the mask to the DataFrame to drop the values
correlation_values_filtered = correlation_values[mask]

correlation_values_filtered

**Understand spatial lag features**

In [None]:

country_neighbours = [1, 2, 4, 9]
data_cm_features_allyears[['month_id', 'country_id', 'vdem_v2x_libdem', 'vdem_v2x_libdem_48', 'splag_vdem_v2x_libdem']][(data_cm_features_allyears['country_id'].isin(country_neighbours)) & (data_cm_features_allyears['month_id'] == 121)]

**Understand decay function features**

In [None]:
country_neighbours = [1, 2, 4, 9]
data_cm_features_allyears[['month_id', 'country_id', 'ged_sb', 'decay_ged_sb_5', 'decay_ged_sb_100', 'decay_ged_sb_500']][(data_cm_features_allyears['country_id'] == 133)]

In [None]:
country_neighbours = [1, 2, 4, 9]
data_cm_features_allyears[['month_id', 'country_id', 'ged_sb', 'decay_ged_sb_5', 'splag_1_decay_ged_sb_5' ]][(data_cm_features_allyears['country_id'].isin(country_neighbours)) & (data_cm_features_allyears['month_id'] == 121)]

Plot independent variables vs conflict fatalities in scatter plot

In [None]:
output_directory = r'C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Plots\scatterplots_global'
independent_variables = data_cm_features_allyears.columns.drop(['index', 'country_id', 'month_id', 'ged_sb']).tolist()
for independent_variable in independent_variables:
    global_scatter_plot(data_cm_features_allyears, independent_variable, 'ged_sb', output_directory, show=False, export=True)

Plot independent variables vs conflict fatalities in scatter plot for non zero conflict data

In [None]:
output_directory_non_zero = r'C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Plots\scatterplots_global_non_zero'
independent_variables = data_cm_features_allyears.columns.drop(['index', 'country_id', 'month_id', 'ged_sb']).tolist()
data_cm_features_allyears_non_zero = data_cm_features_allyears[data_cm_features_allyears['ged_sb'] > 0]
for independent_variable in independent_variables:
    global_scatter_plot(data_cm_features_allyears_non_zero, independent_variable, 'ged_sb', output_directory_non_zero, show=False, export=True)

In [None]:
data_cm_features_allyears['ged_sb'].describe()


In [None]:
data_cm_features_allyears_non_zero['ged_sb'].describe()

In [None]:
data_cm_features_allyears[['month_id', 'country_id', 'ged_sb', 'ged_sb_tlag_1', 'ged_sb_tlag_2', 'ged_sb_tlag_3', 'ged_sb_tlag_4', 'ged_sb_tlag_5', 'ged_sb_tlag_6']]

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
def plot_variable_distribution(df: pd.DataFrame, variable: str):
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.histplot(data=df, x=variable, ax=ax)
    ax.set_title(f'Distribution of {variable}')
    plt.show()

In [None]:
plot_variable_distribution(data_cm_features_allyears, 'ged_sb')

In [None]:
plot_variable_distribution(data_cm_features_allyears_non_zero, 'ged_sb')

In [None]:
global_scatter_plot(data_cm_features_allyears_non_zero, 'wdi_sm_pop_refg_or', 'ged_sb', output_directory, show=True, export=False)

In [None]:
data_cm_features_allyears_non_zero_big =  data_cm_features_allyears[data_cm_features_allyears['ged_sb'] > 5]

In [None]:
plot_variable_distribution(data_cm_features_allyears_non_zero_big, 'ged_sb')

In [None]:
import pymc as pm

In [None]:
run_id = 'eae5db7dfa6e4408a8fb0f916968d033'
evaluation_year = 2018
results_all_countries_path = fr"C:\Users\Uwe Drauz\PycharmProjects\bachelor_thesis\mlruns\3\{run_id}\artifacts\BaselineModel_results_all_countries{evaluation_year}.parquet"
results_all_countries = pd.read_parquet(results_all_countries_path)
results_actuals_all_countries_path = fr"C:\Users\Uwe Drauz\PycharmProjects\bachelor_thesis\mlruns\3\{run_id}\artifacts\BaselineModel_results_actuals_{evaluation_year}.parquet"
results_actuals = pd.read_parquet(results_actuals_all_countries_path)

### Prior Predicitve Checks

In [None]:
import scipy.stats as stats
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

In [None]:

# Simulate gamma distribution of alpha parameter
nsim = 100
nobs = 100
a_alpha = 0.1
b_alpha = 1

# Generate alpha using Gamma distribution
alpha = stats.gamma.rvs(a=a_alpha, scale=1/b_alpha, size=nsim)

# Initialize DataFrame to store Negative Binomial data
y = pd.DataFrame(index=range(nsim), columns=range(nobs))

# Simulate target data following NB distribution
mu = 1
for i in range(nsim):
    # Generate Negative Binomial data using the alpha parameter
    y.loc[i, :] = stats.nbinom.rvs(n=alpha[i], p=mu/(mu+alpha[i]), size=nobs)




In [None]:
# Plot the histogram using Seaborn
sns.histplot(alpha, bins=30, label='Histogram', color='blue')

# Create a second y-axis to plot the density
ax2 = plt.gca().twinx()

# Plot the density on the second y-axis
sns.kdeplot(alpha, ax=ax2, label='Density', color='red')

# Labeling and titles
plt.title('Histogram and Density of Alpha')
plt.xlabel('Alpha')
plt.gca().set_ylabel('Frequency (Count)')
ax2.set_ylabel('Density')

# Add legends
plt.gca().legend(loc='upper left')
ax2.legend(loc='upper right')

plt.show()


In [None]:
# Plot the histogram using Seaborn for the first row of y
for obs in range(20):
    print(f"alpha value: {alpha[obs]}")
    sns.histplot(y.loc[obs, :], bins=30, label='Histogram', color='blue')
    
    # Create a second y-axis to plot the density
    ax2 = plt.gca().twinx()
    
    # Plot the density on the second y-axis for the first row of y
    sns.kdeplot(y.loc[obs, :], ax=ax2, label='Density', color='red')
    
    # Labeling and titles
    plt.title('Histogram and Density of y (Row 0)')
    plt.xlabel('y')
    plt.gca().set_ylabel('Frequency (Count)')
    ax2.set_ylabel('Density')
    
    # Add legends
    plt.gca().legend(loc='upper left')
    ax2.legend(loc='upper right')
    
    plt.show()


In [None]:
# Load 'Actuals' data
data_cm_actual_2018, data_cm_actual_2019, data_cm_actual_2020, data_cm_actual_2021, data_cm_actual_allyears \
    = gather_data_actuals()
# Load features data
data_cm_features_2017, data_cm_features_2018, data_cm_features_2019, data_cm_features_2020, data_cm_features_allyears \
    = gather_data_features()

In [None]:
dict_actuals = {
        "2018": data_cm_actual_2018,
        "2019": data_cm_actual_2019,
        "2020": data_cm_actual_2020,
        "2021": data_cm_actual_2021
        }

In [None]:
# Determine all countries
actual_countries = data_cm_actual_allyears['country_id'].unique()
# Determine countries with at least one conflict fatality
feature_countries_non_zero = data_cm_features_allyears[data_cm_features_allyears['ged_sb'] > 0][
    'country_id'].unique()
# Determine countries which have at least one conflict fatality and are in the actuals data
feature_and_actuals_countries_non_zero = list(set(feature_countries_non_zero) & set(actual_countries))

# Determine the counties which are in actual_countries but not in feature_and_actuals_countries_non_zero
countries_in_actuals_without_observations = list(
    set(actual_countries) - set(feature_and_actuals_countries_non_zero))

In [None]:
# Calculate the average value of 'ged_sb' for each data frames in dict_actuals filtered on countries which are in feature_and_actuals_countries_non_zero
for year in dict_actuals.keys():
    df = dict_actuals[year]
    df = df[df['country_id'].isin(feature_and_actuals_countries_non_zero)]
    print(f"Average value of 'ged_sb' in {year}: {df['ged_sb'].mean()}")

In [None]:
X = pd.read_parquet("competition_results/test_window_2018_baseline_validate_on_forecast_horizon.parquet")

In [None]:
from mappings import map_month_id_to_datetime
map_month_id_to_datetime(468)

In [None]:
import pandas as pd
results_all_countries_2018 = pd.read_parquet("mlruns/3/bd79953d84ac4d3d8d59415a286c3fc3/artifacts/BaselineModel_results_all_countries2018.parquet")

### Categorizing countries into developed and developing countries

In [None]:
import pandas as pd
import numpy as np

In [None]:
# List of countries in the views competition data
views_country_list = [
    "Guyana", "Suriname", "Trinidad and Tobago", "Venezuela", "Samoa", "Tonga",
    "Argentina", "Bolivia", "Brazil", "Chile", "Ecuador", "Paraguay", "Peru", "Uruguay",
    "Guatemala", "Mexico", "Barbados", "Dominica", "Grenada", "St. Lucia", "St. Vincent and the Grenadines",
    "Dominican Republic", "Haiti", "Jamaica", "Bahamas", "Belize", "Colombia", "Costa Rica", "Cuba",
    "El Salvador", "Honduras", "Nicaragua", "Panama", "Antigua and Barbuda", "St. Kitts and Nevis",
    "Iceland", "Ireland", "United Kingdom", "Cape Verde", "Cote d'Ivoire", "Ghana", "Liberia",
    "Portugal", "Spain", "Burkina Faso", "Guinea", "Guinea-Bissau", "Mali",
    "Senegal", "Sierra Leone", "The Gambia", "Djibouti", "Eritrea", "Ethiopia", "Mongolia",
    "Iraq", "Jordan", "Kazakhstan", "Norway", "Russia", "Sweden",
    "Algeria", "Cameroon", "Central African Republic", "Tunisia", "Benin",
    "Equatorial Guinea", "Kiribati", "Niger", "Nigeria", "Sao Tome and Principe", "Togo",
    "Albania", "Bosnia and Herzegovina", "Croatia", "Italy", "Macedonia", "Malta",
    "Bulgaria", "Cyprus", "Georgia", "Greece", "Lebanon",
    "Turkey", "Austria", "Czech Republic", "Denmark", "Hungary", "Poland",
    "Slovakia", "Slovenia", "Belgium", "France", "Luxembourg", "Netherlands",
    "Switzerland", "Belarus", "Estonia", "Finland", "Latvia", "Lithuania", "Moldova",
    "Romania", "Ukraine", "Maldives", "Oman", "Somalia", "Sri Lanka", "Turkmenistan",
    "Uzbekistan", "Yemen", "Armenia", "Azerbaijan", "Bahrain", "Iran", "Kuwait",
    "Qatar", "Saudi Arabia", "United Arab Emirates", "Afghanistan", "Kyrgyzstan",
    "Nepal", "Pakistan", "Tajikistan", "Bangladesh", "Bhutan", "Brunei",
    "Japan", "North Korea", "Palau", "Philippines", "South Korea",
    "Cambodia", "Laos", "Myanmar", "Thailand", "Vietnam", "Marshall Is.",
    "Micronesia", "Botswana", "Burundi", "Rwanda", "Zambia", "Zimbabwe",
    "Comoros", "Lesotho", "Malawi", "Mozambique", "South Africa", "Swaziland",
    "Angola", "Congo", "Congo, DRC", "Fiji", "Gabon", "Namibia",
    "New Zealand", "Madagascar", "Mauritius", "Seychelles", "Timor Leste",
    "Australia", "Nauru", "Papua New Guinea", "Solomon Is.", "Tuvalu",
    "Vanuatu", "Canada", "Germany", "Taiwan", "China", "Malaysia",
    "Singapore", "Indonesia", "Libya", "Chad", "Israel", "Syria",
    "Egypt", "India", "Montenegro", "Kosovo", "Serbia", "United States",
    "Uganda", "Kenya", "Tanzania", "Morocco", "Mauritania", "Sudan", "South Sudan"
]


**List of categorized countries based on the World Economic Situation and Prospects (WESP) 2022 report**
To be found here: https://www.un.org/development/desa/dpad/wp-content/uploads/sites/45/WESP2022_ANNEX.pdf

In [None]:
# List of developed economies from Table A
developed_economies = [
    "Canada", "United States", "Australia", "Japan", "New Zealand",
    "Austria", "Belgium", "Denmark", "Finland", "France", "Germany",
    "Greece", "Ireland", "Italy", "Luxembourg", "Netherlands", "Portugal",
    "Spain", "Sweden", "Bulgaria", "Croatia", "Cyprus", "Czech Republic",
    "Estonia", "Hungary", "Latvia", "Lithuania", "Malta", "Poland",
    "Romania", "Slovakia", "Slovenia", "Iceland", "Norway", "Switzerland",
    "United Kingdom"
]

In [None]:
# List of economies in transition from Table B
economies_in_transition = [
    "Albania", "Bosnia and Herzegovina", "Montenegro", "Macedonia", "Serbia",
    "Armenia", "Azerbaijan", "Belarus", "Georgia", "Kazakhstan", "Kyrgyzstan",
    "Moldova", "Russia", "Tajikistan", "Turkmenistan",
    "Ukraine", "Uzbekistan"
]

In [None]:
# List of developing economies from Table C
developing_economies = [
    # North Africa
    "Algeria", "Egypt", "Libya", "Mauritania", "Morocco", "Sudan", "Tunisia",
    # Central Africa
    "Cameroon", "Central African Republic", "Chad", "Congo, DRC", "Equatorial Guinea", "Gabon", "Sao Tome and Prinicipe",
    # East Africa
    "Burundi", "Congo", "Comoros", "Djibouti", "Eritrea", "Ethiopia", "Kenya", "Madagascar", "Rwanda", "Somalia", "South Sudan", "Tanzania", "Uganda",
    # Southern Africa
    "Angola", "Botswana", "Swaziland", "Lesotho", "Malawi", "Mauritius", "Mozambique", "Namibia", "South Africa", "Zambia", "Zimbabwe",
    # West Africa
    "Benin", "Burkina Faso", "Cape Verde", "Cote d'Ivoire", "The Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Liberia", "Mali", "Niger", "Nigeria", "Senegal", "Sierra Leone", "Togo",
    # Asia
    "Brunei", "Cambodia", "China", "North Korea", "Fiji", "Hong Kong SAR", "Indonesia", "Kiribati", "Laos", "Malaysia", "Mongolia", "Myanmar", "Papua New Guinea", "Philippines", "South Korea", "Samoa", "Singapore", "Solomon Is.", "Taiwan", "Thailand", "Timor-Leste", "Vanuatu", "Vietnam",
    "Afghanistan", "Bangladesh", "Bhutan", "India", "Iran", "Maldives", "Nepal", "Pakistan", "Sri Lanka",
    "Bahrain", "Iraq", "Israel", "Jordan", "Kuwait", "Lebanon", "Oman", "Qatar", "Saudi Arabia", "State of Palestine", "Syria", "Turkey", "United Arab Emirates", "Yemen",
    # Latin America and the Caribbean
    "Bahamas", "Barbados", "Belize", "Guyana", "Jamaica", "Suriname", "Trinidad and Tobago", "Costa Rica", "Cuba", "Dominican Republic", "El Salvador", "Guatemala", "Haiti", "Honduras", "Mexico", "Nicaragua", "Panama",
    "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Paraguay", "Peru", "Uruguay", "Venezuela"
]


In [None]:
# List of the least developed economies from Table F based on the UN list of LDCs
least_developed_countries = [
    "Angola", "Benin", "Burkina Faso", "Burundi", "Central African Republic",
    "Chad", "Comoros", "Congo, DRC", "Djibouti", "Eritrea",
    "Ethiopia", "The Gambia", "Guinea", "Guinea-Bissau", "Lesotho", "Liberia",
    "Madagascar", "Malawi", "Mali", "Mauritania", "Mozambique", "Niger",
    "Rwanda", "Sao Tome and Principe", "Senegal", "Sierra Leone", "Somalia",
    "South Sudan", "Sudan", "Togo", "Uganda", "Tanzania",
    "Zambia", "Cambodia", "Kiribati", "Laos",
    "Myanmar", "Solomon Is.", "Timor Leste", "Tuvalu", "Afghanistan",
    "Bangladesh", "Bhutan", "Nepal", "Yemen", "Haiti"
]

In [None]:
# List of small developing island states from Table H based on the UN list of SIDS
small_islands_developing_states = [
    "Antigua and Barbuda", "Bahamas", "Bahrain", "Barbados", "Belize", "Cape Verde", "Comoros", "Cuba", "Dominica", "Dominican Republic", "Micronesia", "Fiji", "Grenada", "Guinea-Bissau", "Guyana", "Haiti", "Jamaica", "Kiribati",  "Maldives", "Marshall Is.", "Mauritius", "Nauru", "Palau", "Papua New Guinea", "St. Kitts and Nevis", "St. Lucia", "St. Vincent and the Grenadines", "Samoa", "Sao Tome and Principe", "Seychelles", "Singapore", "Solomon Is.", "Suriname", "Timor-Leste", "Tonga", "Trinidad and Tobago", "Tuvalu", "Vanuatu", "American Samoa", "Anguilla", "Aruba", "Bermuda", "British Virgin Islands", "Cayman Islands", "Commonwealth of Northern Marianas", "Cook Islands", "Curaçao", "French Polynesia", "Guadeloupe", "Guam", "Martinique", "Montserrat", "New Caledonia", "Niue", "Puerto Rico", "Sint Maarten", "Turks and Caicos Islands", "U.S. Virgin Islands"
]


In [None]:
# not fully recognized countries
not_recognized_countries = ['Kosovo']

Construct different sets of countries 

In [None]:
# Combined economies types from WESP 2022 report
all_economies_wesp = developed_economies + economies_in_transition + developing_economies

In [None]:
# Combined economies types from WESP 2022 report extended by LDCs and SIDS
all_economies_wesp_plus_UN = list(set(all_economies_wesp + least_developed_countries + small_islands_developing_states))

In [None]:
# Developing economies from WESP 2022 report extended by LDCs, SIDS ans outlier countries
developing_economies_extended = developing_economies + least_developed_countries + small_islands_developing_states + not_recognized_countries
developing_economies_extended = list(set(developing_economies_extended))

In [None]:
# List of developing economies without LDCs n
developing_economies_extended_without_LDC = list(set(developing_economies_extended) - set(least_developed_countries))

In [None]:
df_all_countries = pd.DataFrame(all_economies_wesp_plus_UN, columns=['country'])
# Print countires in views_country_list which are not in all_economies_wesp_plus_UN
divergences = []
for country in views_country_list:
    if country not in all_economies_wesp_plus_UN:   
        divergences.append(country)
divergences_df = pd.DataFrame(divergences, columns=['country']); divergences_df

In [None]:
def are_lists_distinct(*lists):
    all_elements = []
    for lst in lists:
        all_elements.extend(lst)
    return len(all_elements) == len(set(all_elements))

Explore Data of different country sets

In [None]:
from mappings import import_country_mapping
from data_gathering import gather_data_features, gather_data_actuals
from data_preparation import preprocess_data

country_mapping = import_country_mapping()

In [None]:
# Load features data
data_cm_features_2017, data_cm_features_2018, data_cm_features_2019, data_cm_features_2020, data_cm_features_allyears \
    = gather_data_features()

In [None]:
# Retrieve country_ids for different country sets
developed_countries_ids = country_mapping[country_mapping['name'].isin(developed_economies)]['country_id'].tolist()
countries_in_transition_ids = country_mapping[country_mapping['name'].isin(economies_in_transition)]['country_id'].tolist()
developing_countries_extended_without_LDC_ids = country_mapping[country_mapping['name'].isin(developing_economies_extended_without_LDC)]['country_id'].tolist()
least_developed_countries_ids = country_mapping[country_mapping['name'].isin(least_developed_countries)]['country_id'].tolist()


In [None]:
# Plot the distribution of the ged_sb variable for different country sets based on the features data
import seaborn as sns
import matplotlib.pyplot as plt
def plot_variable_distribution(df: pd.DataFrame, variable: str):
    fig, ax = plt.subplots(figsize=(8, 4))
    sns.histplot(data=df, x=variable, ax=ax, kde=True)
    ax.set_title(f'Distribution of {variable}')
    plt.show()

In [None]:
data_cm_features_allyears.describe()

In [None]:
# Plot the distribution of the ged_sb variable over all countries
plot_variable_distribution(data_cm_features_allyears, 'ged_sb')
# Plot the distribution of the ged_sb variable over all countries
plot_variable_distribution(np.log(data_cm_features_allyears[data_cm_features_allyears['ged_sb'] > 0]), 'ged_sb')
# Plot the distribution of the ged_sb variable over all countries
plot_variable_distribution(data_cm_features_allyears[data_cm_features_allyears['ged_sb'] < 1000], 'ged_sb')
# Plot the distribution of the ged_sb variable over all countries
plot_variable_distribution(data_cm_features_allyears[data_cm_features_allyears['ged_sb'] >= 1000], 'ged_sb')

In [None]:
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developed_countries_ids))], 'ged_sb')
plot_variable_distribution(np.log(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developed_countries_ids)) & (data_cm_features_allyears['ged_sb'] > 0)]), 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developed_countries_ids)) & (data_cm_features_allyears['ged_sb'] < 1000)], 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developed_countries_ids)) & (data_cm_features_allyears['ged_sb'] >= 1000)], 'ged_sb')

In [None]:
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(countries_in_transition_ids))], 'ged_sb')
plot_variable_distribution(np.log(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(countries_in_transition_ids)) & (data_cm_features_allyears['ged_sb'] > 0)]), 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(countries_in_transition_ids)) & (data_cm_features_allyears['ged_sb'] < 1000)], 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(countries_in_transition_ids)) & (data_cm_features_allyears['ged_sb'] >= 1000)], 'ged_sb')

In [None]:
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids))], 'ged_sb')
plot_variable_distribution(np.log(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids)) & (data_cm_features_allyears['ged_sb'] > 0)]), 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids)) & (data_cm_features_allyears['ged_sb'] < 1000)], 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids)) & (data_cm_features_allyears['ged_sb'] >= 1000)], 'ged_sb')

In [None]:
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(least_developed_countries_ids))], 'ged_sb')
plot_variable_distribution(np.log(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(least_developed_countries_ids)) & (data_cm_features_allyears['ged_sb'] > 0)]), 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(least_developed_countries_ids)) & (data_cm_features_allyears['ged_sb'] < 1000)], 'ged_sb')
plot_variable_distribution(data_cm_features_allyears[(data_cm_features_allyears['country_id'].isin(least_developed_countries_ids)) & (data_cm_features_allyears['ged_sb'] >= 1000)], 'ged_sb')

In [None]:
# Compute max and average values for ged_sb over all countries
print(f"Max value of ged_sb over all countries: {data_cm_features_allyears['ged_sb'].max()}")
print(f"Average value of ged_sb over all countries: {data_cm_features_allyears['ged_sb'].mean()}")

In [None]:
# Compute max and average values for ged_sb on different country sets
print(f"Max value of ged_sb for developed countries: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(developed_countries_ids)]['ged_sb'].max()}")
print(f"Max value of ged_sb for countries in transition: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(countries_in_transition_ids)]['ged_sb'].max()}")
print(f"Max value of ged_sb for developing countries: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids)]['ged_sb'].max()}")
print(f"Max value of ged_sb for least developed countries: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(least_developed_countries_ids)]['ged_sb'].max()}")

In [None]:
print(f"Average value of ged_sb for developed countries: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(developed_countries_ids)]['ged_sb'].mean()}")
print(f"Average value of ged_sb for countries in transition: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(countries_in_transition_ids)]['ged_sb'].mean()}")
print(f"Average value of ged_sb for developing countries: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids)]['ged_sb'].mean()}")
print(f"Average value of ged_sb for least developed countries: {data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(least_developed_countries_ids)]['ged_sb'].mean()}")

In [None]:
# Select data based on country sets
df_developed_countries = data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(developed_countries_ids)]
df_countries_in_transition = data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(countries_in_transition_ids)]
df_developing_countries_extended_without_LDC = data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(developing_countries_extended_without_LDC_ids)]
df_least_developed_countries = data_cm_features_allyears[data_cm_features_allyears['country_id'].isin(least_developed_countries_ids)]


In [None]:
# Store the set-specific data in a dictionary under the respective key
dict_country_sets = {
    "developed_countries": df_developed_countries,
    "countries_in_transition": df_countries_in_transition,
    "developing_countries_extended_without_LDC": df_developing_countries_extended_without_LDC,
    "least_developed_countries": df_least_developed_countries
}

In [None]:
# Use the funtion "temporary_standardization" to standardize the data for each country set, and "temporary_feature_selection" to select the features for each country set
for key in dict_country_sets.keys():
    df = dict_country_sets[key]
    df = temporary_feature_selection(df)
    df = temporary_standardization(df)
    dict_country_sets[key] = df

In [None]:
# Export the set-specific data to parquet files under the name "data_cm_features_allyears_{key}.parquet"
for key in dict_country_sets.keys():
    df = dict_country_sets[key]
    df.to_parquet(f"C:/Users/Uwe Drauz/Documents/bachelor_thesis_local/personal_competition_data/temp/data_cm_features_allyears_{key}.parquet")

In [None]:
# Preprocess data
lagged_covariates = ["ged_sb", "ged_sb_tsum_24", "decay_ged_sb_5", "decay_ged_sb_100", "decay_ged_sb_500"]
important_vdem_features = ["vdem_v2x_veracc",
                          "vdem_v2x_horacc",
                          "vdem_v2xnp_client",
                          "vdem_v2x_divparctrl",
                          "vdem_v2xpe_exlpol",
                          "vdem_v2xpe_exlsocgr"]
important_wdi_features = ["wdi_ms_mil_xpnd_zs", 
                          "wdi_sm_pop_refg_or",
                          "wdi_sm_pop_netm",
                          "wdi_sp_pop_grow",
                          "wdi_dt_oda_odat_pc_zs"]
covariates =  important_vdem_features + important_wdi_features

In [None]:
from importlib import reload



In [None]:
data = preprocess_data(df=data_cm_features_allyears, covariates=covariates, lagged_covariates=lagged_covariates, standardize=True)

In [None]:
data.to_parquet("C:/Users/Uwe Drauz/Documents/bachelor_thesis_local/personal_competition_data/temp/special_feature_selection_std.parquet")

In [None]:
from data_visualization import global_scatter_plot

In [None]:
# Plot the vdem and wdi indices against ged_sb
for vdem in important_vdem_features:
    global_scatter_plot(data, vdem, "ged_sb", output_dir="C:/Users/Uwe Drauz/Documents/bachelor_thesis_local/personal_competition_data/temp")

In [None]:
for wdi in important_wdi_features:
    global_scatter_plot(data, wdi, "ged_sb", output_dir="C:/Users/Uwe Drauz/Documents/bachelor_thesis_local/personal_competition_data/temp")

In [None]:
global_scatter_plot(data, "decay_ged_sb_100_tlag_3", "ged_sb", output_dir="C:/Users/Uwe Drauz/Documents/bachelor_thesis_local/personal_competition_data/temp")

In [None]:
result_data_2021[result_data_2021['test_score'] > 2000].sort_values(by=['country_id', 'month_id'])

In [None]:
result_data_2021[result_data_2021['test_score'] > 2000].sum()


In [None]:
from mappings import map_month_id_to_datetime
from mappings import map_country_id_to_country_name

In [None]:
import mappings
reload(mappings)
from mappings import map_country_id_to_country_name

In [None]:
data_cm_actual_2018, data_cm_actual_2019, data_cm_actual_2020, data_cm_actual_2021, data_cm_actual_allyears \
    = gather_data_actuals()

In [None]:
map_month_id_to_datetime(498)

In [None]:
map_country_id_to_country_name(133)

In [None]:
data_cm_actual_2021[(data_cm_actual_2021['country_id'] == 133)]

In [None]:
import pandas as pd
import os
from mappings import map_country_id_to_country_name, map_month_id_to_datetime
from data_gathering import gather_data_features, gather_data_actuals
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
data_cm_actual_2018, data_cm_actual_2019, data_cm_actual_2020, data_cm_actual_2021, data_cm_actual_allyears \
    = gather_data_actuals()

In [None]:
countries_with_avg_more_than_5_fatalities_2018 = data_cm_actual_2018.groupby('country_id')['ged_sb'].mean()
countries_with_avg_more_than_5_fatalities_2018 = countries_with_avg_more_than_5_fatalities_2018[
    countries_with_avg_more_than_5_fatalities_2018 > 5].index.tolist()

countries_with_avg_more_than_5_fatalities_2019 = data_cm_actual_2019.groupby('country_id')['ged_sb'].mean()
countries_with_avg_more_than_5_fatalities_2019 = countries_with_avg_more_than_5_fatalities_2019[
    countries_with_avg_more_than_5_fatalities_2019 > 5].index.tolist()

countries_with_avg_more_than_5_fatalities_2020 = data_cm_actual_2020.groupby('country_id')['ged_sb'].mean()
countries_with_avg_more_than_5_fatalities_2020 = countries_with_avg_more_than_5_fatalities_2020[
    countries_with_avg_more_than_5_fatalities_2020 > 5].index.tolist()

countries_with_avg_more_than_5_fatalities_2021 = data_cm_actual_2021.groupby('country_id')['ged_sb'].mean()
countries_with_avg_more_than_5_fatalities_2021 = countries_with_avg_more_than_5_fatalities_2021[
    countries_with_avg_more_than_5_fatalities_2021 > 5].index.tolist()


In [None]:
result_data_2018_FH = pd.read_parquet("mlruns/3/5b3acdf73aeb4eecb8bb007157f3535d/artifacts/BaselineModel_results_all_countries2018.parquet")
result_data_2019_FH = pd.read_parquet("mlruns/3/df6e7240234c40cebc96397963780cf9/artifacts/BaselineModel_results_all_countries2019.parquet")
result_data_2020_FH = pd.read_parquet("mlruns/3/301944cf5760482da0291a8e9daca4e1/artifacts/BaselineModel_results_all_countries2020.parquet")
result_data_2021_FH = pd.read_parquet("mlruns/3/3cc10d101573487aa558f41b21de99c5/artifacts/BaselineModel_results_all_countries2021.parquet")

In [None]:
result_data_2018_FH_and_M = pd.read_parquet("mlruns/3/9d8897f7efd242bdb33fdfe0c8c1eb6a/artifacts/BaselineModel_results_all_countries2018.parquet")
result_data_2019_FH_and_M = pd.read_parquet("mlruns/3/80caad1b10bf4513b5b5fbb9e8827d49/artifacts/BaselineModel_results_all_countries2019.parquet")
result_data_2020_FH_and_M = pd.read_parquet("mlruns/3/47685285ae8b45d38e966b28d5c27579/artifacts/BaselineModel_results_all_countries2020.parquet")
result_data_2021_FH_and_M = pd.read_parquet("mlruns/3/021e0fc647be407d8924e4d134093adc/artifacts/BaselineModel_results_all_countries2021.parquet")

In [None]:
# Filter the results for countries with more than 5 fatalities on average
result_data_2018_FH = result_data_2018_FH[result_data_2018_FH['country_id'].isin(countries_with_avg_more_than_5_fatalities_2018)]
result_data_2019_FH = result_data_2019_FH[result_data_2019_FH['country_id'].isin(countries_with_avg_more_than_5_fatalities_2019)]
result_data_2020_FH = result_data_2020_FH[result_data_2020_FH['country_id'].isin(countries_with_avg_more_than_5_fatalities_2020)]
result_data_2021_FH = result_data_2021_FH[result_data_2021_FH['country_id'].isin(countries_with_avg_more_than_5_fatalities_2021)]


In [None]:
# Filter the results for countries with more than 5 fatalities on average
result_data_2018_FH_and_M = result_data_2018_FH_and_M[result_data_2018_FH_and_M['country_id'].isin(countries_with_avg_more_than_5_fatalities_2018)]
result_data_2019_FH_and_M = result_data_2019_FH_and_M[result_data_2019_FH_and_M['country_id'].isin(countries_with_avg_more_than_5_fatalities_2019)]
result_data_2020_FH_and_M = result_data_2020_FH_and_M[result_data_2020_FH_and_M['country_id'].isin(countries_with_avg_more_than_5_fatalities_2020)]
result_data_2021_FH_and_M = result_data_2021_FH_and_M[result_data_2021_FH_and_M['country_id'].isin(countries_with_avg_more_than_5_fatalities_2021)]

In [None]:
def get_test_scores(df):
    # Create an empty DataFrame to hold the rearranged data
    rearranged_df = pd.DataFrame()

    # Loop over each row in the input DataFrame
    for index, row in df.iterrows():
        # Get the test_score from the DataFrame
        test_score = row['test_score']
        # Update the rearranged DataFrame with the rolling_window_length
        country_id = row['country_id']
        month_id = row['month_id']
        if country_id not in rearranged_df.index:
            rearranged_df.loc[country_id, month_id] = test_score
        else:
            rearranged_df.at[country_id, month_id] = test_score
    # Convert the DataFrame to integer type
    rearranged_df = rearranged_df.astype(float)

    return rearranged_df

In [None]:
def get_rolling_window_lengths(df):
    # Create an empty DataFrame to hold the rearranged data
    rearranged_df = pd.DataFrame()

    # Loop over each row in the input DataFrame
    for index, row in df.iterrows():
        # Build the path to the rolling_window_length file for the current run_id
        rolling_window_length_path = f'mlruns\\2\\{row["run_id"]}\\params\\rolling_window_length'

        # Read the value of rolling_window_length from the file
        try:
            with open(rolling_window_length_path, 'r') as f:
                rolling_window_length = f.read().strip()
        except FileNotFoundError:
            print(f"File not found for RunID: {row['run_id']}")
            rolling_window_length = None

        # Update the rearranged DataFrame with the rolling_window_length
        country_id = row['country_id']
        month_id = row['month_id']
        if country_id not in rearranged_df.index:
            rearranged_df.loc[country_id, month_id] = rolling_window_length
        else:
            rearranged_df.at[country_id, month_id] = rolling_window_length
    # Convert the DataFrame to integer type
    rearranged_df = rearranged_df.astype(int)

    return rearranged_df

def rename_indices(df):
    df.index = df.index.map(map_country_id_to_country_name)
    df.columns = df.columns.map(lambda x: map_month_id_to_datetime(x).strftime('%B %Y'))
    return df

In [None]:
def compute_medians(df):
    median_by_month = df.median(axis=0)
    median_by_country = df.median(axis=1)
    return median_by_month.astype(int), median_by_country.astype(int)


In [None]:
def compute_means(df):
    mean_by_month = df.mean(axis=0)
    mean_by_country = df.mean(axis=1)
    return mean_by_month, mean_by_country

In [None]:
def concatenate_dataframes(dfs):
    return pd.concat(dfs, axis=1).sort_index(axis=1)


In [None]:
def create_heatmap_org(df, figsize, cmap, year, cv_approach, countries: str, export=False):
    # Check data type and range

    plt.figure(figsize=figsize)
    ax = sns.heatmap(df, cmap=cmap, vmin=1, vmax=36, annot=False, cbar=True)  # Added explicit color bar

    # Increase font size and add padding for better visibility
    plt.title(f'Training Months for {year} and CV-Approach {cv_approach}', fontsize=34, pad=40)
    plt.xlabel('Forecasted Months', fontsize=32, labelpad=10)
    plt.ylabel('Countries', fontsize=32, labelpad=10)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    # Adjust color bar label size
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=24)
    
    # Replace spaces with underscores in cv_approach for file naming
    if cv_approach == "FH and M":
        cv_approach = "FH_and_M"

    # Export the plot if required
    if export:
        export_path = rf"C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Plots\rolling_window_heatmap\rolling_window_length_{cv_approach}_{year}_{countries}.png"
        plt.savefig(export_path, dpi=300, bbox_inches='tight')
    
    plt.show()

In [None]:
def create_heatmap_org_test_score(df, figsize, cmap, year, cv_approach, countries: str, export=False):
    # Check data type and range

    plt.figure(figsize=figsize)
    ax = sns.heatmap(df, cmap=cmap, annot=False, cbar=True)  # Added explicit color bar

    # Increase font size and add padding for better visibility
    plt.title(f'Test Score for {year} and CV-Approach {cv_approach}', fontsize=34, pad=40)
    plt.xlabel('Forecasted Months', fontsize=32, labelpad=10)
    plt.ylabel('Countries', fontsize=32, labelpad=10)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    # Adjust color bar label size
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=24)
    
    # Replace spaces with underscores in cv_approach for file naming
    if cv_approach == "FH and M":
        cv_approach = "FH_and_M"

    # Export the plot if required
    if export:
        export_path = rf"C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Plots\baseline_test_score_heatmap\test_score_{cv_approach}_{year}_{countries}.png"
        plt.savefig(export_path, dpi=300, bbox_inches='tight')
    
    plt.show()

In [None]:
def create_heatmap_org_test_score_vmax(df, figsize, cmap, year, cv_approach, countries: str,vmax, export=False):
    # Check data type and range

    plt.figure(figsize=figsize)
    ax = sns.heatmap(df, cmap=cmap, vmax=vmax, annot=False, cbar=True)  # Added explicit color bar

    # Increase font size and add padding for better visibility
    plt.title(f'Test Score for {year} and CV-Approach {cv_approach}', fontsize=34, pad=40)
    plt.xlabel('Forecasted Months', fontsize=32, labelpad=10)
    plt.ylabel('Countries', fontsize=32, labelpad=10)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    # Adjust color bar label size
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=24)
    
    # Replace spaces with underscores in cv_approach for file naming
    if cv_approach == "FH and M":
        cv_approach = "FH_and_M"

    # Export the plot if required
    if export:
        export_path = rf"C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Plots\baseline_test_score_heatmap\test_score_{cv_approach}_{year}_{countries}_vmax{vmax}.png"
        plt.savefig(export_path, dpi=300, bbox_inches='tight')
    
    plt.show()

Rolling window lengths of month country_combination

In [None]:
result_data_2018_FH_and_M_rwl = get_rolling_window_lengths(result_data_2018_FH_and_M)
result_data_2019_FH_and_M_rwl = get_rolling_window_lengths(result_data_2019_FH_and_M)
result_data_2020_FH_and_M_rwl = get_rolling_window_lengths(result_data_2020_FH_and_M)
result_data_2021_FH_and_M_rwl = get_rolling_window_lengths(result_data_2021_FH_and_M)

In [None]:
result_data_2018_FH_and_M_rwl = rename_indices(result_data_2018_FH_and_M_rwl)
result_data_2019_FH_and_M_rwl = rename_indices(result_data_2019_FH_and_M_rwl)
result_data_2020_FH_and_M_rwl = rename_indices(result_data_2020_FH_and_M_rwl)
result_data_2021_FH_and_M_rwl = rename_indices(result_data_2021_FH_and_M_rwl)

In [None]:
result_data_2018_FH_rwl = get_rolling_window_lengths(result_data_2018_FH)
result_data_2019_FH_rwl = get_rolling_window_lengths(result_data_2019_FH)
result_data_2020_FH_rwl = get_rolling_window_lengths(result_data_2020_FH)
result_data_2021_FH_rwl = get_rolling_window_lengths(result_data_2021_FH)


In [None]:
result_data_2018_FH_rwl = rename_indices(result_data_2018_FH_rwl)
result_data_2019_FH_rwl = rename_indices(result_data_2019_FH_rwl)
result_data_2020_FH_rwl = rename_indices(result_data_2020_FH_rwl)
result_data_2021_FH_rwl = rename_indices(result_data_2021_FH_rwl)

Test Scores of month country_combination

In [None]:
result_data_2018_FH_test_score = get_test_scores(result_data_2018_FH)
result_data_2019_FH_test_score = get_test_scores(result_data_2019_FH)
result_data_2020_FH_test_score = get_test_scores(result_data_2020_FH)
result_data_2021_FH_test_score = get_test_scores(result_data_2021_FH)

In [None]:
result_data_2018_FH_test_score = rename_indices(result_data_2018_FH_test_score)
result_data_2019_FH_test_score = rename_indices(result_data_2019_FH_test_score)
result_data_2020_FH_test_score = rename_indices(result_data_2020_FH_test_score)
result_data_2021_FH_test_score = rename_indices(result_data_2021_FH_test_score)


In [None]:
result_data_2018_FH_and_M_test_score = get_test_scores(result_data_2018_FH_and_M)
result_data_2019_FH_and_M_test_score = get_test_scores(result_data_2019_FH_and_M)
result_data_2020_FH_and_M_test_score = get_test_scores(result_data_2020_FH_and_M)
result_data_2021_FH_and_M_test_score = get_test_scores(result_data_2021_FH_and_M)


In [None]:
result_data_2018_FH_and_M_test_score = rename_indices(result_data_2018_FH_and_M_test_score)
result_data_2019_FH_and_M_test_score = rename_indices(result_data_2019_FH_and_M_test_score)
result_data_2020_FH_and_M_test_score = rename_indices(result_data_2020_FH_and_M_test_score)
result_data_2021_FH_and_M_test_score = rename_indices(result_data_2021_FH_and_M_test_score)

In [None]:
all_years_FH_rwl = concatenate_dataframes([result_data_2018_FH_rwl, result_data_2019_FH_rwl, result_data_2020_FH_rwl, result_data_2021_FH_rwl])
all_years_FH_and_M_rwl = concatenate_dataframes([result_data_2018_FH_and_M_rwl, result_data_2019_FH_and_M_rwl, result_data_2020_FH_and_M_rwl, result_data_2021_FH_and_M_rwl])

In [None]:
FH_and_M_median_by_month, FH_and_M_median_by_country = compute_medians(all_years_FH_and_M_rwl)
FH_median_by_month, FH_median_by_country = compute_medians(all_years_FH_rwl)

In [None]:
FH_and_M_mean_by_month, FH_and_M_mean_by_country = compute_means(all_years_FH_and_M_rwl)
FH_mean_by_month, FH_mean_by_country = compute_means(all_years_FH_rwl)

In [None]:
figsize_all_countries = (12, 48)
figsize_selected_on_avg = (18, 15)
cmap = "Blues"

In [None]:
create_heatmap_org(result_data_2018_FH_rwl, figsize_selected_on_avg, cmap, 2018, "FH", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2019_FH_rwl, figsize_selected_on_avg, cmap, 2019, "FH", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2020_FH_rwl, figsize_selected_on_avg, cmap, 2020, "FH", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2021_FH_rwl, figsize_selected_on_avg, cmap, 2021, "FH", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2018_FH_and_M_rwl, figsize_selected_on_avg, cmap, 2018, "FH and M", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2019_FH_and_M_rwl, figsize_selected_on_avg, cmap, 2019, "FH and M", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2020_FH_and_M_rwl, figsize_selected_on_avg, cmap, 2020, "FH and M", countries="selected_on_avg", export=True)
create_heatmap_org(result_data_2021_FH_and_M_rwl, figsize_selected_on_avg, cmap, 2021, "FH and M", countries="selected_on_avg", export=True)

In [None]:
cmap = "Blues"

In [None]:
create_heatmap_org_test_score(result_data_2018_FH_test_score, figsize_selected_on_avg, cmap, 2018, "FH", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2019_FH_test_score, figsize_selected_on_avg, cmap, 2019, "FH", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2020_FH_test_score, figsize_selected_on_avg, cmap, 2020, "FH", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2021_FH_test_score, figsize_selected_on_avg, cmap, 2021, "FH", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2018_FH_and_M_test_score, figsize_selected_on_avg, cmap, 2018, "FH and M", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2019_FH_and_M_test_score, figsize_selected_on_avg, cmap, 2019, "FH and M", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2020_FH_and_M_test_score, figsize_selected_on_avg, cmap, 2020, "FH and M", countries="selected_on_avg", export=True)
create_heatmap_org_test_score(result_data_2021_FH_and_M_test_score, figsize_selected_on_avg, cmap, 2021, "FH and M", countries="selected_on_avg", export=True)

In [None]:
export_path = rf"C:\Users\Uwe Drauz\Documents\bachelor_thesis_local\personal_competition_data\Plots\rolling_window_length_{validation_approach}_{year}.png"

In [None]:
# Export the plots to the path with correspinding name


In [None]:
# Determine the number of countries with a least one value in "test_score" greater than 0
result_data_2018_FH_and_M[result_data_2018_FH_and_M['test_score'] > 5]['country_id'].nunique()

In [None]:
data = data.pivot("month", "year", "passengers")
sns.heatmap(data, cmap="rocket")
plt.show()

In [None]:
glue = sns.load_dataset("glue").pivot(index="Model", columns="Task", values="Score")
sns.heatmap(glue, cmap="Blues_r")