In [None]:
import ast
import optuna
import pandas as pd
import numpy as np
import random

from src.models.affinity import AffinityBuilder
from src.models.greedy import GreedyLayout
from src.models.forbidden_pairs import get_forbidden_pairs
from src.models.ga_optimizer import GeneticLayoutOptimizer  # bản GA đã mix fitness
from src.plots import LayoutVisualizer
from src.config import PROCESSED_DATA_DIR, INTERIM_DATA_DIR


# ----- 1. DATA LOADER -----
class DataLoader:
    def __init__(
        self,
        assoc_rules_path,
        freq_itemsets_path,
        layout_real_path,
        margin_matrix_path=None,
    ):
        self.assoc_rules = pd.read_csv(assoc_rules_path)
        self.freq_itemsets = pd.read_csv(freq_itemsets_path)
        self.layout_real = pd.read_csv(layout_real_path).drop_duplicates(keep="first")
        self.margin_matrix = (
            pd.read_csv(margin_matrix_path, index_col=0)
            if margin_matrix_path is not None
            else None
        )
        self._process()

    def _process(self):
        # Parse items safely từ association rules & frequent itemsets
        antecedents = self.assoc_rules["antecedent"].apply(ast.literal_eval)
        consequents = self.assoc_rules["consequent"].apply(ast.literal_eval)
        itemsets = self.freq_itemsets["items"].apply(ast.literal_eval)

        all_items = set()
        for ser in antecedents.tolist() + consequents.tolist():
            all_items.update(ser)
        for sublist in itemsets.tolist():
            all_items.update(sublist)

        # BỔ SUNG: mọi Category có trong layout.csv
        layout_cats = self.layout_real["Category"].dropna().astype(str).tolist()
        all_items.update(layout_cats)
        self.all_items = sorted(all_items)

        # Real categories
        self.real_categories = set(self.layout_real["Category"].dropna().unique())

        # Lấy positions từ x,y
        if {"x", "y"}.issubset(self.layout_real.columns):
            self.positions = list(zip(self.layout_real["x"], self.layout_real["y"]))
        else:
            raise ValueError("layout_real.csv không có cột x,y")

        # Flags (nếu thiếu thì điền 0)
        for col in ["is_refrigerated", "is_entrance", "is_cashier", "width", "height"]:
            if col not in self.layout_real.columns:
                # mặc định 0 cho flags, và 0 cho width/height nếu thiếu
                self.layout_real[col] = 0

        # Hard rules: chỉ giữ tủ mát (gom các nhóm tủ mát lại gần nhau)
        refrig_cats = (
            self.layout_real.loc[self.layout_real["is_refrigerated"] == 1, "Category"]
            .astype(str)
            .tolist()
        )
        self.hard_rules = {}
        if refrig_cats:
            self.hard_rules["must_group_refrigerated"] = refrig_cats


# ----- 2. HARD RULE ENGINE -----
class HardRuleEngine:
    def __init__(self, rules_dict):
        self.rules = rules_dict or {}

    def check_must_together(self, layout):
        pairs = self.rules.get("must_together", [])
        return all(
            (
                a in layout
                and b in layout
                and abs(layout.index(a) - layout.index(b)) == 1
            )
            for a, b in pairs
        )

    def check_must_order(self, layout):
        pairs = self.rules.get("must_order", [])
        return all(
            (a in layout and b in layout and layout.index(a) < layout.index(b))
            for a, b in pairs
        )

    def check_group_refrigerated(self, layout):
        refrig_items = self.rules.get("must_group_refrigerated", [])
        if not refrig_items:
            return True
        idxs = [layout.index(cat) for cat in refrig_items if cat in layout]
        if not idxs:
            return True
        return max(idxs) - min(idxs) + 1 == len(idxs)

    def check_all(self, layout):
        return {
            "must_together": self.check_must_together(layout),
            "must_order": self.check_must_order(layout),
            "must_group_refrigerated": self.check_group_refrigerated(layout),
        }


# ----- 3. PIPELINE (mixed fitness: association + entrance) -----
class LayoutOptimizationPipeline:
    def __init__(
        self,
        data,
        n_trials=30,
        n_gen_final=80,
        use_optuna=True,
        output_path=None,
        ga_selection="tournament",
        ga_crossover="PMX",
        ga_mutation="shuffle",
        ga_adaptive=True,
        ga_ensemble_runs=1,
        hard_rules=None,
        seed=42,
    ):
        # Seed
        self.seed = seed
        np.random.seed(seed)
        random.seed(seed)

        # Data
        self.data = data
        self.n_trials = n_trials
        self.n_gen_final = n_gen_final
        self.use_optuna = use_optuna
        self.output_path = output_path or (
            PROCESSED_DATA_DIR / "layout_real_mapped.csv"
        )
        self.ga_selection = ga_selection
        self.ga_crossover = ga_crossover
        self.ga_mutation = ga_mutation
        self.ga_adaptive = ga_adaptive
        self.ga_ensemble_runs = ga_ensemble_runs

        # Positions và Real Categories
        self.positions = self.data.positions
        self.real_categories = self.data.real_categories
        self.all_items = self.data.all_items

        # Merge & sanitize hard rules (DataLoader + tham số truyền vào)
        merged_rules = {}
        if hasattr(self.data, "hard_rules") and isinstance(self.data.hard_rules, dict):
            merged_rules.update(self.data.hard_rules)
        if isinstance(hard_rules, dict):
            merged_rules.update(hard_rules)
        self.hard_rule_engine = HardRuleEngine(
            self._sanitize_rules(merged_rules, self.all_items)
        )

        # Affinity builder
        self.affinity_builder = AffinityBuilder(
            self.data.assoc_rules,
            self.data.freq_itemsets,
            self.data.all_items,
            self.data.margin_matrix,
        )

    @staticmethod
    def _sanitize_rules(rules, all_items):
        """Loại bỏ các Category không tồn tại trong all_items để tránh .index() lỗi."""
        if not rules:
            return {}
        known = set(all_items)
        out = {}
        for k, v in rules.items():
            if k in ("must_together", "must_order"):
                out[k] = [(a, b) for (a, b) in v if a in known and b in known]
            elif k == "must_group_refrigerated":
                out[k] = [cat for cat in v if cat in known]
            else:
                out[k] = v
        return out

    def _get_real_layout_as_seed(self):
        # Sắp xếp layout thực tế theo tọa độ (y trước, x sau) và lọc theo all_items
        layout_real = (
            self.data.layout_real.sort_values(by=["y", "x"])["Category"]
            .astype(str)
            .tolist()
        )
        known = set(self.all_items)
        return [c for c in layout_real if c in known]

    def _filter_layout(self, best_layout):
        return best_layout[: len(self.positions)]

    # ---------- helpers cho mixed fitness ----------
    def _build_coords_and_entrance(self, override_entr_xy=None):
        """
        Trả về:
        - coords: danh sách (cx, cy) của các slot theo thứ tự sort (y, x)
        - entr_xy: tâm của Entrance (từ flag/tên; nếu không có thì fallback heuristic)
        """
        # slots theo thứ tự (y,x)
        slots = self.data.layout_real.sort_values(["y", "x"]).reset_index(drop=True)
        # đảm bảo có width/height
        if "width" not in slots.columns:
            slots["width"] = 0
        if "height" not in slots.columns:
            slots["height"] = 0
        coords = list(
            zip(slots["x"] + slots["width"] * 0.5, slots["y"] + slots["height"] * 0.5)
        )

        # nếu user truyền sẵn toạ độ entrance thì dùng luôn
        if override_entr_xy is not None:
            return coords, tuple(override_entr_xy)

        df_lr = self.data.layout_real.copy()

        # 1) ưu tiên theo cờ is_entrance
        if (
            "is_entrance" in df_lr.columns
            and df_lr["is_entrance"].fillna(0).astype(int).any()
        ):
            row = (
                df_lr.loc[df_lr["is_entrance"].fillna(0).astype(int) == 1]
                .sort_values(["y", "x"])
                .iloc[0]
            )
            ex = float(row["x"]) + float(row.get("width", 0)) * 0.5
            ey = float(row["y"]) + float(row.get("height", 0)) * 0.5
            return coords, (ex, ey)

        # 2) theo tên Category (nhiều biến thể)
        name_col = df_lr["Category"].astype(str).str.lower()
        name_hits = name_col.isin({"entrance", "entry", "door", "cửa vào", "lối vào"})
        if name_hits.any():
            row = df_lr.loc[name_hits].sort_values(["y", "x"]).iloc[0]
            ex = float(row["x"]) + float(row.get("width", 0)) * 0.5
            ey = float(row["y"]) + float(row.get("height", 0)) * 0.5
            return coords, (ex, ey)

        # 3) fallback heuristic: chọn điểm trên-cùng-trái
        row = df_lr.sort_values(["y", "x"]).iloc[0]
        ex = float(row["x"]) + float(row.get("width", 0)) * 0.5
        ey = float(row["y"]) + float(row.get("height", 0)) * 0.5
        return coords, (ex, ey)

    def _build_cat_support(self):
        # support lớn nhất của mỗi category qua các itemset chứa nó
        cat_support = {c: 0.0 for c in self.all_items}
        if (
            "items" in self.data.freq_itemsets.columns
            and "support" in self.data.freq_itemsets.columns
        ):
            for _, r in self.data.freq_itemsets.iterrows():
                try:
                    items = ast.literal_eval(r["items"])
                except Exception:
                    continue
                sup = float(r["support"])
                for it in items:
                    if it in cat_support:
                        cat_support[it] = max(cat_support[it], sup)
        return cat_support

    def _build_pairs_list(self, affinity, threshold: float):
        # lấy các cặp có affinity >= threshold
        pairs = []
        for a in self.all_items:
            for b in self.all_items:
                if a >= b:
                    continue
                w = float(affinity.loc[a, b])
                if w >= threshold:
                    pairs.append((a, b, w))
        return pairs

    # ---------- Optuna objective ----------
    def objective(self, trial):
        lift_threshold = trial.suggest_float("lift_threshold", 0.0, 2.0)
        penalty = trial.suggest_int("penalty", 10, 150)
        pop_size = trial.suggest_int("pop_size", 100, 500, step=50)
        greedy_ratio = trial.suggest_float("greedy_ratio", 0.05, 0.5)
        w_lift = trial.suggest_float("w_lift", 0.1, 1.0)
        w_conf = trial.suggest_float("w_conf", 0.0, 1.0)
        w_margin = trial.suggest_float("w_margin", 0.0, 1.0)
        gamma = trial.suggest_float("gamma", 0.5, 4.0)

        # mixed-fitness weights
        w_aff = trial.suggest_float("w_aff", 0.5, 2.0)
        w_pair = trial.suggest_float("w_pair", 0.0, 2.0)
        w_entr = trial.suggest_float("w_entr", 0.0, 2.0)
        gamma_support = trial.suggest_float("gamma_support", 0.5, 1.5)
        pair_threshold = trial.suggest_float("pair_threshold", 0.5, 0.9)

        weight_sum = w_lift + w_conf + w_margin
        w_lift, w_conf, w_margin = (
            w_lift / weight_sum,
            w_conf / weight_sum,
            w_margin / weight_sum,
        )

        affinity = self.affinity_builder.build_affinity(
            lift_threshold=lift_threshold,
            w_lift=w_lift,
            w_conf=w_conf,
            w_margin=w_margin,
        )
        affinity = self.affinity_builder.normalize(affinity)
        affinity = self.affinity_builder.kernelize(affinity, gamma=gamma)

        greedy_module = GreedyLayout(self.all_items)
        layout_greedy = greedy_module.init_layout(affinity)
        layout_greedy = greedy_module.local_search(layout_greedy, affinity)

        forbidden_pairs = get_forbidden_pairs(affinity, self.all_items)

        # --- mixed-fitness inputs ---
        coords, entr_xy = self._build_coords_and_entrance()
        cat_support = self._build_cat_support()
        pairs_list = self._build_pairs_list(affinity, threshold=pair_threshold)

        layout_real = self._get_real_layout_as_seed()

        ga_module = GeneticLayoutOptimizer(
            self.all_items,
            affinity,
            forbidden_pairs,
            penalty=penalty,
            greedy_ratio=greedy_ratio,
            selection=self.ga_selection,
            crossover=self.ga_crossover,
            mutation=self.ga_mutation,
            adaptive=self.ga_adaptive,
            hard_rules=self.hard_rule_engine.rules,
            # NEW: mixed params
            coords=coords,
            entr_xy=entr_xy,
            cat_support=cat_support,
            pairs_list=pairs_list,
            w_aff=w_aff,
            w_pair=w_pair,
            w_entr=w_entr,
            gamma_support=gamma_support,
        )

        if self.ga_ensemble_runs > 1:
            best_run, _ = ga_module.run_ensemble(
                ngen=30,
                pop_size=pop_size,
                greedy_layout=layout_greedy,
                n_runs=self.ga_ensemble_runs,
                init_population_extra=[layout_real],
            )
            best_layout = best_run["best_layout"]
            best_fitness = best_run["best_fitness"]
            logbook = best_run["logbook"]
        else:
            best_layout, best_fitness, logbook = ga_module.run(
                ngen=30,
                pop_size=pop_size,
                greedy_layout=layout_greedy,
                record_logbook=True,
                init_population_extra=[layout_real],
            )

        filtered_best_layout = self._filter_layout(best_layout)
        fitness = best_fitness

        trial.set_user_attr("best_layout", filtered_best_layout)
        trial.set_user_attr("logbook", logbook)
        return fitness

    def tune(self):
        sampler = optuna.samplers.TPESampler(seed=self.seed)
        study = optuna.create_study(direction="maximize", sampler=sampler)
        study.optimize(self.objective, n_trials=self.n_trials)
        self.study = study
        self.best_params = study.best_params
        best_trial = study.best_trial
        self.best_layout = best_trial.user_attrs["best_layout"]
        self.best_logbook = best_trial.user_attrs["logbook"]
        print("Best Optuna params:", self.best_params)
        print("Best layout (from Optuna):", self.best_layout)
        return study

    def run_final(self):
        p = self.best_params
        affinity = self.affinity_builder.build_affinity(
            lift_threshold=p["lift_threshold"],
            w_lift=p["w_lift"],
            w_conf=p["w_conf"],
            w_margin=p["w_margin"],
        )
        affinity = self.affinity_builder.normalize(affinity)
        affinity = self.affinity_builder.kernelize(affinity, gamma=p["gamma"])

        greedy_module = GreedyLayout(self.all_items)
        layout_greedy = greedy_module.init_layout(affinity)
        layout_greedy = greedy_module.local_search(layout_greedy, affinity)

        forbidden_pairs = get_forbidden_pairs(affinity, self.all_items)
        layout_real = self._get_real_layout_as_seed()

        # --- mixed-fitness inputs (giống objective) ---
        coords, entr_xy = self._build_coords_and_entrance()
        cat_support = self._build_cat_support()
        pairs_list = self._build_pairs_list(affinity, threshold=p["pair_threshold"])

        ga_module = GeneticLayoutOptimizer(
            self.all_items,
            affinity,
            forbidden_pairs,
            penalty=p["penalty"],
            greedy_ratio=p["greedy_ratio"],
            selection=self.ga_selection,
            crossover=self.ga_crossover,
            mutation=self.ga_mutation,
            adaptive=self.ga_adaptive,
            hard_rules=self.hard_rule_engine.rules,
            # NEW: mixed params
            coords=coords,
            entr_xy=entr_xy,
            cat_support=cat_support,
            pairs_list=pairs_list,
            w_aff=p["w_aff"],
            w_pair=p["w_pair"],
            w_entr=p["w_entr"],
            gamma_support=p["gamma_support"],
        )

        if self.ga_ensemble_runs > 1:
            best_run, _ = ga_module.run_ensemble(
                ngen=self.n_gen_final,
                pop_size=p["pop_size"],
                greedy_layout=layout_greedy,
                n_runs=self.ga_ensemble_runs,
                init_population_extra=[layout_real],
            )
            best_layout = best_run["best_layout"]
            best_fitness = best_run["best_fitness"]
            logbook = best_run["logbook"]
        else:
            best_layout, best_fitness, logbook = ga_module.run(
                ngen=self.n_gen_final,
                pop_size=p["pop_size"],
                greedy_layout=layout_greedy,
                record_logbook=True,
                init_population_extra=[layout_real],
            )

        # ---------- Xuất layout: gán category theo thứ tự slot (y,x) + giữ width/height ----------
        filtered_best_layout = [str(c) for c in self._filter_layout(best_layout)]
        slots = self.data.layout_real.sort_values(["y", "x"])[
            ["x", "y", "width", "height"]
        ].reset_index(drop=True)
        n = min(len(filtered_best_layout), len(slots))
        layout_optimal_xy = pd.DataFrame(
            {
                "Category": filtered_best_layout[:n],
                "x": slots.loc[: n - 1, "x"].to_list(),
                "y": slots.loc[: n - 1, "y"].to_list(),
                "width": slots.loc[: n - 1, "width"].to_list(),
                "height": slots.loc[: n - 1, "height"].to_list(),
            }
        )
        layout_optimal_xy["cx"] = (
            layout_optimal_xy["x"] + layout_optimal_xy["width"] / 2.0
        )
        layout_optimal_xy["cy"] = (
            layout_optimal_xy["y"] + layout_optimal_xy["height"] / 2.0
        )

        layout_optimal_xy.to_csv(self.output_path, index=False)

        print(f"\nBest layout: {filtered_best_layout}")
        print(
            "Hard rule checks:", self.hard_rule_engine.check_all(filtered_best_layout)
        )
        final_fitness = best_fitness
        print(f"Best fitness (GA): {final_fitness:.4f}")

        self.layout_optimal_xy = layout_optimal_xy
        self.affinity = affinity
        self.best_fitness = final_fitness
        self.ga_logbook = pd.DataFrame(logbook)
        return layout_optimal_xy, final_fitness

    def plot_all(self):
        if not hasattr(self, "layout_optimal_xy"):
            print("You must run run_final() before plotting.")
            return
        LayoutVisualizer.plot_affinity_heatmap(self.affinity)
        LayoutVisualizer.plot_affinity_bar(self.affinity)
        LayoutVisualizer.plot_ga_convergence(self.ga_logbook)
        if hasattr(self, "study"):
            LayoutVisualizer.plot_optuna_trials(self.study)
        LayoutVisualizer.plot_spring_layout(self.affinity, threshold=0.8)


# ----- Example usage -----
if __name__ == "__main__":
    hard_rules = {}  # có thể truyền thêm rule ngoài nếu muốn (sẽ được merge & sanitize)

    df = DataLoader(
        assoc_rules_path=PROCESSED_DATA_DIR / "association_rules.csv",
        freq_itemsets_path=PROCESSED_DATA_DIR / "frequent_itemsets.csv",
        layout_real_path=INTERIM_DATA_DIR / "layout.csv",
        margin_matrix_path=None,
    )

    pipeline = LayoutOptimizationPipeline(
        data=df,
        n_trials=20,
        n_gen_final=100,
        ga_selection="tournament",
        ga_crossover="PMX",
        ga_mutation="shuffle",
        ga_adaptive=True,
        ga_ensemble_runs=3,
        hard_rules=hard_rules,
        seed=42,
    )

    pipeline.tune()
    pipeline.run_final()
    pipeline.plot_all()