In [1]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr

In [5]:
# G20 countries with ISO codes
G20_COUNTRIES = {
    "Argentina": "ARG", 
    "Australia": "AUS", 
    "Brazil": "BRA", 
    "Canada": "CAN", 
    "China": "CHN", 
    "France": "FRA", 
    "Germany": "DEU", 
    "India": "IND", 
    "Indonesia": "IDN", 
    "Italy": "ITA", 
    "Japan": "JPN", 
    "Mexico": "MEX", 
    "Russia": "RUS", 
    "Saudi Arabia": "SAU", 
    "South Africa": "ZAF", 
    "Korea, Rep.": "KOR", 
    "Turkey": "TUR", 
    "United Kingdom": "GBR", 
    "United States": "USA"
}

# Alternative names that might appear in the dataset
G20_ALTERNATIVE_NAMES = {
    "Korea, Rep.": ["South Korea", "Republic of Korea"],
    "Russia": ["Russian Federation"],
    "United Kingdom": ["UK", "Great Britain"],
    "United States": ["USA", "US"]
}

# Reverse mapping from ISO to standard name for visualization
ISO_TO_NAME = {iso: name for name, iso in G20_COUNTRIES.items()}

# Primary data loading function - reads the JSON file containing merged SIPRI and World Bank data
def load_data(filepath):
    """Load the merged dataset from JSON file"""
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

# Extracts only G20 countries from the full dataset for focused analysis
def extract_g20_countries(data):
    """Extract G20 countries from the dataset using ISO codes"""
    g20_data = []
    
    # Create a lookup set of G20 ISO codes
    g20_iso_set = set(G20_COUNTRIES.values())
    
    # Create a lookup set with all possible G20 country names
    g20_names_set = set(G20_COUNTRIES.keys())
    for country, alternatives in G20_ALTERNATIVE_NAMES.items():
        g20_names_set.update(alternatives)
    
    for country in data.get('countries', []):
        # First try matching by ISO code (more reliable)
        if country['ISO'] in g20_iso_set:
            country_copy = country.copy()
            # Store the standard name based on ISO
            country_copy['std_name'] = ISO_TO_NAME.get(country['ISO'], country['name'])
            g20_data.append(country_copy)
        # If ISO doesn't match, try matching by name as fallback
        elif country['name'] in g20_names_set:
            # Map alternative names to standard G20 name
            std_name = country['name']
            for g20_name, alternatives in G20_ALTERNATIVE_NAMES.items():
                if country['name'] in alternatives:
                    std_name = g20_name
                    break
            
            country_copy = country.copy()
            country_copy['std_name'] = std_name
            # If we can determine the ISO from the name, use that
            country_copy['std_ISO'] = G20_COUNTRIES.get(std_name, country['ISO'])
            g20_data.append(country_copy)
    
    return g20_data

# Transforms nested country time series data into a flat pandas DataFrame for analysis
def convert_to_dataframe(countries_data):
    """Convert the time series data for countries into a pandas DataFrame"""
    rows = []
    
    for country in countries_data:
        country_name = country.get('std_name', country['name'])
        country_iso = country.get('std_ISO', country['ISO'])
        
        for year_data in country['time_series']:
            row = {'country': country_name, 'ISO': country_iso, 'year': year_data['year']}
            row.update(year_data)
            rows.append(row)
    
    df = pd.DataFrame(rows)
    return df


# test the functions
data = load_data('../../data/all_data_merged.json')
g20_data = extract_g20_countries(data)
df = convert_to_dataframe(g20_data)
print(df.head())

     country  ISO  year  military_expenditure  military_expenditure_gdp  \
0  Argentina  ARG  1949                5268.4                       NaN   
1  Argentina  ARG  1950                4176.1                   0.02914   
2  Argentina  ARG  1951                2830.8                       NaN   
3  Argentina  ARG  1952                2882.9                       NaN   
4  Argentina  ARG  1953                3759.1                       NaN   

   Rural population (% of total population)  \
0                                       NaN   
1                                       NaN   
2                                       NaN   
3                                       NaN   
4                                       NaN   

   International migrant stock (% of population)  \
0                                            NaN   
1                                            NaN   
2                                            NaN   
3                                            NaN   
4     

In [6]:
# Creates a visual comparison between two indicators with country labels
def create_scatter_plot(df, x_indicator, y_indicator, year=None, countries=None, title=None, use_iso=True):
    """
    Create a scatter plot of two indicators for countries
    
    Parameters:
    df (DataFrame): DataFrame with country data
    x_indicator (str): Name of the indicator for x-axis
    y_indicator (str): Name of the indicator for y-axis
    year (int): Optional specific year to visualize
    countries (list): Optional list of specific country ISO codes to include
    title (str): Optional title for the plot
    use_iso (bool): Whether to use ISO codes or country names for identification
    """
    if year is not None:
        df = df[df['year'] == year]
    
    if countries is not None:
        id_column = 'ISO' if use_iso else 'country'
        df = df[df[id_column].isin(countries)]
    
    if x_indicator not in df.columns or y_indicator not in df.columns:
        print(f"Error: Indicators {x_indicator} or {y_indicator} not found in data")
        return
    
    # Drop rows with missing values for the selected indicators
    plot_df = df.dropna(subset=[x_indicator, y_indicator])
    
    if plot_df.empty:
        print("No data available for the selected indicators and filters")
        return
    
    plt.figure(figsize=(12, 8))
    
    # Create scatter plot using ISO for the hue
    scatter_plot = sns.scatterplot(data=plot_df, x=x_indicator, y=y_indicator, 
                                  hue='ISO' if use_iso else 'country', s=100)
    
    # Add labels to points
    for _, row in plot_df.iterrows():
        label = row['ISO'] if use_iso else row['country']
        plt.annotate(label, 
                    (row[x_indicator], row[y_indicator]),
                    textcoords="offset points",
                    xytext=(0, 5),
                    ha='center')
    
    plt.xlabel(x_indicator)
    plt.ylabel(y_indicator)
    
    if year:
        title_suffix = f" ({year})"
    else:
        title_suffix = ""
        
    plt.title(title or f"Relationship between {x_indicator} and {y_indicator}{title_suffix}")
    plt.tight_layout()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Update legend with country names if needed
    if not use_iso:
        handles, labels = scatter_plot.get_legend_handles_labels()
        plt.legend(handles=handles, labels=labels)
    else:
        handles, labels = scatter_plot.get_legend_handles_labels()
        plt.legend(handles=handles, labels=[ISO_TO_NAME.get(iso, iso) for iso in labels])
    
    plt.show()

# test the function
create_scatter_plot(df, 'GDP (constant 2010 US$)', 'Military expenditure (% of GDP)', year=2018, title="GDP vs Military Expenditure")


Error: Indicators GDP (constant 2010 US$) or Military expenditure (% of GDP) not found in data
