# Pull data from World Bank WDI and save to local CSV

In [1]:
"""
Pull data from World Bank WDI and save to local CSV

Dependencies:
    pip install wbgapi pandas
"""

import wbgapi as wb
import pandas as pd
import os

def fetch_and_save_wdi(
    indicators: list,
    countries: list = None,
    years: list = None,
    output_dir: str = "data/raw",
    output_filename: str = "wdi_data.csv"
):
    """
    Fetch World Bank WDI data and save as CSV.

    Parameters:
        indicators: List of indicator codes, e.g. ["NY.GDP.PCAP.KD", "SL.EMP.TOTL.SP.ZS", "NY.GDP.MKTP.KD.ZG"]
        countries: List of country ISO-2/ISO-3 codes, default None means all available countries
        years: List of year range, e.g. list(range(2000, 2024))
        output_dir: Local save directory
        output_filename: Output file name
    """
    os.makedirs(output_dir, exist_ok=True)

    if countries is None:
        countries = wb.economy.list()  # Get codes for all economies

    # If years not specified, use 1960–latest
    if years is None:
        years = list(range(1960, pd.Timestamp.now().year + 1))

    # Fetch data: DataFrame with row index (economy, year), columns as indicators
    df = wb.data.DataFrame(
        indicators,
        economy=countries,
        time=years,
        labels=True
    )

    # Reset index, convert economy and time to regular columns
    df = df.reset_index().rename(columns={"economy": "country", "time": "year"})

    # Save as CSV (no index column)
    output_path = os.path.join(output_dir, output_filename)
    df.to_csv(output_path, index=False, encoding="utf-8-sig")

    print(f"Saved WDI data to {output_path}")

if __name__ == "__main__":
    indicators = [
        "NY.GDP.PCAP.KD",
        "SL.EMP.TOTL.SP.ZS",
        "NY.GDP.MKTP.KD.ZG"
    ]

    # ["USA","CHN","IND"]
    countries = ["USA", "CHN", "IND"]

    # Specify year range
    years = list(range(2000, 2024))

    # Execute fetch and save
    fetch_and_save_wdi(
        indicators=indicators,
        countries=countries,
        years=years,
        output_dir="../data/raw",
        output_filename="economic_dev_2000_2023.csv"
    )

Saved WDI data to ../data/raw/economic_dev_2000_2023.csv


# Data cleaning

In [2]:
"""
Use DuckDB in Python to perform SQL cleaning/unpivoting of WDI wide table, and export long-format CSV.
Dependencies:
    pip install duckdb pandas
"""

import duckdb
import pandas as pd
import os

RAW_CSV        = "../data/raw/economic_dev_2000_2023.csv"
OUTPUT_CSV     = "../data/clean/wdi_long_clean.csv"
DB_FILE        = "../data/tmp/wdi.duckdb"

os.makedirs(os.path.dirname(OUTPUT_CSV), exist_ok=True)
os.makedirs(os.path.dirname(DB_FILE), exist_ok=True)

# 1) Connect to DuckDB (file storage), or create in-memory database with ":memory:"
con = duckdb.connect(database=DB_FILE, read_only=False)

# 2) Read CSV into a DuckDB table raw_widi, rename to avoid case conflicts in SELECT
con.execute(f"""
CREATE OR REPLACE TABLE raw_wdi AS
SELECT
    country        AS country_code,
    series         AS indicator_code,
    "Country"      AS country_name,
    "Series"       AS indicator_name,
    YR2000, YR2001, YR2002, YR2003, YR2004, YR2005,
    YR2006, YR2007, YR2008, YR2009, YR2010, YR2011,
    YR2012, YR2013, YR2014, YR2015, YR2016, YR2017,
    YR2018, YR2019, YR2020, YR2021, YR2022, YR2023
FROM read_csv_auto('{RAW_CSV}');
""")

# 3) Use SQL UNPIVOT to transform wide table to long table, filter out NULLs
con.execute("""
CREATE OR REPLACE TABLE clean_wdi AS
SELECT
    country_code,
    indicator_code,
    country_name,
    indicator_name,
    CAST(REPLACE(col, 'YR', '') AS INTEGER) AS year,
    value
FROM raw_wdi
UNPIVOT (
    value FOR col IN (
        YR2000, YR2001, YR2002, YR2003, YR2004, YR2005,
        YR2006, YR2007, YR2008, YR2009, YR2010, YR2011,
        YR2012, YR2013, YR2014, YR2015, YR2016, YR2017,
        YR2018, YR2019, YR2020, YR2021, YR2022, YR2023
    )
)
WHERE value IS NOT NULL
ORDER BY country_code, indicator_code, year;
""")

# 4) Export cleaned long table to local CSV
df_clean = con.execute("SELECT * FROM clean_wdi").df()
df_clean.to_csv(OUTPUT_CSV, index=False, encoding="utf-8-sig")

con.close()
print("Clean data saved to", OUTPUT_CSV)

Clean data saved to ../data/clean/wdi_long_clean.csv


# Exploratory analysis

In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import os

# Configuration
INPUT_CSV = "../data/clean/wdi_long_clean.csv"
FIGURES_DIR = "../figures/exploratory"
os.makedirs(FIGURES_DIR, exist_ok=True)

# Load data
df = pd.read_csv(INPUT_CSV)

# Descriptive statistics table (as before)
desc = df.groupby(['indicator_code','country_code'])['value'].describe().round(2)
desc.to_csv(os.path.join(FIGURES_DIR, "descriptive_stats.csv"))
pivot_mean = df.groupby(['country_code','indicator_code'])['value'].mean().unstack().round(2)
pivot_mean.to_csv(os.path.join(FIGURES_DIR, "mean_pivot.csv"))

# 1. Trend line chart: Increase size, mark, annotate last value
for indicator in df['indicator_code'].unique():
    sub = df[df['indicator_code']==indicator]
    plt.figure(figsize=(10, 6))
    for country in sub['country_code'].unique():
        csub = sub[sub['country_code']==country]
        plt.plot(csub['year'], csub['value'], marker='o', linewidth=2, label=country)
        # Annotate last point
        last = csub.iloc[-1]
        plt.text(last['year'], last['value'], f"{last['value']:.1f}", va='bottom', ha='right')
    plt.title(f"Trend of {indicator}", fontsize=14)
    plt.xlabel("Year", fontsize=12)
    plt.ylabel("Value", fontsize=12)
    plt.xticks(rotation=45)
    plt.legend(title="Country")
    plt.grid(linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"trend_{indicator}.png"))
    plt.close()

# 2. Bar comparison for 2023: Show value on bars
latest = df[df['year']==df['year'].max()]
for indicator in latest['indicator_code'].unique():
    sub = latest[latest['indicator_code']==indicator]
    plt.figure(figsize=(8, 5))
    bars = plt.bar(sub['country_code'], sub['value'])
    plt.title(f"{indicator} in {int(df['year'].max())}", fontsize=14)
    plt.xlabel("Country", fontsize=12)
    plt.ylabel("Value", fontsize=12)
    # Annotate value at bar top
    for bar in bars:
        h = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, h, f"{h:.1f}", ha='center', va='bottom')
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"bar_{indicator}_2023.png"))
    plt.close()

# 3. Correlation heatmap: Use pcolormesh and add annotations
for country in df['country_code'].unique():
    sub = df[df['country_code']==country]
    wide = sub.pivot(index='year', columns='indicator_code', values='value')
    corr = wide.corr()
    plt.figure(figsize=(6, 5))
    mesh = plt.pcolormesh(corr.values, edgecolors='k', linewidth=0.5)
    plt.xticks(range(len(corr)), corr.columns, rotation=45)
    plt.yticks(range(len(corr)), corr.index)
    plt.title(f"Indicator Correlation ({country})", fontsize=14)
    # Annotate each cell
    for i in range(corr.shape[0]):
        for j in range(corr.shape[1]):
            plt.text(j + 0.5, i + 0.5, f"{corr.iloc[i,j]:.2f}", ha='center', va='center')
    plt.colorbar(mesh)
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"corr_{country}.png"))
    plt.close()

# 4. Histogram: Overall distribution of indicators
for indicator in df['indicator_code'].unique():
    sub = df[df['indicator_code']==indicator]
    plt.figure(figsize=(8, 5))
    plt.hist(sub['value'], bins=15)
    plt.title(f"Distribution of {indicator}", fontsize=14)
    plt.xlabel("Value", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.grid(axis='y', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"hist_{indicator}.png"))
    plt.close()

# 5. Boxplot: Compare indicator distribution by country
for indicator in df['indicator_code'].unique():
    sub = df[df['indicator_code']==indicator]
    plt.figure(figsize=(8, 5))
    data = [sub[sub['country_code']==c]['value'] for c in sub['country_code'].unique()]
    plt.boxplot(data, labels=sub['country_code'].unique())
    plt.title(f"Boxplot of {indicator} by Country", fontsize=14)
    plt.xlabel("Country", fontsize=12)
    plt.ylabel("Value", fontsize=12)
    plt.grid(axis='y', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"box_{indicator}.png"))
    plt.close()

# 6. Trend smoothing: Moving average (window=3)
for indicator in df['indicator_code'].unique():
    sub = df[df['indicator_code']==indicator]
    plt.figure(figsize=(10, 6))
    for country in sub['country_code'].unique():
        csub = sub[sub['country_code']==country].set_index('year')
        smooth = csub['value'].rolling(window=3, center=True).mean()
        plt.plot(smooth.index, smooth.values, linewidth=2, label=country)
    plt.title(f"3-Year Moving Avg of {indicator}", fontsize=14)
    plt.xlabel("Year", fontsize=12)
    plt.ylabel("Smoothed Value", fontsize=12)
    plt.xticks(rotation=45)
    plt.legend(title="Country")
    plt.grid(linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"smooth_{indicator}.png"))
    plt.close()

print("Improved exploratory analysis complete; figures saved to", FIGURES_DIR)

Improved exploratory analysis complete; figures saved to ../figures/exploratory
