<a href="https://colab.research.google.com/github/sperez1989/oecd-inflation-streamlit/blob/main/ALY6110_FinalGroupProject_PySpark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================
# 0. Install and import dependencies
# ============================================
!pip install -q pyspark plotly pandas pyarrow

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans

import pandas as pd

# ============================================
# 1. Start Spark Session
# ============================================
spark = (
    SparkSession.builder
    .appName("OECD_Inflation_Consumption_Clustering")
    .getOrCreate()
)

spark

# ============================================
# 2. Load CPI and COICOP datasets
# ============================================

cpi_path = "/content/COICOP Datasets/Monthly Consumer Price Indices (CPI, HICP).csv"
coicop_path = "/content/COICOP Datasets/Annual Household Final Consumption Expenditure (COICOP).csv"

cpi_raw = (
    spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv(cpi_path)
)

coicop_raw = (
    spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv(coicop_path)
)

print("CPI rows:", cpi_raw.count())
print("COICOP rows:", coicop_raw.count())

cpi_raw.printSchema()
coicop_raw.printSchema()

# ============================================
# 3. Basic cleaning and renaming
# ============================================

from pyspark.sql import functions as F
from pyspark.sql.window import Window

cpi = (
    cpi_raw
    .select(
        F.col("REF_AREA").alias("country"),
        F.col("EXPENDITURE14").alias("category"),
        F.col("TIME_PERIOD").alias("time_period"),
        F.col("OBS_VALUE").cast("double").alias("cpi_value"),
        F.col("FREQ"),
        F.col("METHODOLOGY8").alias("methodology"),
        F.col("MEASURE10").alias("measure"),
        F.col("UNIT_MEASURE"),
        F.col("ADJUSTMENT16").alias("adjustment"),
        F.col("TRANSFORMATION18").alias("transformation")
    )
    .filter(F.col("FREQ") == "M")
    .filter(F.col("measure") == "CPI")
    .filter(F.col("UNIT_MEASURE") == "PA")
    .filter(F.col("transformation") == "GY")
)

cpi = cpi.withColumn(
    "year",
    F.year(F.col("time_period")).cast("int")
)

cpi = cpi.filter(F.col("year") >= 2019)

coicop = (
    coicop_raw
    .select(
        F.col("REF_AREA").alias("country"),
        F.col("EXPENDITURE18").alias("category"),
        F.col("TIME_PERIOD").cast("int").alias("year"),
        F.col("OBS_VALUE").cast("double").alias("exp_value"),
        F.col("FREQ"),
        F.col("UNIT_MEASURE"),
        F.col("PRICE_BASE"),
        F.col("UNIT_MULT"),
        F.col("CURRENCY42").alias("currency")
    )
    .filter(F.col("FREQ") == "A")
)

coicop = coicop.filter(F.col("year") >= 2019)

cpi.select("country", "category", "time_period", "year", "cpi_value").show(5)
coicop.select("country", "category", "year", "exp_value").show(5)

# ============================================
# 4. Aggregate monthly CPI to annual level
# ============================================

cpi_annual = (
    cpi
    .groupBy("country", "category", "year")
    .agg(
        F.avg("cpi_value").alias("cpi_annual_avg")
    )
)

cpi_annual.show(5)

# ============================================
# 5. Compute expenditure shares and growth
# ============================================

total_exp = (
    coicop
    .groupBy("country", "year")
    .agg(F.sum("exp_value").alias("total_exp"))
)

coicop_enriched = (
    coicop
    .join(total_exp, on=["country", "year"], how="left")
    .withColumn("exp_share", F.col("exp_value") / F.col("total_exp"))
)

w = Window.partitionBy("country", "category").orderBy("year")

coicop_enriched = (
    coicop_enriched
    .withColumn("exp_value_lag", F.lag("exp_value").over(w))
    .withColumn(
        "exp_growth_yoy",
        F.when(F.col("exp_value_lag").isNotNull(),
               (F.col("exp_value") - F.col("exp_value_lag")) / F.col("exp_value_lag"))
        .otherwise(None)
    )
)

coicop_enriched.select(
    "country", "category", "year",
    "exp_value", "total_exp", "exp_share", "exp_growth_yoy"
).show(10)

# ============================================
# 6. Join CPI and expenditure at annual level
# ============================================

analytics = (
    coicop_enriched.alias("exp")
    .join(
        cpi_annual.alias("cpi"),
        on=["country", "category", "year"],
        how="inner"
    )
)

analytics.select(
    "country", "category", "year",
    "cpi_annual_avg", "exp_value", "exp_share", "exp_growth_yoy"
).show(10)

# ============================================
# 7. Create summary stats per country & category (2020–2024)
# ============================================

analytics_recent = analytics.filter(F.col("year") >= 2020)

country_category_summary = (
    analytics_recent
    .groupBy("country", "category")
    .agg(
        F.avg("cpi_annual_avg").alias("avg_cpi_2020_24"),
        F.avg("exp_share").alias("avg_share_2020_24"),
        F.avg("exp_growth_yoy").alias("avg_growth_2020_24")
    )
)

country_category_summary.show(10)

# ============================================
# 8. Build country-level feature matrix for clustering
# ============================================

features_inflation = (
    country_category_summary
    .groupBy("country")
    .pivot("category")
    .agg(F.first("avg_cpi_2020_24"))
    .fillna(0)
)

features_share = (
    country_category_summary
    .groupBy("country")
    .pivot("category")
    .agg(F.first("avg_share_2020_24"))
    .fillna(0)
)

features_growth = (
    country_category_summary
    .groupBy("country")
    .pivot("category")
    .agg(F.first("avg_growth_2020_24"))
    .fillna(0)
)

for col_name in features_inflation.columns:
    if col_name != "country":
        features_inflation = features_inflation.withColumnRenamed(col_name, f"cpi_{col_name}")

for col_name in features_share.columns:
    if col_name != "country":
        features_share = features_share.withColumnRenamed(col_name, f"share_{col_name}")

for col_name in features_growth.columns:
    if col_name != "country":
        features_growth = features_growth.withColumnRenamed(col_name, f"growth_{col_name}")

features = (
    features_inflation
    .join(features_share, on="country", how="inner")
    .join(features_growth, on="country", how="inner")
)

features.show(5)
print("Number of countries:", features.count())

# ============================================
# 9. Run K-Means clustering
# ============================================

feature_cols = [c for c in features.columns if c != "country"]

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features_vec"
)

features_vec = assembler.transform(features)

scaler = StandardScaler(
    inputCol="features_vec",
    outputCol="features_scaled",
    withMean=True,
    withStd=True
)

scaler_model = scaler.fit(features_vec)
features_scaled = scaler_model.transform(features_vec)

k = 4
kmeans = KMeans(
    featuresCol="features_scaled",
    predictionCol="cluster",
    k=k,
    seed=42
)

kmeans_model = kmeans.fit(features_scaled)
clustered = kmeans_model.transform(features_scaled)

clustered.select("country", "cluster").show(50)

# ============================================
# 10. Cluster profiles
# ============================================

cluster_summary = clustered

cpi_cols = [c for c in cluster_summary.columns if c.startswith("cpi_")]
growth_cols = [c for c in cluster_summary.columns if c.startswith("growth_")]

agg_exprs = [F.avg(c).alias(f"avg_{c}") for c in (cpi_cols + growth_cols)]

cluster_profile = (
    cluster_summary
    .groupBy("cluster")
    .agg(*agg_exprs)
)

cluster_profile.show(truncate=False)

# ============================================
# 11. Export tables for dashboard
# ============================================

# a) Time series: Canada vs OECD average
oecd_avg = (
    analytics_recent
    .groupBy("year", "category")
    .agg(
        F.avg("cpi_annual_avg").alias("oecd_cpi"),
        F.avg("exp_share").alias("oecd_exp_share"),
        F.avg("exp_growth_yoy").alias("oecd_exp_growth")
    )
)

canada_series = analytics_recent.filter(F.col("country") == "CAN")

canada_vs_oecd = (
    canada_series
    .select(
        "year", "category",
        F.col("cpi_annual_avg").alias("can_cpi"),
        F.col("exp_share").alias("can_exp_share"),
        F.col("exp_growth_yoy").alias("can_exp_growth")
    )
    .join(oecd_avg, on=["year", "category"], how="left")
)

# b) Cluster assignment by country
cluster_results_pd = clustered.select("country", "cluster").toPandas()
canada_vs_oecd_pd = canada_vs_oecd.toPandas()

cluster_results_pd.to_csv("cluster_results.csv", index=False)
canada_vs_oecd_pd.to_csv("canada_vs_oecd_timeseries.csv", index=False)

print("Saved: cluster_results.csv, canada_vs_oecd_timeseries.csv")

# ============================================
# 12. Cluster assignments dataframe
# ============================================

clusters_df = (
    clustered
    .select("country", "cluster")
    .dropDuplicates()
)

# ============================================
# 13. CPI time series – Canada + clusters
# ============================================

# CPI promedio por cluster
cluster_ts_clusters = (
    analytics_recent
    .join(clusters_df, on="country", how="inner")
    .groupBy("cluster", "year", "category")
    .agg(F.avg("cpi_annual_avg").alias("avg_cpi"))
    .withColumn(
        "group",
        F.concat(F.lit("Cluster "), F.col("cluster").cast("string"))
    )
)

# CPI promedio para Canadá (como grupo propio)
cluster_ts_canada = (
    analytics_recent
    .filter(F.col("country") == "CAN")
    .groupBy("year", "category")
    .agg(F.avg("cpi_annual_avg").alias("avg_cpi"))
    .withColumn("cluster", F.lit(None).cast("int"))
    .withColumn("group", F.lit("Canada"))
)

cluster_timeseries_all = (
    cluster_ts_clusters
    .select("group", "cluster", "year", "category", "avg_cpi")
    .unionByName(
        cluster_ts_canada.select("group", "cluster", "year", "category", "avg_cpi")
    )
    .orderBy("group", "year", "category")
)

cluster_timeseries_pd = cluster_timeseries_all.toPandas()
cluster_timeseries_pd.to_csv("cluster_timeseries.csv", index=False)

# ============================================
# 14. Expenditure share & growth – Canada + clusters
# ============================================

# Promedios por cluster
cluster_exp_clusters = (
    analytics_recent
    .join(clusters_df, on="country", how="inner")
    .groupBy("cluster", "year", "category")
    .agg(
        F.avg("exp_share").alias("avg_exp_share"),
        F.avg("exp_growth_yoy").alias("avg_exp_growth")
    )
    .withColumn(
        "group",
        F.concat(F.lit("Cluster "), F.col("cluster").cast("string"))
    )
)

# Promedios para Canadá
cluster_exp_canada = (
    analytics_recent
    .filter(F.col("country") == "CAN")
    .groupBy("year", "category")
    .agg(
        F.avg("exp_share").alias("avg_exp_share"),
        F.avg("exp_growth_yoy").alias("avg_exp_growth")
    )
    .withColumn("cluster", F.lit(None).cast("int"))
    .withColumn("group", F.lit("Canada"))
)

cluster_expenditure_all = (
    cluster_exp_clusters
    .select("group", "cluster", "year", "category", "avg_exp_share", "avg_exp_growth")
    .unionByName(
        cluster_exp_canada.select(
            "group", "cluster", "year", "category", "avg_exp_share", "avg_exp_growth"
        )
    )
    .orderBy("group", "year", "category")
)

cluster_expenditure_pd = cluster_expenditure_all.toPandas()
cluster_expenditure_pd.to_csv("cluster_expenditure.csv", index=False)

# ============================================
# 15. Country list per cluster
# ============================================

cluster_country_list_pd = (
    clusters_df
    .toPandas()
    .drop_duplicates()
    .sort_values(["cluster", "country"])
)

cluster_country_list_pd.to_csv("cluster_country_list.csv", index=False)

print(
    "Saved: cluster_results.csv, canada_vs_oecd_timeseries.csv, "
    "cluster_timeseries.csv, cluster_expenditure.csv, cluster_country_list.csv"
)


CPI rows: 1048575
COICOP rows: 305702
root
 |-- STRUCTURE: string (nullable = true)
 |-- STRUCTURE_ID: string (nullable = true)
 |-- STRUCTURE_NAME: string (nullable = true)
 |-- ACTION: string (nullable = true)
 |-- REF_AREA: string (nullable = true)
 |-- Reference area: string (nullable = true)
 |-- FREQ: string (nullable = true)
 |-- Frequency of observation: string (nullable = true)
 |-- METHODOLOGY8: string (nullable = true)
 |-- Methodology9: string (nullable = true)
 |-- MEASURE10: string (nullable = true)
 |-- Measure11: string (nullable = true)
 |-- UNIT_MEASURE: string (nullable = true)
 |-- Unit of measure: string (nullable = true)
 |-- EXPENDITURE14: string (nullable = true)
 |-- Expenditure15: string (nullable = true)
 |-- ADJUSTMENT16: string (nullable = true)
 |-- Adjustment17: string (nullable = true)
 |-- TRANSFORMATION18: string (nullable = true)
 |-- Transformation19: string (nullable = true)
 |-- TIME_PERIOD: timestamp (nullable = true)
 |-- Time period: string (nul