In [None]:
# %% Import libraries
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
from srcLocal.train import *

In [None]:
# Load the training data
X = np.load('X.npy') # shape: (n_samples, n_features); features are [zenith, azimuth, u1, u2, u3]
y = np.load('y.npy') # shape: (n_samples, n_targets); targets are [vei_east, vei_west]
print(f'n_samples: {X.shape[0]}, n_features: {X.shape[1]}, n_targets: {y.shape[1]}')

In [None]:
# Split data into train, validation and test sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.25, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, test_size=0.5, random_state=42)

In [None]:
# %% Conduct hyperparameter optimization
# hyperparameter_optimization(X_train, X_val, X_test, y_train, y_val, y_test)

In [None]:
# %% Train the best model on the full training data
model, _ = train_model_dtree(X_train=X_train,y_train=y_train,criterion = 'squared_error',splitter = 'best',max_depth = 30,min_samples_split = 2,min_samples_leaf = 1)
print(f'Normalized MAE on test set: {nMAE(y_test, model.predict(X_test)):.1f} %')

In [None]:
class SurfaceVisualizer:
    """
    A class to visualize model predictions as a 3D surface.
    All controls (sliders) are aligned in a single column.
    """

    def __init__(self, model, res=50):
        """
        Initialize the visualizer.

        Parameters:
        -----------
        model : trained model
            Model used for prediction.
        res : int, optional
            Resolution of the spherical grid (default=50).
        """
        self.model = model
        self.res = res

        # Create widgets
        self.s1 = widgets.FloatSlider(
            value=0, min=0, max=100, step=1,
            description='Tint region 1', continuous_update=False
        )
        self.s2 = widgets.FloatSlider(
            value=0, min=0, max=100, step=1,
            description='Tint region 2', continuous_update=False
        )
        self.s3 = widgets.FloatSlider(
            value=0, min=0, max=100, step=1,
            description='Tint region 3', continuous_update=False
        )
        self.s4 = widgets.IntSlider(
            value=1, min=1, max=2, step=1,
            description='Sensor', continuous_update=False
        )

        # Output container
        self.out = widgets.Output()

        # Build UI
        self.controls = widgets.VBox([
            self.s1, self.s2, self.s3, self.s4
        ])

        self.app = widgets.VBox([self.out, self.controls])

        # Connect sliders â†’ update on change
        for slider in [self.s1, self.s2, self.s3, self.s4]:
            slider.observe(self._update_surface, names='value')

        # Initial render
        self._update_surface()


    # -------------------------------------------------
    # Core Logic
    # -------------------------------------------------
    def _predict(self, u1, u2, u3, target):
        """
        Generate 3D surface data from model prediction.
        """
        # Spherical grid
        zens = np.linspace(0, np.pi/2, self.res)
        azis = np.linspace(0, 2*np.pi, self.res)
        zens, azis = np.meshgrid(zens, azis)

        # Flatten & build feature matrix
        z_flat = zens.flatten()
        a_flat = azis.flatten()
        ones = np.ones((len(z_flat), 1))

        features = np.column_stack((z_flat, a_flat,
                                   np.full(len(z_flat), u1),
                                   np.full(len(z_flat), u2),
                                   np.full(len(z_flat), u3)))

        # Predict
        r = self.model.predict(features)[:, target-1]

        # Convert to Cartesian coordinates
        X = r * np.sin(z_flat) * np.cos(a_flat)
        Y = r * np.sin(z_flat) * np.sin(a_flat)
        Z = r * np.cos(z_flat)

        # Reshape to 2D grid
        return (
            X.reshape(zens.shape),
            Y.reshape(zens.shape),
            Z.reshape(zens.shape)
        )


    def _make_figure(self, u1, u2, u3, target):
        """Create Plotly figure for current parameters."""
        X, Y, Z = self._predict(u1, u2, u3, target)

        fig = go.Figure(
            data=[go.Surface(
                x=X, y=Y, z=Z,
                colorscale='Viridis',
                cmin=0, cmax=0.2,
                colorbar=dict(title='Value')
            )]
        )

        fig.update_layout(
            width=700,
            height=500,
            margin=dict(l=0, r=0, t=30, b=0),
            title=f'Incidence operator of sensor {target}',
            scene=dict(
                aspectmode='cube',
                xaxis=dict(range=[-0.4, 0.4]),
                yaxis=dict(range=[-.4, 0]),
                zaxis=dict(range=[0, .4]),
                camera=dict(eye=dict(x=0.1, y=-1.5, z=0.5))
            ),
        )
        return fig

    # -------------------------------------------------
    # Update Handler
    # -------------------------------------------------
    def _update_surface(self, change=None):
        """Update the 3D plot when any slider changes."""
        with self.out:
            self.out.clear_output(wait=True)
            fig = self._make_figure(
                self.s1.value,
                self.s2.value,
                self.s3.value,
                self.s4.value
            )
            display(fig)

    # -------------------------------------------------
    # Display the UI
    # -------------------------------------------------
    def show(self):
        """Render the widget in the notebook."""
        display(self.app)

In [None]:
visualizer = SurfaceVisualizer(model=model, res=100)
visualizer.show()