In [4]:
import pandas as pd
import numpy as np
# import pymc as pm
# import arviz as az
# from sklearn.preprocessing import StandardScaler
import matplotlib. pyplot as plt
import warnings
import chardet
warnings.filterwarnings('ignore')

def read_csv_auto(filepath, **kwargs):
    """自动检测编码并读取CSV"""
    with open(filepath, 'rb') as f:
        result = chardet.detect(f.read(100000))
        encoding = result['encoding']

    print(f"Detected: {encoding} (confidence: {result['confidence']:.2%})")

    try:
        return pd.read_csv(filepath, encoding=encoding, **kwargs)
    except:
        return pd.read_csv(filepath, encoding='latin1', **kwargs)

# =============================================================================
# 配置
# =============================================================================

GROUP = "Fruits (nec)"  # 选择要预测的商品

print(f"\n{'='*70}")
print(f"预测商品: {GROUP}")
print(f"{'='*70}\n")

# =============================================================================
# 第 1 步：读取数据
# =============================================================================

print("=== 步骤 1: 读取数据 ===\n")

# 贸易数据（已处理好的）
tradeF = read_csv_auto("../2_processed_data/tradeF_final.csv")

# SSP2数据（已有所有需要的列）
ssp2 = read_csv_auto("../0_original_data/data.ssp2.csv")

# GDP per capita
gdp_pc = read_csv_auto("../0_original_data/gdp_pc_ppp_2005_wdi.csv")

# Gravity data
gravity = read_csv_auto("../0_original_data/dynamic gravity data AU od do.csv")

print(f"✓ 数据读取完成")

# =============================================================================
# 第 2 步：提取该商品的出口数据
# =============================================================================

print("\n=== 步骤 2: 提取出口数据 ===\n")

# 筛选出口数据
exports = tradeF[
    (tradeF['Element'] == "Export Quantity") &
    (tradeF['LUTO'] == GROUP) &
    (tradeF['Partner Countries'] != "Australia")
]. copy()

# 获取年份列
year_cols = [col for col in exports.columns if col.isdigit() or (col.startswith('X') and col[1:].isdigit())]

# 转为长格式
exports_long = exports.melt(
    id_vars=['ISO3 Code', 'Factor'],
    value_vars=year_cols,
    var_name='year',
    value_name='exports'
)

# 清理年份
exports_long['year'] = exports_long['year'].astype(int)

# 按国家-年份聚合
trade_data = exports_long.groupby(['year', 'ISO3 Code'], as_index=False)['exports'].sum()
trade_data. columns = ['year', 'ISO', 'trade']

# 筛选1990-2014
trade_data = trade_data[(trade_data['year'] >= 1990) & (trade_data['year'] <= 2014)]

print(f"出口数据: {trade_data.shape}")
print(f"国家数: {trade_data['ISO']. nunique()}")

# =============================================================================
# 第 3 步：合并协变量
# =============================================================================

print("\n=== 步骤 3: 合并协变量 ===\n")

# 准备GDP per capita
gdp_pc_long = gdp_pc.melt(
    id_vars=['Country Code'],
    value_vars=gdp_pc.columns[2:],
    var_name='year',
    value_name='gdp_pc'
)
gdp_pc_long['year'] = gdp_pc_long['year'].astype(int)
# 准备gravity数据（只要澳洲出口的）
gravity_au = gravity[gravity['iso3_o'] == 'AUS'][['iso3_d', 'year', 'distance', 'lat_d', 'lng_d']].copy()

# 合并
data = trade_data.copy()

# 合并gravity
data = data.merge(gravity_au, left_on=['ISO', 'year'], right_on=['iso3_d', 'year'], how='left')

# 合并GDP pc
data = data.merge(gdp_pc_long, left_on=['ISO', 'year'], right_on=['Country Code', 'year'], how='left')

# 合并SSP2
data = data.merge(
    ssp2[['country.code', 'year', 'Population.WB', 'Urban.population.pct.WB']],
    left_on=['ISO', 'year'],
    right_on=['country.code', 'year'],
    how='left'
)

print(f"合并后: {data.shape}")

# 删除缺失
data = data.dropna(subset=['trade', 'gdp_pc', 'Population.WB', 'Urban.population.pct.WB'])

print(f"删除缺失后: {data.shape}")

# =============================================================================
# 第 4 步：计算变量
# =============================================================================

print("\n=== 步骤 4: 计算变量 ===\n")

# GDP PPP
data['gdp_ppp'] = data['gdp_pc'] * data['Population.WB']
data['gdp_ppp_sqrd'] = data['gdp_ppp'] ** 2

# 对数变换
data['log_trade'] = np.log(data['trade'] + 1)

print("✓ 变量计算完成")

# =============================================================================
# 第 5 步：标准化
# =============================================================================

print("\n=== 步骤 5: 标准化 ===\n")

scalers = {}
vars_to_scale = ['gdp_ppp', 'gdp_ppp_sqrd', 'Population.WB', 'Urban. population.pct.WB']

# for var in vars_to_scale:
#     scaler = StandardScaler()
#     data[f'{var}_norm'] = scaler.fit_transform(data[[var]])
#     scalers[var] = scaler
# 
# print("✓ 标准化完成")
# 
# # =============================================================================
# # 第 6 步：创建索引
# # =============================================================================
# 
# print("\n=== 步骤 6: 创建索引 ===\n")
# 
# countries = sorted(data['ISO'].unique())
# country_map = {c: i for i, c in enumerate(countries)}
# data['country_idx'] = data['ISO'].map(country_map)
# 
# print(f"国家数: {len(countries)}")
# 
# # =============================================================================
# # 第 7 步：贝叶斯模型
# # =============================================================================
# 
# print("\n=== 步骤 7: 贝叶斯模型 ===\n")


预测商品: Fruits (nec)

=== 步骤 1: 读取数据 ===
Detected: ISO-8859-1 (confidence: 72.97%)
Detected: ascii (confidence: 100.00%)
Detected: Windows-1252 (confidence: 73.00%)
Detected: ascii (confidence: 100.00%)
✓ 数据读取完成

=== 步骤 2: 提取出口数据 ===
出口数据: (5675, 3)
国家数: 227

=== 步骤 3: 合并协变量 ===

合并后: (5675, 12)
删除缺失后: (5050, 12)

=== 步骤 4: 计算变量 ===

✓ 变量计算完成

=== 步骤 5: 标准化 ===


In [3]:
data

Unnamed: 0,year,ISO,trade,iso3_d,distance,lat_d,lng_d,Country Code,gdp_pc,country.code,Population.WB,Urban.population.pct.WB,gdp_ppp,gdp_ppp_sqrd,log_trade
0,1990,AFG,0.0,AFG,11041.0,34.649338,67.113739,AFG,0.000000,AFG,12412308.0,21.177,0.000000e+00,0.000000e+00,0.000000
1,1990,AGO,25.0,AGO,12940.0,-11.704597,16.376490,AGO,5362.583904,AGO,11848386.0,37.144,6.353796e+10,4.037073e+21,3.258097
2,1990,ALB,0.0,ALB,15154.0,41.094173,20.045958,ALB,2958.437076,ALB,3286542.0,36.428,9.723028e+09,9.453727e+19,0.000000
4,1990,ARE,1601.0,ARE,11502.0,25.094378,55.454674,ARE,136664.243500,ARE,1828432.0,79.051,2.498813e+11,6.244065e+22,7.379008
5,1990,ARG,0.0,ARG,12074.0,-34.092197,-63.942871,ARG,10299.163910,ARG,32618651.0,86.984,3.359448e+11,1.128589e+23,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4269,2014,VUT,305.0,VUT,2947.0,-16.622950,167.746600,VUT,3390.384458,VUT,263888.0,24.861,8.946818e+08,8.004555e+17,5.723585
4271,2014,WSM,6.0,WSM,4862.0,-13.841540,-171.738700,WSM,5655.987338,WSM,192221.0,19.149,1.087200e+09,1.182003e+18,1.945910
4272,2014,YEM,0.0,YEM,11905.0,14.599030,45.970100,YEM,0.000000,YEM,25823485.0,34.165,0.000000e+00,0.000000e+00,0.000000
4273,2014,ZAF,3.0,ZAF,10517.0,-29.420470,25.905810,ZAF,14516.851390,ZAF,54545991.0,64.312,7.918360e+11,6.270043e+23,1.386294
