# Case Study: Hierarchical Clustering for Customer Segmentation (Real Data)

**Dataset:** *Wholesale Customers Dataset* (UCI Machine Learning Repository)  
This dataset contains **annual spending amounts** across multiple product categories for customers of a wholesale distributor.

---

## Learning Objectives (Class Demo)

In this case study, we will:

1. Load and explore **real-world customer purchasing data**
2. Preprocess the data using **log transformation** and **standardization**
3. Apply **hierarchical (agglomerative) clustering** with different linkage methods
4. Visualize **dendrograms** to understand hierarchical structure
5. Select an appropriate number of clusters (\(K\)) and analyze **customer segments**
6. Interpret clusters using **original feature pairs (no PCA)** and business intuition

---

> **Note:**  
> The dataset is downloaded directly from the UCI Machine Learning Repository using a public URL, ensuring reproducibility and transparency for classroom demonstrations.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as sch

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import AgglomerativeClustering

plt.rcParams["figure.dpi"] = 120
np.random.seed(42)


## 1) Load the dataset (UCI Wholesale Customers)

Columns:
- `Channel`, `Region` (categorical-ish numeric)
- Spending categories: `Fresh`, `Milk`, `Grocery`, `Frozen`, `Detergents_Paper`, `Delicassen`


In [None]:
UCI_CSV_URL = r"https://archive.ics.uci.edu/ml/machine-learning-databases/00292/Wholesale%20customers%20data.csv"

df = pd.read_csv(UCI_CSV_URL)
display(df.head())
print("Shape:", df.shape)
df.describe(include="all")


## 2) Select features for clustering

For a clean customer-segmentation demo, we cluster using only the **spending categories** (continuous features).
We will **exclude** `Channel` and `Region` from clustering (we can use them later to interpret clusters).


In [None]:
spend_cols = ["Fresh", "Milk", "Grocery", "Frozen", "Detergents_Paper", "Delicassen"]
X = df[spend_cols].copy()

# Optional (recommended): log-transform to reduce skew in spending data
X_log = np.log1p(X)

# Standardize for distance-based clustering
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

X_scaled.shape


## 3) Dendrograms (compare linkage methods)

We compare:
- **Single**: can create *chaining*  
- **Complete**: tends to form compact clusters  
- **Average**: middle ground  
- **Ward**: minimizes within-cluster variance (often a strong default)

To avoid label clutter, we truncate to the top levels.


In [None]:
linkages = ["single", "average", "complete", "ward"]

fig, axes = plt.subplots(2, 2, figsize=(12, 7))
axes = axes.ravel()

for ax, method in zip(axes, linkages):
    Z = sch.linkage(X_scaled, method=method)
    sch.dendrogram(Z, ax=ax, truncate_mode="level", p=5, no_labels=True)
    ax.set_title(f"{method.capitalize()} linkage")
    ax.set_ylabel("Distance")

plt.tight_layout()
plt.show()


## 4) Choose a number of clusters (K) and fit Agglomerative Clustering

For class demo, start with `K=4` or `K=5` and discuss how changing `K` changes segments.


In [None]:
K = 4
linkage_method = "ward"  # try: "complete", "average", "single", "ward"

model = AgglomerativeClustering(n_clusters=K, linkage=linkage_method, metric="euclidean")
labels = model.fit_predict(X_scaled)

df_demo = df.copy()
df_demo["Cluster"] = labels

print("Cluster sizes (largest → smallest):", np.sort(np.bincount(labels))[::-1])
df_demo.head()


## 5) Cluster plots (NO PCA): use interpretable feature pairs

Because the data is 6D, we visualize clusters using **pairs of original spending features**.

Recommended pairs for a clean story:
- `Grocery` vs `Detergents_Paper` (often separates retail-like vs horeca-like clients)
- `Milk` vs `Grocery`
- `Fresh` vs `Frozen`

We plot **raw values** (not log-scaled) for interpretability, with a log axis to reduce skew.


In [None]:
# Pick a feature pair to visualize (no PCA)
x_col, y_col = "Grocery", "Detergents_Paper"

plt.figure(figsize=(6.2, 5.2))
plt.scatter(df_demo[x_col], df_demo[y_col], c=df_demo["Cluster"], cmap="tab10", alpha=0.85, s=28)
plt.title(f"Agglomerative Clustering ({linkage_method}, K={K})")
plt.xlabel(x_col)
plt.ylabel(y_col)

# Log scale improves visibility because spending features are skewed
plt.xscale("log")
plt.yscale("log")

plt.tight_layout()
plt.show()


### Optional: show multiple pairs (small multiples)

This gives a very clear “segment profile” view without PCA.


In [None]:
pairs = [("Grocery", "Detergents_Paper"),
         ("Milk", "Grocery"),
         ("Fresh", "Frozen"),
         ("Frozen", "Delicassen")]

fig, axes = plt.subplots(2, 2, figsize=(11, 8))
axes = axes.ravel()

for ax, (x_col, y_col) in zip(axes, pairs):
    ax.scatter(df_demo[x_col], df_demo[y_col], c=df_demo["Cluster"], cmap="tab10", alpha=0.85, s=24)
    ax.set_title(f"{x_col} vs {y_col}")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel(x_col)
    ax.set_ylabel(y_col)

plt.suptitle(f"Clusters visualized on feature pairs (no PCA) — {linkage_method}, K={K}", y=1.02)
plt.tight_layout()
plt.show()


## 6) Segment summary table (business interpretation)

We summarize each cluster by mean spending in each category (on raw scale and log scale).


In [None]:
# Avoid ambiguity errors if 'Cluster' becomes an index somewhere:
df_demo = df_demo.reset_index(drop=True)

summary_mean = df_demo.groupby("Cluster", as_index=False)[spend_cols].mean()
summary_median = df_demo.groupby("Cluster", as_index=False)[spend_cols].median()

print("Mean spend by cluster:")
display(summary_mean)

print("Median spend by cluster:")
display(summary_median)


## 7) Optional: Compare clusters vs Channel/Region

Even though we excluded `Channel`/`Region` from clustering, we can use them to interpret segments.


In [None]:
# Cross-tab: how clusters relate to Channel/Region (interpretation)
ct_channel = pd.crosstab(df_demo["Cluster"], df_demo["Channel"], normalize="index")
ct_region  = pd.crosstab(df_demo["Cluster"], df_demo["Region"], normalize="index")

print("Cluster → Channel distribution (row-normalized):")
display(ct_channel)

print("Cluster → Region distribution (row-normalized):")
display(ct_region)
