In [None]:
#-------------------------
# import packages
#-------------------------

import requests
import pycountry
import pandas as pd

#-------------------------
# retrieve IMF trade data
#-------------------------

# get list of country codes
country_codes = [country.alpha_2 for country in pycountry.countries]

# base API URL
base_url = 'http://dataservices.imf.org/REST/SDMX_JSON.svc/'

def get_imf_dots_data(reporter, indicator = 'TXG_FOB_USD', freq = 'Q', startPeriod = '1980', endPeriod = '2024'):
    """
    Retrieve IMF DOTS data for a given reporter country.
    
    reporter: ISO alpha-2 code for the reporter country (e.g., 'GB', 'US', 'CN', etc.)
    indicator: IMF indicator code ('TMG_CIF_USD' for imports, 'TXG_FOB_USD' for exports)
    freq: frequency ('A' for annual, 'Q' for quarterly, 'M' for monthly)
    startPeriod: starting period (e.g., '2000')
    endPeriod: ending period (e.g., '2023')
    """
    # construct the key - here, partner is omitted to retrieve all partners
    key = f'CompactData/DOT/{freq}.{reporter}.{indicator}'
    # append the time period parameters to the url
    full_url = f'{base_url}{key}?startPeriod = {startPeriod}&endPeriod={endPeriod}'
    print(f"requesting url: {full_url}")
    # include a timeout and accept header for json
    try:
        response = requests.get(full_url, headers = {"Accept": "application/json"}, timeout = 30)
    except requests.exceptions.RequestException as e:
        print(f"network error for {reporter}: {e}")
        return None
    if response.status_code == 200:
        try:
            # navigate into the JSON structure to extract the series element
            data = response.json()['CompactData']['DataSet']['Series']
            return data
        except Exception as e:
            print(f"error parsing data for {reporter}: {e}")
            return None
    else:
        print(f"error fetching data for {reporter}: {response.status_code}")
        return None

# list of reporter country codes you want to retrieve data for
reporters = country_codes  # adjust this list as needed

# dictionary to store the results
data_dict = {}

for rep in reporters:
    print(f"\nFetching data for {rep}...")
    series_data = get_imf_dots_data(rep)
    if series_data is not None:
        data_dict[rep] = series_data
    else:
        print(f"No data returned for {rep}.")

#-------------------------
# format data as table and save
#-------------------------

# list to store all observation records
records = []

# loop over each reporter (key) in data_dict
for rep, series in data_dict.items():
    # if series is a list, process each series separately
    if isinstance(series, list):
        for s in series:
            # get counterpart_area from the series level
            cp = s.get('@COUNTERPART_AREA')
            obs = s.get('Obs')
            if isinstance(obs, list):
                for o in obs:
                    records.append({
                        'reporter': rep,
                        'counterpart_area': cp,  # from series level
                        'time_period': o.get('@TIME_PERIOD'),
                        'value': o.get('@OBS_VALUE')
                    })
            elif isinstance(obs, dict):
                records.append({
                    'reporter': rep,
                    'counterpart_area': cp,
                    'time_period': obs.get('@TIME_PERIOD'),
                    'value': obs.get('@OBS_VALUE')
                })
    # if series is a dictionary, process it directly
    elif isinstance(series, dict):
        cp = series.get('@COUNTERPART_AREA')
        obs = series.get('Obs')
        if isinstance(obs, list):
            for o in obs:
                records.append({
                    'reporter': rep,
                    'counterpart_area': cp,
                    'time_period': o.get('@TIME_PERIOD'),
                    'value': o.get('@OBS_VALUE')
                })
        elif isinstance(obs, dict):
            records.append({
                'reporter': rep,
                'counterpart_area': cp,
                'time_period': obs.get('@TIME_PERIOD'),
                'value': obs.get('@OBS_VALUE')
            })

# convert the list of records into a dataframe
df = pd.DataFrame(records)

# optionally, convert time_period to datetime if the format is standard
df['time_period'] = pd.to_datetime(df['time_period'], errors = 'coerce')

# save the dataframe to a CSV file
df.to_csv("../../data/raw/imf_exports_quarterly.csv", index = False)