In [24]:
# Preliminaries
import os
import numpy as np
import pandas as pd
import pyarrow.feather as feather
# If you want pyarrow as the backend for parquet (recommended)
# pip install pyarrow
# and optionally:
# pd.set_option("io.parquet.engine", "pyarrow")

#---------------------------------------
# Years
#---------------------------------------
years = list(range(2016, 2023))
base_path = os.path.join("..") 
#---------------------------------------
# Paths (adjust as needed)
#---------------------------------------
path_comtrade = os.path.join(base_path, "raw", "comtrade")
path_wits     = os.path.join(base_path, "raw", "wits")
path_oecd_tc  = os.path.join(base_path, "raw", "oecd")
path_out      = os.path.join(base_path, "clean", "flows")

In [25]:
#---------------------------------------
# Constants / helpers
#---------------------------------------
toosmall_territories = [
    "Åland Islands ","American Samoa","Antarctica","Bonaire","Bouvet Island","Br. Antarctic Terr.",
    "Br. Indian Ocean Terr.","Br. Virgin Isds","Christmas Isds","Cocos Isds","Cook Isds","Curaçao",
    "Europe EU, nes","Faeroe Isds","Falkland Isds (Malvinas)","Fr. South Antarctic Terr.","FS Micronesia",
    "Heard Island and McDonald Islands","French Polynesia","Gibraltar","Guam","Guernsey",
    "Heard Island and McDonald Islands","Holy See (Vatican City State)","Isle of Man ","Jersey","Libya",
    "Liechtenstein ","Marshall Isds","Martinique (Overseas France)","Metropolitan France","Montserrat",
    "Neutral Zone","New Caledonia","N. Mariana Isds","Norfolk Isds",
    "Norway, excluding Svalbard and Jan Mayen","Niue","Pitcairn","Puerto Rico ",
    "Saint Barthélemy","Saint Helena","Saint Maarten","Saint Martin (French part) ",
    "Saint Pierre and Miquelon","South Georgia and the South Sandwich Islands",
    "Svalbard and Jan Mayen Islands ","Switzerland ","Taiwan, Province of China","Tokelau",
    "Turks and Caicos Isds","United States Minor Outlying Islands","US Misc. Pacific Isds",
    "United States of America","Wallis and Futuna Isds","Western Sahara"
]

eu_countries = [
    "Austria","Belgium","Bulgaria","Croatia","Cyprus","Czechia","Denmark","Estonia",
    "Finland","France","Germany","Greece","Hungary","Ireland","Italy","Latvia","Lithuania",
    "Luxembourg","Malta","Netherlands","Poland","Portugal","Romania",
    "Slovakia","Slovenia","Spain","Sweden"
]

# Name normalizations
rename_map = {
    "USA":                              "United States",
    "Russian Federation":               "Russia",
    "United Rep. of Tanzania":          "Tanzania",
    "Rep. of Korea":                    "South Korea",
    "China, Hong Kong SAR":             "Hong Kong",
    "China, Macao SAR":                 "Macau",
    "Bolivia (Plurinational State of)": "Bolivia",
    "Viet Nam":                         "Vietnam",
    "Brunei Darussalam":                "Brunei",
    "Lao People's Dem. Rep.":           "Laos",
    "Dem. People's Rep. of Korea":      "North Korea",
    "Dem. Rep. of the Congo":           "DR Congo",
    "Côte d'Ivoire":                    "Ivory Coast",
    "Bosnia Herzegovina":               "Bosnia and Herzegovina",
    "Rep. of Moldova":                  "Moldova",
    "Dominican Rep.":                   "Dominican Republic",
    "Cayman Isds":                      "Cayman Islands",
    "Faroe Isds":                       "Faroe Islands",
    "Solomon Isds":                     "Solomon Islands",
    "Cabo Verde":                       "Cape Verde",
    "Timor-Leste":                      "East Timor",
    "State of Palestine":               "Palestine",
    "Central African Rep.":             "Central African Republic",
    "Saint Kitts and Nevis":            "St. Kitts and Nevis",
    "Saint Vincent and the Grenadines": "St. Vincent and the Grenadines",
    "Saint Lucia":                      "St. Lucia"
}

rename_tariffs = {
    "Korea, Rep.":                 "South Korea",
    "Slovak Republic":             "Slovakia",
    "Czech Republic":              "Czechia",
    "Russian Federation":          "Russia",
    "Congo, Dem. Rep.":            "DR Congo",
    "Congo, Rep.":                 "Congo",
    "Iran, Islamic Rep.":          "Iran",
    "Lao PDR":                     "Laos",
    "Kyrgyz Republic":             "Kyrgyzstan",
    "Cote d'Ivoire":               "Ivory Coast",
    "Hong Kong, China":            "Hong Kong",
    "Macao":                       "Macau",
    "Egypt, Arab Rep.":            "Egypt",
    "Bahamas, The":                "Bahamas",
    "Gambia, The":                 "Gambia",
    "Serbia, FR(Serbia/Montenegro)": "Serbia",
    "Syrian Arab Republic":        "Syria",
    "Ethiopia(excludes Eritrea)":  "Ethiopia",
    "Turkey":                      "Türkiye",
    "Korea, Dem. Rep.":            "North Korea"
}

rename_tc = {
    "Myanmar (Burma)":          "Myanmar",
    "Bosnia & Herzegovina":     "Bosnia and Herzegovina",
    "Trinidad & Tobago":        "Trinidad and Tobago",
    "Congo - Kinshasa":         "DR Congo",
    "Congo - Brazzaville":      "Congo",
    "Turkey":                   "Türkiye",
    "Timor-Leste":              "East Timor",
    "St. Vincent & Grenadines": "St. Vincent and the Grenadines",
    "St. Kitts & Nevis":        "St. Kitts and Nevis",
    "Hong Kong SAR China":      "Hong Kong",
    "Antigua & Barbuda":        "Antigua and Barbuda"
}

def normalize_with_map(series: pd.Series, mapping: dict) -> pd.Series:
  """
  Mimics the R normalize_with_map: only replace when key is in mapping.
  """
  return series.replace(mapping)

#---------------------------------------
# Load transport costs once & normalize
#---------------------------------------
tc_feather_path = os.path.join(path_oecd_tc, "transport_costs.feather")

tc_table = feather.read_table(tc_feather_path)
tc = tc_table.to_pandas()

# Match your R code logic
tc = tc.rename(columns={"hs": "hs4"})
tc["exporter"] = normalize_with_map(tc["exporter"], rename_tc)
tc["importer"] = normalize_with_map(tc["importer"], rename_tc)

In [26]:
#---------------------------------------
# MAIN LOOP — Creates flows_final_{year}
#---------------------------------------
flows_final_by_year = {}  # dict instead of globals

for yy in years:
    # 1) Flows for year yy
    flows_y = feather.read_table(
        os.path.join(path_comtrade, f"bulktrade_hs6_{yy}.feather")
    )

    flows_y = (
        flows_y[["exporter", "importer", "period", "hs6", "hs6_desc",
                 "primary_value", "kg"]]
        .rename(columns={"period": "year"})
    )

    # Convert year to numeric
    flows_y["year"] = pd.to_numeric(flows_y["year"], errors="coerce")

    # Filter
    flows_y = flows_y[
        (~flows_y["exporter"].isin(toosmall_territories)) &
        (~flows_y["importer"].isin(toosmall_territories)) &
        (flows_y["kg"].notna()) &
        (flows_y["kg"] > 0) &
        (flows_y["primary_value"].notna())
    ]

    # Normalize country names
    flows_y["exporter"] = normalize_with_map(flows_y["exporter"], rename_map)
    flows_y["importer"] = normalize_with_map(flows_y["importer"], rename_map)

    # 2) Tariffs for year yy
    tariffs_y = feather.read_table(
        os.path.join(path_wits, f"tariffs_hs6_{yy}.feather")
    )
    tariffs_y = tariffs_y[["exporter", "importer", "year", "hs", "tau"]]

    tariffs_y["exporter"] = normalize_with_map(tariffs_y["exporter"], rename_tariffs)
    tariffs_y["importer"] = normalize_with_map(tariffs_y["importer"], rename_tariffs)

    # EU zero-tariffs logic
    # EU-EU -> 0
    mask_eu_eu = tariffs_y["exporter"].isin(eu_countries) & tariffs_y["importer"].isin(eu_countries)
    tariffs_y.loc[mask_eu_eu, "tau"] = 0

    # UK-EU pre-2021 -> 0
    mask_uk_eu = (
        ((tariffs_y["exporter"] == "United Kingdom") & tariffs_y["importer"].isin(eu_countries)) |
        ((tariffs_y["importer"] == "United Kingdom") & tariffs_y["exporter"].isin(eu_countries))
    ) & (tariffs_y["year"] < 2021)
    tariffs_y.loc[mask_uk_eu, "tau"] = 0

    # EU pooling for missing importer-EU tariffs
    eu_pool_y = (
        tariffs_y[tariffs_y["importer"].isin(eu_countries)]
        .groupby(["exporter", "year", "hs"], as_index=False)
        .agg(tau_eu=("tau", "mean"))
    )
    # merge & use tau_eu to fill missing tau when importer is EU
    tariffs_y = tariffs_y.merge(eu_pool_y, on=["exporter", "year", "hs"], how="left")
    mask_eu_missing = tariffs_y["importer"].isin(eu_countries) & tariffs_y["tau"].isna()
    tariffs_y.loc[mask_eu_missing, "tau"] = tariffs_y.loc[mask_eu_missing, "tau_eu"]
    tariffs_y = tariffs_y.drop(columns=["tau_eu"])

    # distinct(importer, exporter, year, hs, .keep_all = TRUE)
    tariffs_y = tariffs_y.drop_duplicates(
        subset=["importer", "exporter", "year", "hs"], keep="first"
    )

    # 3) Merge flows + tariffs + transport costs (year yy only)
    flows3_y = flows_y.merge(
        tariffs_y,
        left_on=["exporter", "importer", "year", "hs6"],
        right_on=["exporter", "importer", "year", "hs"],
        how="left"
    )
    flows3_y["hs4"] = flows3_y["hs6"].str[:4]

    # merge with transport costs for that year
    tc_yy = tc[tc["year"] == yy]
    flows_final_y = flows3_y.merge(
        tc_yy,
        on=["importer", "exporter", "year", "hs4"],
        how="left",
        suffixes=("", "_tc")
    ).drop(columns=["hs4", "hs"])

    flows_final_by_year[yy] = flows_final_y

TypeError: Index must either be string or integer

In [None]:
#---------------------------------------
# Bind all years
#---------------------------------------
final_all = pd.concat(flows_final_by_year.values(), ignore_index=True)

# Keep only rows with non-missing freight
final_all = final_all[final_all["freight"].notna()].copy()

# freight as proportion initially
final_all["freight"] = final_all["freight"] / 100.0

# fob for most importers
mask_non_special = ~final_all["importer"].isin(["Canada", "Bermuda", "South Africa"])
final_all["fob"] = final_all["primary_value"]
final_all.loc[mask_non_special, "fob"] = (
    final_all.loc[mask_non_special, "primary_value"]
    - final_all.loc[mask_non_special, "primary_value"] * final_all.loc[mask_non_special, "freight"]
)

# For Canada, Bermuda, South Africa: reconstruct cost then freight
final_all["cost"] = final_all["primary_value"]
mask_special = final_all["importer"].isin(["Canada", "Bermuda", "South Africa"])
final_all.loc[mask_special, "cost"] = (
    final_all.loc[mask_special, "primary_value"]
    / (1 - final_all.loc[mask_special, "freight"])
)
final_all["freight"] = final_all["cost"] - final_all["fob"]

# Drop cost and primary_value
final_all = final_all.drop(columns=["cost", "primary_value"])

In [None]:
#---------------------------------------
# Add all the exp/imp totals via groupby + transform
#---------------------------------------
# exporter, year, hs6
final_all["exp_prodvalue_toworld"] = (
    final_all.groupby(["exporter", "year", "hs6"])["fob"].transform("sum")
)
final_all["exp_prodkg_toworld"] = (
    final_all.groupby(["exporter", "year", "hs6"])["kg"].transform("sum")
)

# importer, year, hs6
final_all["imp_prodvalue_fromworld"] = (
    final_all.groupby(["importer", "year", "hs6"])["fob"].transform("sum")
)
final_all["imp_prodkg_fromworld"] = (
    final_all.groupby(["importer", "year", "hs6"])["kg"].transform("sum")
)

# exporter, year
final_all["exp_totalvalue_toworld"] = (
    final_all.groupby(["exporter", "year"])["fob"].transform("sum")
)
final_all["exp_totalkg_toworld"] = (
    final_all.groupby(["exporter", "year"])["kg"].transform("sum")
)

# importer, year
final_all["imp_totalvalue_fromworld"] = (
    final_all.groupby(["importer", "year"])["fob"].transform("sum")
)
final_all["imp_totalkg_fromworld"] = (
    final_all.groupby(["importer", "year"])["kg"].transform("sum")
)

# exporter, importer, year
final_all["exp_totalvalue_toimp"] = (
    final_all.groupby(["exporter", "importer", "year"])["fob"].transform("sum")
)
final_all["exp_totalkg_toimp"] = (
    final_all.groupby(["exporter", "importer", "year"])["kg"].transform("sum")
)

In [None]:
#---------------------------------------
# Fill tau over time (down then up) by exporter-importer-hs6
#---------------------------------------
final_all = final_all.sort_values(["exporter", "importer", "hs6", "year"])

group_keys = ["exporter", "importer", "hs6"]

# Pass 1: down (ffill)
final_all["tau_down_work"] = (
    final_all.groupby(group_keys)["tau"].ffill()
)
final_all["tau_after_down"] = np.where(
    final_all["tau"].isna(), final_all["tau_down_work"], final_all["tau"]
)

# Pass 2: up (bfill where still missing)
final_all["tau_up_work"] = (
    final_all.groupby(group_keys)["tau_after_down"].bfill()
)
final_all["tau_final"] = np.where(
    final_all["tau_after_down"].isna(),
    final_all["tau_up_work"],
    final_all["tau_after_down"]
)

# Replace tau, drop helpers
final_all["tau"] = final_all["tau_final"]
final_all = final_all.drop(columns=["tau_down_work", "tau_up_work", "tau_after_down", "tau_final"])

# tau: missing -> 0, then scale /100 to get ad valorem
final_all["tau"] = final_all["tau"].fillna(0.0)
final_all["tau"] = final_all["tau"] / 100.0

In [None]:
#---------------------------------------
# Keep relevant columns & order (mirrors your select())
#---------------------------------------
desired_cols = [
    "exporter", "importer", "year", "hs6", "hs6_desc",
    # R did some reordering using positions; here we explicitly choose
    "kg", "tau", "freight", "fob",
    "exp_prodvalue_toworld", "exp_prodkg_toworld",
    "imp_prodvalue_fromworld", "imp_prodkg_fromworld",
    "exp_totalvalue_toworld", "exp_totalkg_toworld",
    "imp_totalvalue_fromworld", "imp_totalkg_fromworld",
    "exp_totalvalue_toimp", "exp_totalkg_toimp"
]

# keep only those that actually exist (just in case tc had extra cols)
existing_desired_cols = [c for c in desired_cols if c in final_all.columns]
final_all2 = final_all[existing_desired_cols].copy()

#---------------------------------------
# Split back into flows_final_{year} and write parquet
#---------------------------------------
for yy in years:
    flows_final_y = final_all2[final_all2["year"] == yy].copy()
    flows_final_by_year[yy] = flows_final_y  # store in dict if needed

    out_path = os.path.join(path_out, f"allflows_forCCAs_{yy}.parquet")
    flows_final_y.to_parquet(out_path, index=False)
    print(f"Wrote {out_path}")