In [None]:
import numpy as np
from ipywidgets import interactive, IntSlider, VBox, Label, Layout, fixed
from IPython.display import display
from sklearn.tree import DecisionTreeRegressor, plot_tree
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import pandas as pd
import matplotlib.pyplot as plt

def plot_model_predictions_over_days():
    """
    This function is a placeholder for future implementation.
    """
    raise NotImplementedError("This function is not yet implemented.")

%run example_definitions/lecture_06/plot_phenotyping_regression.ipynb

def dt_regression(df_loaded, max_depth=2, min_samples_split=2):
    model_name = "Regression Tree"
    model = DecisionTreeRegressor(max_depth=max_depth, min_samples_split=min_samples_split)

    # ----------------------
    # Define Features/Target
    # ----------------------
    passthrough_cols = [
        "days_of_phenotyping",
        "nitrogen_applied",
        "drought_stress",
    ]
    # X_data_init = df_loaded[feature_cols]

    preprocessor = ColumnTransformer(
        transformers=[
            ('other_features', 'passthrough', passthrough_cols),  # Leave these unchanged
            ('species_ohe', OneHotEncoder(drop='first', sparse_output=False), ['species']),
        ]
    )
    X_data_init = preprocessor.fit_transform(df_loaded)
    feature_names = passthrough_cols + preprocessor.named_transformers_['species_ohe'].get_feature_names_out(['species']).tolist()
    X_data_init =  pd.DataFrame(X_data_init, columns=feature_names)

    # display(X_data_init.head())

    target_col = "digital_biomass"
    y_data_init = df_loaded[target_col]

    # --------------------------
    # Train/Test Split
    # --------------------------
    X_train, X_test, y_train, y_test = train_test_split(
        X_data_init, y_data_init, test_size=0.3, random_state=42
    )

    # --------------------------
    # 7. Fit Model
    # --------------------------
    model.fit(X_train, y_train)

    # --------------------------
    # 8. Evaluate
    # --------------------------

    # Predict scaled targets on test set
    y_pred = model.predict(X_test)

    r2_test_ridge = r2_score(y_test, y_pred)
    mse_test_ridge = np.mean((y_test - y_pred) ** 2)
    rmse_test_ridge = np.sqrt(mse_test_ridge)

    print("--------------------------------------------------")
    print(f"{model_name} - Regression Results:")
    print(f"Test Set RMSE: {rmse_test_ridge:.3e}")
    print(f"Test Set RÂ² score: {r2_test_ridge:.3f}")
    print("--------------------------------------------------")

    plt.figure(figsize=(20, 10))
    plot_tree(
        model,
        impurity=True,
        filled=True,
        feature_names=feature_names,
        rounded=True,
        fontsize=10,
    )
    plt.title(f"{model_name} - Train Set - Tree Structure")
    plt.show()

    # %% Test predictions over days for different nitrogen levels
    plot_model_predictions_over_days(
        preprocessor,
        None,
        model,
        model_name,
        df_loaded,
        feature_names=feature_names,
    )

# ---------------------------------------------------
# 5. Interactive controls
# ---------------------------------------------------
def regression_tree_regression_interact(df_loaded):
    max_depth_slider = IntSlider(
        value=7,
        min=1,
        max=20,
        step=1,
        description="Max Depth:",
        continuous_update=False,
        style={'description_width': '160px'},
        layout=Layout(width="400px"),
    )
    min_samples_split_slider = IntSlider(
        value=25,
        min=5,
        max=50,
        step=5,
        description="Min Samples Split:",
        continuous_update=False,
        style={'description_width': '160px'},
        layout=Layout(width="400px"),
    )

    ui_box = VBox([
        Label(value="ðŸ“Š Controls", layout=Layout(margin="0 0 0 0")),
    ])

    interactive_plot = interactive(
        dt_regression,
        df_loaded=fixed(df_loaded),
        max_depth=max_depth_slider,
        min_samples_split=min_samples_split_slider,
    )

    display(ui_box, interactive_plot)