In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np


def plot_description(text):
    print(f"\nDescription:\n{text}\n")


# %% Investigate model predictions over days for different nitrogen levels
def plot_model_predictions_over_days(
    preprocessor,
    y_scaler,
    model,
    model_name,
    integrated_data,
    ax=None,
):
    # Plot two configurations side-by-side
    test_species_1 = "barley_2"
    test_drought_stress_1 = 1
    test_species_2 = "bread_wheat"
    test_drought_stress_2 = 0

    configs = [
        (test_species_1, test_drought_stress_1),
        (test_species_2, test_drought_stress_2),
    ]

    test_days = np.arange(1, 21)
    test_nitrogen_levels = [25, 130]

    # Create side-by-side axes
    if ax is None:
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        show_plt = True
    else:
        # allow passing either a single Axes or an array-like of two axes
        if hasattr(ax, "__len__") and len(ax) == 2:
            axes = ax
            show_plt = False
        else:
            fig = ax.figure
            axes = [ax, fig.add_subplot(1, 2, 2)]
            show_plt = False

    title_font_size = 12
    axis_font_size = 10
    species_name_map = {"barley_2": "Barley v2", "durum": "Durum Wheat", "barley_1": "Barley v1", "bread_wheat": "Bread Wheat"}

    for ax_idx, (test_species, test_drought_stress) in enumerate(configs):
        ax_loc = axes[ax_idx]

        # Filter integrated_data to same drought condition so CI / lines are comparable
        data_plot = integrated_data[
            (integrated_data["species"] == test_species)
            & (integrated_data["drought_stress"] == test_drought_stress)
            & (integrated_data["nitrogen_applied"].isin(test_nitrogen_levels))
        ].copy()

        for idx, n_val in enumerate(test_nitrogen_levels):
            test_df = pd.DataFrame(
                {
                    "days_of_phenotyping": test_days,
                    "species": [test_species] * len(test_days),
                    "nitrogen_applied": [n_val] * len(test_days),
                    "drought_stress": [test_drought_stress] * len(test_days),
                }
            )

            # Preprocess and predict
            Xp = preprocessor.transform(test_df)
            pred_scaled = model.predict(Xp)
            pred = y_scaler.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()

            # Plot predictions
            ax_loc.plot(
                test_days,
                pred,
                label=f"Prediction (N={n_val})",
                marker="o" if n_val == 25 else "x",
                markeredgewidth=2,
                markersize=7,
                linewidth=2,
                alpha=0.9,
                linestyle=":",
            )

            # Plot ground truth points for this nitrogen level
            gt = data_plot[data_plot["nitrogen_applied"] == n_val].sort_values("days_of_phenotyping")
            if not gt.empty:
                sns.lineplot(
                    data=gt,
                    x="days_of_phenotyping",
                    y="digital_biomass",
                    ax=ax_loc,
                    markers=True,
                    markersize=10,
                    dashes=False,
                    label=f"Ground Truth (N={n_val})"
                )

        # Title and labels
        species_label = species_name_map.get(test_species, test_species)
        cond_label = "Drought" if test_drought_stress else "No Drought"
        ax_loc.set_xlabel("Days of Phenotyping", fontsize=axis_font_size)
        ax_loc.set_ylabel("Biomass (dm³)", fontsize=axis_font_size)
        ax_loc.set_title(f"{model_name} - {species_label} - {cond_label}", fontsize=title_font_size)
        ax_loc.grid(True)
        ax_loc.tick_params(axis="both", labelsize=max(8, axis_font_size - 2))
        ax_loc.legend(loc="upper left", fontsize=max(8, axis_font_size - 4))

    if show_plt:
        plt.tight_layout()
        plt.show()

    return axes

    # test_days = np.arange(1, 21)
    # test_nitrogen_levels = [25, 130]


    # title_font_size = 12
    # axis_font_size = 10

    # show_plt = False
    # if ax is None:
    #     show_plt = True
    #     fig, ax = plt.subplots(figsize=(6, 5))

    # # Filter integrated_data to same drought condition so CI / lines are comparable
    # data_plot = integrated_data[
    #     (integrated_data["species"] == test_species)
    #     & (integrated_data["drought_stress"] == test_drought_stress)
    #     & (integrated_data["nitrogen_applied"].isin(test_nitrogen_levels))
    #     ].copy()

    # for idx, n_val in enumerate(test_nitrogen_levels):
    #     test_df = pd.DataFrame(
    #         {
    #             "days_of_phenotyping": test_days,
    #             "species": [test_species] * len(test_days),
    #             "nitrogen_applied": [n_val] * len(test_days),
    #             "drought_stress": [test_drought_stress] * len(test_days),
    #         }
    #     )

    #     # Preprocess and predict
    #     Xp = preprocessor.transform(test_df)
    #     pred_scaled = model.predict(Xp)
    #     pred = y_scaler.inverse_transform(pred_scaled.reshape(-1, 1)).ravel()

    #     # Plot predictions
    #     ax.plot(
    #         test_days,
    #         pred,
    #         label=f"Prediction (N={n_val})",
    #         marker="o" if n_val == 25 else "x",
    #         markeredgewidth=2,
    #         markersize=7,
    #         linewidth=2,
    #         alpha=0.9,
    #         linestyle=":",
    #     )

    #     # Plot ground truth points for this nitrogen level
    #     gt = data_plot[data_plot["nitrogen_applied"] == n_val].sort_values("days_of_phenotyping")
    #     if not gt.empty:
    #         sns.lineplot(
    #             data=gt,
    #             x="days_of_phenotyping",
    #             y="digital_biomass",
    #             ax=ax,
    #             markers=True,
    #             markersize=10,
    #             dashes=False,
    #             label=f"Ground Truth (N={n_val})"
    #         )

    #     # Title and labels
    #     species_name_map = {"barley_2": "Barley v2", "durum": "Durum Wheat", "barley_1": "Barley v1", "bread_wheat": "Bread Wheat"}
    #     species_label = species_name_map.get(test_species, test_species)
    #     cond_label = "Drought" if test_drought_stress else "No Drought"
    #     ax.set_xlabel("Days of Phenotyping", fontsize=axis_font_size)
    #     ax.set_ylabel("Biomass (dm³)", fontsize=axis_font_size)
    #     ax.set_title(f"{model_name} - {species_label} - {cond_label}", fontsize=title_font_size)
    #     ax.grid(True)
    #     # adjust tick label size and legend font to match axis sizing (with sensible minimums)
    #     ax.tick_params(axis="both", labelsize=max(8, axis_font_size - 2))
    #     ax.legend(loc="upper left", fontsize=max(8, axis_font_size - 4))

    # if show_plt:
    #     plt.tight_layout()
    #     plt.show()

    # return ax