In [4]:
import os
import pandas as pd
import matplotlib.pyplot as plt

In [5]:
#!/usr/bin/env python3
# recsys_imbalance_viz.py
# Visualize genre imbalance with grouped bars, long-tail, cumulative, and heatmap.

import argparse
from pathlib import Path
from typing import Optional, List, Tuple

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

# If you're on a headless server, uncomment:
# matplotlib.use("Agg")


def build_synthetic_df() -> pd.DataFrame:
    genres = [
        "Fantasy", "Science Fiction", "Romance", "Mystery", "Thriller",
        "Historical", "Adult", "Horror", "Children's",
        "Adventure", "Classics", "Nonfiction", "Drama"
    ]
    typical = np.array([14, 13, 12, 11, 10, 8, 7, 6, 5, 5, 4, 3, 2], dtype=float)
    desired = np.array([10, 10, 10, 9, 9, 8, 8, 7, 7, 7, 7, 6, 6], dtype=float)
    typical = typical / typical.sum() * 100.0
    desired = desired / desired.sum() * 100.0
    return pd.DataFrame({"genre": genres, "typical_pct": typical, "desired_pct": desired})


def load_df(csv_path: Optional[Path], normalize: bool) -> pd.DataFrame:
    if csv_path is not None and csv_path.exists():
        df = pd.read_csv(csv_path)
        required = {"genre", "typical_pct", "desired_pct"}
        missing = required - set(df.columns)
        if missing:
            raise ValueError("CSV is missing columns: {}".format(missing))
        df["typical_pct"] = pd.to_numeric(df["typical_pct"], errors="coerce").fillna(0.0)
        df["desired_pct"] = pd.to_numeric(df["desired_pct"], errors="coerce").fillna(0.0)
    else:
        df = build_synthetic_df()

    if normalize:
        for col in ["typical_pct", "desired_pct"]:
            s = float(df[col].sum())
            if s > 0:
                df[col] = df[col] / s * 100.0
    return df


def compute_stats(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["gap_pct"] = df["desired_pct"] - df["typical_pct"]
    df["abs_gap_pct"] = df["gap_pct"].abs()
    total_abs_gap = float(df["abs_gap_pct"].sum())
    df["gap_contribution_pct"] = np.where(
        total_abs_gap > 0, df["abs_gap_pct"] / total_abs_gap * 100.0, 0.0
    )
    return df


def save_grouped_bar(df: pd.DataFrame, outdir: Path) -> Path:
    plt.figure(figsize=(12, 6))
    x = np.arange(len(df))
    width = 0.4
    plt.bar(x - width/2, df["typical_pct"].values, width=width, label="Typical")
    plt.bar(x + width/2, df["desired_pct"].values, width=width, label="Desired")
    plt.xticks(x, df["genre"].tolist(), rotation=30, ha="right")
    plt.ylabel("Share of Recommendations (%)")
    plt.title("Genre Distribution: Typical vs Desired")
    plt.legend()
    plt.tight_layout()
    path = outdir / "grouped_bar_typical_vs_desired.png"
    plt.savefig(path, dpi=200, bbox_inches="tight")
    plt.close()
    return path


def save_long_tail(df: pd.DataFrame, outdir: Path) -> Tuple[Path, Path]:
    df_sorted = df.sort_values("typical_pct", ascending=False).reset_index(drop=True)

    plt.figure(figsize=(12, 6))
    plt.plot(df_sorted["typical_pct"].values, marker="o")
    plt.xticks(np.arange(len(df_sorted)), df_sorted["genre"].tolist(), rotation=30, ha="right")
    plt.ylabel("Share of Recommendations (%)")
    plt.title("Long Tail: Typical Distribution by Genre (Sorted Desc)")
    plt.tight_layout()
    longtail_path = outdir / "long_tail_typical.png"
    plt.savefig(longtail_path, dpi=200, bbox_inches="tight")
    plt.close()

    cum = df_sorted["typical_pct"].cumsum()
    plt.figure(figsize=(12, 6))
    plt.plot(cum.values, marker="o")
    plt.xticks(np.arange(len(df_sorted)), df_sorted["genre"].tolist(), rotation=30, ha="right")
    plt.ylabel("Cumulative Share (%)")
    plt.title("Cumulative Long Tail: Typical Distribution (Sorted Desc)")
    plt.tight_layout()
    cum_path = outdir / "long_tail_typical_cumulative.png"
    plt.savefig(cum_path, dpi=200, bbox_inches="tight")
    plt.close()

    return longtail_path, cum_path


def save_heatmap(df: pd.DataFrame, outdir: Path) -> Path:
    heat_data = np.vstack([
        df["typical_pct"].values,
        df["desired_pct"].values,
        df["gap_pct"].values
    ])
    plt.figure(figsize=(14, 4))
    plt.imshow(heat_data, aspect="auto")
    plt.yticks([0, 1, 2], ["Typical %", "Desired %", "Gap (Desired - Typical)"])
    plt.xticks(np.arange(len(df)), df["genre"].tolist(), rotation=30, ha="right")
    plt.colorbar(label="Percent")
    plt.title("Genre Coverage Heatmap")
    plt.tight_layout()
    path = outdir / "heatmap_typical_desired_gap.png"
    plt.savefig(path, dpi=200, bbox_inches="tight")
    plt.close()
    return path


def parse_args(argv: Optional[List[str]] = None):
    parser = argparse.ArgumentParser(
        description="Visualize genre imbalance (Typical vs Desired)."
    )
    parser.add_argument("--csv", type=Path, default=None,
                        help="Path to CSV with columns: genre, typical_pct, desired_pct")
    parser.add_argument("--outdir", type=Path, default=Path("./recsys_viz"),
                        help="Output directory for figures and stats")
    parser.add_argument("--normalize", action="store_true",
                        help="Normalize typical/desired columns to 100%% each")

    # IMPORTANT: in notebooks/IPython, extra args like --f=... appear.
    # Use parse_known_args to ignore unknowns.
    args, _unknown = parser.parse_known_args(argv)
    return args


def main(argv: Optional[List[str]] = None):
    args = parse_args(argv)
    args.outdir.mkdir(parents=True, exist_ok=True)

    df = load_df(args.csv, normalize=args.normalize)
    df = compute_stats(df)

    stats_path = args.outdir / "genre_imbalance_stats.csv"
    df.round(3).to_csv(stats_path, index=False)

    bar_path = save_grouped_bar(df, args.outdir)
    longtail_path, cum_path = save_long_tail(df, args.outdir)
    heat_path = save_heatmap(df, args.outdir)

    print("Saved:")
    print("  Stats CSV:           {}".format(stats_path))
    print("  Grouped Bar:         {}".format(bar_path))
    print("  Long Tail:           {}".format(longtail_path))
    print("  Cumulative Long Tail:{}".format(cum_path))
    print("  Heatmap:             {}".format(heat_path))


if __name__ == "__main__":
    # When run as a script (terminal): python recsys_imbalance_viz.py --csv ... --outdir ...
    # When run in Jupyter: just execute this cell; parse_known_args will ignore --f=...
    main()


Saved:
  Stats CSV:           recsys_viz/genre_imbalance_stats.csv
  Grouped Bar:         recsys_viz/grouped_bar_typical_vs_desired.png
  Long Tail:           recsys_viz/long_tail_typical.png
  Cumulative Long Tail:recsys_viz/long_tail_typical_cumulative.png
  Heatmap:             recsys_viz/heatmap_typical_desired_gap.png
