In [3]:
from pathlib import Path
import numpy as np
import plotly.graph_objects as go
from plotly.colors import qualitative as qpal
from datetime import timedelta
import pandas as pd

# Create folder to save figures
FIGDIR = Path("figures"); FIGDIR.mkdir(exist_ok=True)
# 1) LOAD & SPLIT THE DATASET
# Load cleaned data
df = pd.read_csv("online_retail_cleaned.csv")
df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"], errors="coerce")
df["Year"] = df["InvoiceDate"].dt.year


def sankey_top5_year_to_items(df, metric="Revenue", year_a=2010, year_b=2011, max_items=5, save_html_path=None):
    """
    Build a 3-column Sankey: Year A → Items → Year B
    - Each item has its own pastel color (applied to node + flows).
    - Year nodes have fixed neutral colors.
    """
    data = df.copy()
    data = data[data["Quantity"] > 0]
    data = data.dropna(subset=["InvoiceDate"])

    if metric == "Revenue":
        data["Revenue"] = data["Quantity"] * data["Price"]

    # short labels
    def _short(s, n=28):
        s = str(s) if s is not None else ""
        return s if len(s) <= n else s[:n-1] + "…"

    # item names
    names = (data.dropna(subset=["Description"])
                  .drop_duplicates("StockCode")[["StockCode", "Description"]]
                  .assign(item_label=lambda d: d["Description"].map(_short)))

    # aggregate
    agg = (data.groupby(["Year", "StockCode"])
                .agg(Quantity=("Quantity", "sum"),
                     Revenue=("Revenue", "sum") if "Revenue" in data.columns else ("Quantity", "sum"))
                .reset_index())

    def top_codes(yr):
        return (agg[agg["Year"] == yr]
                .sort_values(metric, ascending=False)
                .head(max_items)["StockCode"].tolist())

    top_a = top_codes(year_a)
    top_b = top_codes(year_b)
    top_union = list(dict.fromkeys(top_a + top_b))

    tidy = (agg[agg["StockCode"].isin(top_union)]
              .pivot_table(index="StockCode", columns="Year", values=metric, aggfunc="sum")
              .fillna(0.0)
              .reset_index())

    tidy = tidy.merge(names[["StockCode", "item_label"]], on="StockCode", how="left")
    tidy["item_label"] = tidy["item_label"].fillna(tidy["StockCode"])

    for y in (year_a, year_b):
        if y not in tidy.columns:
            tidy[y] = 0.0

    tidy = tidy.sort_values(year_b, ascending=False).reset_index(drop=True)

    # --- Nodes ---
    year_a_label = str(year_a)
    year_b_label = str(year_b)
    item_labels = tidy["item_label"].tolist()

    node_labels = [year_a_label] + item_labels + [year_b_label]
    idx_year_a = 0
    idx_year_b = len(node_labels) - 1
    item_start = 1
    item_idx = {lab: item_start + i for i, lab in enumerate(item_labels)}

    # Pastel palette
    # pastel_palette = qpal.Pastel + qpal.Pastel1 + qpal.Pastel2
    # while len(pastel_palette) < len(item_labels):
    #     pastel_palette *= 2
    # item_colors = pastel_palette[:len(item_labels)]
    # item_color_map = dict(zip(item_labels, item_colors))
    # FT-style muted pastel palette
    FT_COLORS = [
        "#964977", "#4f7ca3", "#f0a6ca",
        "#91cf60", "#67a9cf", "#fc8d62", "#8da0cb"
    ]
    # cycle if fewer/more than needed
    while len(FT_COLORS) < len(item_labels):
        FT_COLORS *= 2
    item_colors = FT_COLORS[:len(item_labels)]
    item_color_map = dict(zip(item_labels, item_colors))

    # --- Links ---
    sources, targets, values, link_colors = [], [], [], []
    for _, row in tidy.iterrows():
        lab = row["item_label"]
        v_a = float(row.get(year_a, 0.0))
        v_b = float(row.get(year_b, 0.0))
        col = item_color_map[lab]

        if v_a > 0:
            sources.append(idx_year_a)
            targets.append(item_idx[lab])
            values.append(v_a)
            link_colors.append(col)

        if v_b > 0:
            sources.append(item_idx[lab])
            targets.append(idx_year_b)
            values.append(v_b)
            link_colors.append(col)

    # Node colors: year nodes neutral, items same as their link colors
    node_colors = ["#ccebc5"] + [item_color_map[lab] for lab in item_labels] + ["#b3cde3"]

    # force layout
    node_x = [0.0] + [0.5]*len(item_labels) + [1.0]
    node_y = [0.5] + (np.linspace(0.05, 0.95, len(item_labels)).tolist() if len(item_labels) > 1 else [0.5]) + [0.5]

    valueformat = ".2f" if metric == "Revenue" else ".0f"
    hover_value = "%{value:.2f}" if metric == "Revenue" else "%{value:.0f}"
    hovertemplate = "%{source.label} → %{target.label}<br>" + metric + ": " + hover_value

    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        valueformat=valueformat,
        node=dict(
            label=node_labels,
            pad=14,
            thickness=11,
            color=node_colors,
            x=node_x,
            y=node_y
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors,
            hovertemplate=hovertemplate
        )
    )])

    title = f"Top-{max_items} Items by {metric}: {year_a} vs {year_b}"
    fig.update_layout(title=dict(text=title, x=0.5, y=0.95), font=dict(size=12), hoverlabel=dict(bgcolor="white"))
    fig.show()
    # optional export
    html_path = FIGDIR / f"sankey_{metric.lower()}_{year_a}_{year_b}.html"
    fig.write_html(html_path)

        # Revenue comparison
sankey_top5_year_to_items(df, metric="Revenue",  year_a=2010, year_b=2011, max_items=5)

# Quantity comparison
sankey_top5_year_to_items(df, metric="Quantity", year_a=2010, year_b=2011, max_items=5)
