# SnackTrack ML --- Exploratory Data Analysis

This notebook performs a comprehensive EDA across **all** data sources available to the
SnackTrack recommendation engine:

| Source | What we look at |
|--------|----------------|
| **Kaggle datasets** (Parquet) | Recipe nutrition, ingredients, ratings, daily food logs |
| **PostgreSQL database** | Live recipes, user interactions, meal logs, taste profiles, dietary preferences |

### Goals

1. Understand the **shape and quality** of every dataset
2. Visualize **nutrition distributions** and compare across sources
3. Explore **diet labels, allergens, and cuisine types**
4. Analyze **user interaction patterns** (ratings, temporal trends)
5. Examine **meal patterns** from daily food logs
6. Produce a consolidated **data quality summary**

> **Prerequisite**: Run `00_download_datasets.ipynb` first to download and convert datasets to Parquet.

In [None]:
import sys
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings("ignore", category=FutureWarning)

# Allow imports from the parent directory
sys.path.insert(0, "..")

from notebooks.utils.plot_helpers import setup_plot_style, SNACKTRACK_COLORS, PALETTE

setup_plot_style()

print("Environment ready.")

## 1. Load All Datasets

We load from **two** sources:

- **Database** -- via `utils.db_connect.get_connection()` and the various `load_*_from_db()` helpers.
  Wrapped in `try/except` so the notebook still works when no database is available.
- **Kaggle Parquet files** -- via `utils.data_loader.load_kaggle_dataset()`.

In [None]:
from notebooks.utils.data_loader import (
    load_kaggle_dataset,
    load_recipes_from_db,
    load_interactions_from_db,
    load_meal_logs_from_db,
    load_user_profiles_from_db,
    load_dietary_preferences_from_db,
)
from notebooks.utils.db_connect import get_connection

# ---- Database data (optional) ----
db_recipes = pd.DataFrame()
db_interactions = pd.DataFrame()
db_meal_logs = pd.DataFrame()
db_user_profiles = pd.DataFrame()
db_dietary_prefs = pd.DataFrame()
db_available = False

try:
    conn = get_connection()
    db_recipes = load_recipes_from_db(conn)
    db_interactions = load_interactions_from_db(conn)
    db_meal_logs = load_meal_logs_from_db(conn)
    db_user_profiles = load_user_profiles_from_db(conn)
    db_dietary_prefs = load_dietary_preferences_from_db(conn)
    conn.close()
    db_available = True
    print(f"Database connected. Loaded:")
    print(f"  recipes:       {len(db_recipes):>7,} rows")
    print(f"  interactions:  {len(db_interactions):>7,} rows")
    print(f"  meal_logs:     {len(db_meal_logs):>7,} rows")
    print(f"  user_profiles: {len(db_user_profiles):>7,} rows")
    print(f"  dietary_prefs: {len(db_dietary_prefs):>7,} rows")
except Exception as e:
    print(f"Database not available ({type(e).__name__}: {e}).")
    print("Continuing with Kaggle data only.")

# ---- Kaggle datasets (from Parquet) ----
kaggle_data: dict[str, pd.DataFrame] = {}

KAGGLE_NAMES = [
    "diet_recommendations",
    "daily_food_nutrition",
    "medical_diet",
    "epicurious",
    "recipe_ingredients",
    "global_food_nutrition",
    "recipes_64k",
    "food_recommendation",
]

# Food.com has multiple files -- try common sub-file naming
KAGGLE_MULTI = [
    "foodcom_reviews",
    "foodcom_interactions",
]

for name in KAGGLE_NAMES:
    try:
        kaggle_data[name] = load_kaggle_dataset(name)
        print(f"  {name:<30s} {kaggle_data[name].shape[0]:>9,} rows  {kaggle_data[name].shape[1]:>3} cols")
    except FileNotFoundError:
        print(f"  {name:<30s} NOT FOUND (run notebook 00 first)")

# For multi-file datasets, try loading each sub-file
from notebooks.utils.dataset_downloader import DATA_DIR

for name in KAGGLE_MULTI:
    # Try single file first
    try:
        kaggle_data[name] = load_kaggle_dataset(name)
        print(f"  {name:<30s} {kaggle_data[name].shape[0]:>9,} rows  {kaggle_data[name].shape[1]:>3} cols")
        continue
    except FileNotFoundError:
        pass

    # Try sub-files: data/<name>__<stem>.parquet
    sub_files = sorted(DATA_DIR.glob(f"{name}__*.parquet"))
    for sf in sub_files:
        sub_key = sf.stem  # e.g. foodcom_reviews__recipes
        try:
            kaggle_data[sub_key] = pd.read_parquet(sf)
            print(f"  {sub_key:<30s} {kaggle_data[sub_key].shape[0]:>9,} rows  {kaggle_data[sub_key].shape[1]:>3} cols")
        except Exception as e:
            print(f"  {sub_key:<30s} LOAD ERROR: {e}")

    if not sub_files:
        print(f"  {name:<30s} NOT FOUND")

print(f"\nTotal Kaggle datasets loaded: {len(kaggle_data)}")

## 2. Recipe Corpus Analysis

We merge recipe data from multiple sources and examine the distribution of key
nutritional fields: **calories, protein, carbs, fat, fiber, sugar, sodium**.

This helps us understand:
- Value ranges across datasets (do we need clipping/normalization?)
- Whether certain datasets are biased toward high/low calorie foods
- How well the Kaggle data matches our database recipes

In [None]:
# Collect nutrition columns from every source that has them
NUTRITION_COLS = ["calories", "protein_g", "carbs_g", "fat_g", "fiber_g", "sugar_g", "sodium_mg"]

# Mapping of common column name variants to our canonical names
COL_ALIASES = {
    "protein": "protein_g",
    "protein_g": "protein_g",
    "carbohydrates": "carbs_g",
    "carbs": "carbs_g",
    "carbs_g": "carbs_g",
    "total_carbohydrate_g": "carbs_g",
    "fat": "fat_g",
    "fat_g": "fat_g",
    "total_fat_g": "fat_g",
    "fiber": "fiber_g",
    "fiber_g": "fiber_g",
    "dietary_fiber_g": "fiber_g",
    "sugar": "sugar_g",
    "sugar_g": "sugar_g",
    "sugars_g": "sugar_g",
    "sodium": "sodium_mg",
    "sodium_mg": "sodium_mg",
    "sodium_mg_": "sodium_mg",
    "calories": "calories",
    "energy_kcal": "calories",
}


def extract_nutrition(df: pd.DataFrame, source: str) -> pd.DataFrame:
    """Extract and rename nutrition columns from a dataset."""
    renamed = {}
    for col in df.columns:
        canonical = COL_ALIASES.get(col.lower())
        if canonical and canonical not in renamed:
            renamed[canonical] = df[col]
    if not renamed:
        return pd.DataFrame()
    out = pd.DataFrame(renamed)
    out["source"] = source
    return out


nutrition_frames = []

# Database recipes
if not db_recipes.empty:
    nutrition_frames.append(extract_nutrition(db_recipes, "db_recipes"))

# Kaggle datasets that may contain nutrition
for name, df in kaggle_data.items():
    nf = extract_nutrition(df, name)
    if not nf.empty:
        nutrition_frames.append(nf)

if nutrition_frames:
    all_nutrition = pd.concat(nutrition_frames, ignore_index=True)
    print(f"Combined nutrition data: {len(all_nutrition):,} rows from {all_nutrition['source'].nunique()} sources")
    print(f"Sources: {sorted(all_nutrition['source'].unique())}")
    print()
    display(all_nutrition.describe().round(1))
else:
    all_nutrition = pd.DataFrame()
    print("No nutrition data found in any loaded dataset.")

In [None]:
if not all_nutrition.empty:
    available_cols = [c for c in NUTRITION_COLS if c in all_nutrition.columns]

    # --- Histograms of each nutrition column ---
    n_cols = min(4, len(available_cols))
    n_rows = (len(available_cols) + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 4 * n_rows))
    axes = np.atleast_2d(axes).flatten()

    for i, col in enumerate(available_cols):
        ax = axes[i]
        data = all_nutrition[col].dropna()
        # Clip extreme outliers for readability
        upper = data.quantile(0.99)
        data_clipped = data[data <= upper]
        ax.hist(data_clipped, bins=60, color=PALETTE[i % len(PALETTE)], alpha=0.7, edgecolor="white")
        ax.set_title(col, fontsize=13)
        ax.set_ylabel("Count")
        ax.axvline(data.median(), color="red", linestyle="--", linewidth=1, label=f"median={data.median():.0f}")
        ax.legend(fontsize=9)

    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)

    fig.suptitle("Nutrition Distributions (clipped at 99th percentile)", fontsize=15, y=1.02)
    plt.tight_layout()
    plt.show()

    # --- Box plots by source ---
    sources = sorted(all_nutrition["source"].unique())
    if len(sources) > 1:
        fig, axes = plt.subplots(1, min(4, len(available_cols)), figsize=(18, 5))
        axes = np.atleast_1d(axes)

        for i, col in enumerate(available_cols[:4]):
            ax = axes[i]
            plot_data = all_nutrition[[col, "source"]].dropna()
            upper = plot_data[col].quantile(0.95)
            plot_data = plot_data[plot_data[col] <= upper]
            sns.boxplot(data=plot_data, x="source", y=col, ax=ax, palette=PALETTE)
            ax.set_title(col, fontsize=12)
            ax.tick_params(axis="x", rotation=45)

        fig.suptitle("Nutrition Comparison Across Datasets (clipped at 95th pctl)", fontsize=14, y=1.02)
        plt.tight_layout()
        plt.show()
else:
    print("Skipping nutrition plots -- no data available.")

## 3. Diet Label & Allergen Distributions

The `global_food_nutrition` dataset contains allergen flags (`contains_gluten`,
`contains_dairy`, etc.) and diet/cuisine labels. We visualize these to understand
the label balance for knowledge-based filtering.

In [None]:
gfn = kaggle_data.get("global_food_nutrition", pd.DataFrame())

if not gfn.empty:
    # --- Allergen flags ---
    allergen_cols = [c for c in gfn.columns if c.startswith("contains_")]
    if allergen_cols:
        allergen_counts = gfn[allergen_cols].sum().sort_values(ascending=True)

        fig, ax = plt.subplots(figsize=(10, max(4, len(allergen_cols) * 0.5)))
        allergen_counts.plot.barh(ax=ax, color=SNACKTRACK_COLORS["accent"], edgecolor="white")
        ax.set_xlabel("Number of foods")
        ax.set_title("Allergen Flag Frequency (global_food_nutrition)", fontsize=13)
        for i, v in enumerate(allergen_counts.values):
            ax.text(v + 50, i, f"{v:,}", va="center", fontsize=9)
        plt.tight_layout()
        plt.show()
    else:
        print("No allergen flag columns found.")

    # --- Diet labels ---
    diet_col = None
    for candidate in ["diet_labels", "diet_label", "diet_type", "dietary_label"]:
        if candidate in gfn.columns:
            diet_col = candidate
            break

    if diet_col:
        diet_counts = gfn[diet_col].value_counts().head(20)
        fig, ax = plt.subplots(figsize=(12, 5))
        diet_counts.plot.bar(ax=ax, color=SNACKTRACK_COLORS["primary"], edgecolor="white")
        ax.set_ylabel("Count")
        ax.set_title(f"Top 20 Diet Labels ({diet_col})", fontsize=13)
        ax.tick_params(axis="x", rotation=45)
        plt.tight_layout()
        plt.show()

    # --- Cuisine types ---
    cuisine_col = None
    for candidate in ["cuisine_types", "cuisine_type", "cuisine"]:
        if candidate in gfn.columns:
            cuisine_col = candidate
            break

    if cuisine_col:
        cuisine_counts = gfn[cuisine_col].value_counts().head(20)
        fig, ax = plt.subplots(figsize=(12, 5))
        cuisine_counts.plot.bar(ax=ax, color=SNACKTRACK_COLORS["secondary"], edgecolor="white")
        ax.set_ylabel("Count")
        ax.set_title(f"Top 20 Cuisine Types ({cuisine_col})", fontsize=13)
        ax.tick_params(axis="x", rotation=45)
        plt.tight_layout()
        plt.show()
else:
    print("global_food_nutrition not loaded -- skipping diet label & allergen analysis.")

## 4. User Interaction Analysis

We examine user--recipe interactions from two potential sources:

1. **Food.com interactions** (Kaggle) -- ratings, review counts, temporal activity
2. **Database `user_interactions`** -- SnackTrack-native events (`cook`, `log`, `rate`, `view`, `swap_accept`, `swap_reject`)

In [None]:
# --- Food.com interactions (Kaggle) ---
# Try several possible keys for the Food.com interaction data
foodcom_int = None
for key in ["foodcom_interactions", "foodcom_interactions__interactions_train",
            "foodcom_interactions__train", "foodcom_interactions__RAW_interactions"]:
    if key in kaggle_data:
        foodcom_int = kaggle_data[key]
        print(f"Using Food.com interactions from: {key}  ({len(foodcom_int):,} rows)")
        break

if foodcom_int is not None:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # (a) Rating distribution
    rating_col = None
    for c in ["rating", "score", "review_score"]:
        if c in foodcom_int.columns:
            rating_col = c
            break
    if rating_col:
        axes[0].hist(foodcom_int[rating_col].dropna(), bins=20, color=SNACKTRACK_COLORS["primary"],
                     edgecolor="white", alpha=0.8)
        axes[0].set_title("Rating Distribution", fontsize=13)
        axes[0].set_xlabel("Rating")
        axes[0].set_ylabel("Count")
    else:
        axes[0].text(0.5, 0.5, "No rating column found", ha="center", va="center", transform=axes[0].transAxes)

    # (b) Interactions per user (log scale)
    user_col = None
    for c in ["user_id", "authorid", "author_id"]:
        if c in foodcom_int.columns:
            user_col = c
            break
    if user_col:
        per_user = foodcom_int[user_col].value_counts()
        axes[1].hist(per_user.values, bins=100, color=SNACKTRACK_COLORS["accent"],
                     edgecolor="white", alpha=0.8, log=True)
        axes[1].set_title("Interactions per User (log scale)", fontsize=13)
        axes[1].set_xlabel("Number of interactions")
        axes[1].set_ylabel("Number of users")
    else:
        axes[1].text(0.5, 0.5, "No user_id column found", ha="center", va="center", transform=axes[1].transAxes)

    # (c) Temporal pattern
    date_col = None
    for c in ["date", "submitted", "created_at", "review_date"]:
        if c in foodcom_int.columns:
            date_col = c
            break
    if date_col:
        dates = pd.to_datetime(foodcom_int[date_col], errors="coerce").dropna()
        axes[2].hist(dates.dt.year, bins=30, color=SNACKTRACK_COLORS["secondary"],
                     edgecolor="white", alpha=0.8)
        axes[2].set_title("Interactions Over Time", fontsize=13)
        axes[2].set_xlabel("Year")
        axes[2].set_ylabel("Count")
    else:
        axes[2].text(0.5, 0.5, "No date column found", ha="center", va="center", transform=axes[2].transAxes)

    fig.suptitle("Food.com Interaction Analysis", fontsize=15, y=1.02)
    plt.tight_layout()
    plt.show()
else:
    print("Food.com interaction data not available.")

# --- Database interactions ---
if not db_interactions.empty:
    print(f"\nDatabase interactions: {len(db_interactions):,} rows")

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # (a) Interaction type distribution
    type_counts = db_interactions["interaction_type"].value_counts()
    type_counts.plot.bar(ax=axes[0], color=PALETTE[:len(type_counts)], edgecolor="white")
    axes[0].set_title("Interaction Types (DB)", fontsize=13)
    axes[0].set_ylabel("Count")
    axes[0].tick_params(axis="x", rotation=45)

    # (b) Interactions per user
    per_user_db = db_interactions["user_id"].value_counts()
    axes[1].hist(per_user_db.values, bins=50, color=SNACKTRACK_COLORS["accent"],
                 edgecolor="white", alpha=0.8)
    axes[1].set_title("Interactions per User (DB)", fontsize=13)
    axes[1].set_xlabel("Number of interactions")
    axes[1].set_ylabel("Number of users")

    plt.tight_layout()
    plt.show()
else:
    print("No database interactions available.")

## 5. Meal Patterns

The `daily_food_nutrition` dataset contains individual food log entries with
**meal types** (`breakfast`, `lunch`, `dinner`, `snack`) and food categories.
We analyze these patterns to inform our RNN meal-sequence model.

In [None]:
dfn = kaggle_data.get("daily_food_nutrition", pd.DataFrame())

if not dfn.empty:
    print(f"daily_food_nutrition: {len(dfn):,} rows, {dfn.shape[1]} columns")
    print(f"Columns: {list(dfn.columns)}\n")

    fig, axes = plt.subplots(1, 2, figsize=(16, 5))

    # (a) Meal type distribution
    meal_col = None
    for c in ["meal_type", "meal", "mealtype"]:
        if c in dfn.columns:
            meal_col = c
            break

    if meal_col:
        meal_counts = dfn[meal_col].value_counts()
        colors = PALETTE[:len(meal_counts)]
        axes[0].pie(meal_counts.values, labels=meal_counts.index, autopct="%1.1f%%",
                    colors=colors, startangle=90)
        axes[0].set_title(f"Meal Type Distribution ({meal_col})", fontsize=13)
    else:
        axes[0].text(0.5, 0.5, "No meal_type column found", ha="center", va="center",
                     transform=axes[0].transAxes)

    # (b) Food category breakdown
    cat_col = None
    for c in ["food_category", "category", "food_group", "food_type"]:
        if c in dfn.columns:
            cat_col = c
            break

    if cat_col:
        cat_counts = dfn[cat_col].value_counts().head(15)
        cat_counts.plot.barh(ax=axes[1], color=SNACKTRACK_COLORS["primary"], edgecolor="white")
        axes[1].set_xlabel("Count")
        axes[1].set_title(f"Top 15 Food Categories ({cat_col})", fontsize=13)
        axes[1].invert_yaxis()
    else:
        axes[1].text(0.5, 0.5, "No food category column found", ha="center", va="center",
                     transform=axes[1].transAxes)

    plt.tight_layout()
    plt.show()
else:
    print("daily_food_nutrition not loaded -- skipping meal pattern analysis.")

# Also check database meal logs
if not db_meal_logs.empty:
    print(f"\nDB meal_logs: {len(db_meal_logs):,} rows")
    if "meal_type" in db_meal_logs.columns:
        print(db_meal_logs["meal_type"].value_counts().to_string())

## 6. Data Quality Summary

A consolidated view of every dataset's size, completeness, and key features.

In [None]:
summary_rows = []

# Database tables
db_tables = {
    "db: recipes": db_recipes,
    "db: interactions": db_interactions,
    "db: meal_logs": db_meal_logs,
    "db: user_profiles": db_user_profiles,
    "db: dietary_prefs": db_dietary_prefs,
}
for name, df in db_tables.items():
    if df.empty:
        continue
    null_pct = (df.isnull().sum().sum() / (df.shape[0] * df.shape[1]) * 100) if df.size > 0 else 0
    key_features = ", ".join(list(df.columns)[:6])
    summary_rows.append({
        "Dataset": name,
        "Rows": f"{df.shape[0]:,}",
        "Columns": df.shape[1],
        "Null %": f"{null_pct:.1f}%",
        "Key Features": key_features + ("..." if df.shape[1] > 6 else ""),
    })

# Kaggle datasets
for name, df in sorted(kaggle_data.items()):
    null_pct = (df.isnull().sum().sum() / (df.shape[0] * df.shape[1]) * 100) if df.size > 0 else 0
    key_features = ", ".join(list(df.columns)[:6])
    summary_rows.append({
        "Dataset": name,
        "Rows": f"{df.shape[0]:,}",
        "Columns": df.shape[1],
        "Null %": f"{null_pct:.1f}%",
        "Key Features": key_features + ("..." if df.shape[1] > 6 else ""),
    })

if summary_rows:
    summary_df = pd.DataFrame(summary_rows)

    # Style the table
    styled = (
        summary_df.style
        .set_properties(**{"text-align": "left"})
        .set_table_styles([
            {"selector": "th", "props": [("background-color", SNACKTRACK_COLORS["primary"]),
                                          ("color", "white"), ("font-weight", "bold")]},
            {"selector": "td", "props": [("padding", "6px 12px")]},
        ])
        .hide(axis="index")
    )
    display(styled)
else:
    print("No datasets were loaded. Run notebook 00 first.")

print("\nEDA complete. Proceed to notebook 02 for content-based filtering analysis.")