In [None]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, HTML

# ----------------------------------------------------------
# Styling
# ----------------------------------------------------------
display(HTML("""
<style>
    .widget-button {
        background-color: #1976D2 !important;
        color: white !important;
        border: none !important;
        font-weight: bold !important;
        border-radius: 4px !important;
        height: 32px !important;
        margin-top: 5px !important;
        text-transform: none !important;
    }
    .widget-button:hover {
        background-color: #1565C0 !important;
        transform: translateY(-1px);
        box-shadow: 0 2px 4px rgba(0,0,0,0.2);
    }
    .widget-button:active {
        background-color: #0D47A1 !important;
    }
    .control-group {
        display: flex;
        align-items: center;
        margin-right: 15px;
    }
    .control-group label {
        margin-right: 5px;
        min-width: 120px;
    }
    .explanation-container {
        margin: 10px 0;
        padding: 10px;
        border: 1px solid #E0E0E0;
        border-radius: 5px;
    }
</style>
"""))

colors = {
    'primary': '#1976D2',
    'secondary': '#FF5722',
    'background': '#F5F7FA',
    'grid': '#E0E0E0',
    'text': '#212121',
    'residual': '#1976D2',
    'leverage': '#00796B',
    'cook': '#D32F2F',
    'regression': '#D32F2F'
}

rng = np.random.default_rng(1337)

# ----------------------------------------------------------
# Data generation
# ----------------------------------------------------------
def generate_data(n=15, distribution='normal', seed=None):
    global rng
    if seed is not None:
        rng = np.random.default_rng(seed)
    
    if distribution == 'normal':
        x = np.sort(rng.normal(50, 20, n))
        y = 2 + 0.8 * x + rng.normal(0, 8, n)
    elif distribution == 'unequal':
        x = np.sort(np.linspace(10, 90, n) + rng.normal(0, 2, n))
        y = 2 + 0.8 * x + rng.normal(0, x/10, n)  # Heteroscedasticity
    elif distribution == 'data':
        x = np.array([10.49,  6.13,  5.17, 13.63,  8.96,  7.74,  7.10,  9.88, 11.84, 11.85,
     8.83, 10.96, 16.56,  2.76,  3.35,  8.01, 13.42,  9.45, 14.55,  9.37,
    11.13, 11.14, 11.23,  7.54,  6.56,  9.73,  8.20, 12.12,  9.57,  9.74])
        y = np.array([12.49, 11.66,  2.11, 22.98, 14.38,  8.74,  7.13, 17.15, 15.58, 18.28,
    11.85, 19.57, 18.50, 10.46,  6.45,  5.15, 16.88, 17.59, 18.30, 13.27,
    14.37, 15.62, 16.85, 16.72, 11.97, 14.84,  7.27, 17.12, 11.98, 12.75])
    else:
        x = np.sort(rng.normal(50, 20, n))
        y = 2 + 0.8 * x + rng.normal(0, 8, n)
    
    return x, y

# ----------------------------------------------------------
# Initial data
# ----------------------------------------------------------
n_points = 15
distribution = 'normal'
x_data, y_data = generate_data(n_points, distribution=distribution, seed=42)

selected_point = None

# ----------------------------------------------------------
# Matplotlib configuration
# ----------------------------------------------------------
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['axes.edgecolor'] = '#BDBDBD'
plt.rcParams['axes.linewidth'] = 1.2

fig, (ax_reg, ax_lev, ax_cook) = plt.subplots(3, 1, figsize=(7.5, 9), sharex=False)
fig.suptitle('Interactive Regression Analysis', fontsize=15, fontweight='bold', y=0.98)

# ----------------------------------------------------------
# Helper: compute regression, residuals, leverage, and Cook's distance
# ----------------------------------------------------------
def compute_regression_stats(x, y):
    """Compute regression line, residuals, leverage, and Cook's distance"""
    # Linear regression: y = beta0 + beta1 * x
    X = np.column_stack([np.ones_like(x), x])
    beta = np.linalg.lstsq(X, y, rcond=None)[0]
    y_hat = X @ beta
    residuals = y - y_hat
    
    # Leverage (hat values)
    XtX_inv = np.linalg.inv(X.T @ X)
    H = X @ XtX_inv @ X.T
    h = np.diag(H)
    
    # Cook's distance
    n, p = len(x), 2
    mse = np.sum(residuals**2) / max(n - p, 1)
    cooks = (residuals**2 / (p * max(mse, 1e-12))) * (h / (1 - h)**2)
    
    return beta, y_hat, residuals, h, cooks

# ----------------------------------------------------------
# Initial plots
# ----------------------------------------------------------
beta, y_hat, residuals, leverages, cooks = compute_regression_stats(x_data, y_data)

# 1) Regression scatter + line
scat_reg = ax_reg.scatter(x_data, y_data, s=60, c=colors['residual'],
                         alpha=0.85, edgecolor='white', linewidth=1.2, picker=5, zorder=5)
line_reg, = ax_reg.plot(x_data, y_hat, color=colors['regression'], 
                       linewidth=3, label='Regression', zorder=4)
ax_reg.set_ylabel('y', fontweight='bold')
ax_reg.set_title('Regression: Drag points to refit line')
ax_reg.grid(True, linestyle='--', alpha=0.6, color=colors['grid'])


# 2) Leverage plot
scat_lev = ax_lev.scatter(x_data, leverages, s=60, c=colors['leverage'],
                         alpha=0.85, edgecolor='white', linewidth=1.2)
#ax_lev.axhline(np.mean(leverages), color='#9E9E9E', linestyle='--', linewidth=1)
ax_lev.set_ylabel('Leverage $h_{ii}$', fontweight='bold')
ax_lev.set_title('Leverage of each point')
ax_lev.grid(True, linestyle='--', alpha=0.6, color=colors['grid'])

# 3) Cook's distance plot
scat_cook = ax_cook.scatter(x_data, cooks, s=60, c=colors['cook'],
                           alpha=0.85, edgecolor='white', linewidth=1.2)
ax_cook.set_ylabel("Cook's distance", fontweight='bold')
ax_cook.set_xlabel('x', fontweight='bold')
ax_cook.set_title('Cook\'s distance of each point')
ax_cook.grid(True, linestyle='--', alpha=0.6, color=colors['grid'])


fig.tight_layout(rect=[0, 0, 1, 0.95])

# ----------------------------------------------------------
# Widgets and explanation box
# ----------------------------------------------------------
explanation_out = widgets.Output()

dist_selector = widgets.Dropdown(
    options=[('Normal', 'normal'),
             ('Unequal variance', 'unequal'),
             ('Example data', 'data')],
    value='normal',
    description='Preset:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='230px')
)

seed_input = widgets.BoundedIntText(
    value=42,
    min=0,
    max=1_000_000,
    step=1,
    description='Random seed:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='160px')
)

n_points_input = widgets.BoundedIntText(
    value=n_points,
    min=5,
    max=50,
    step=1,
    description='Points:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='130px')
)

reset_button = widgets.Button(
    description='Reset data',
    layout=widgets.Layout(width='150px', height='32px'),
)
reset_button.add_class('widget-button')

# ----------------------------------------------------------
# Update functions
# ----------------------------------------------------------
def update_plots():
    global x_data, y_data, scat_reg, scat_lev, scat_cook, line_reg
    
    # Recompute everything
    beta, y_hat, residuals, h, c = compute_regression_stats(x_data, y_data)
    
    # Update regression plot
    scat_reg.set_offsets(np.column_stack([x_data, y_data]))
    line_reg.set_data(x_data, y_hat)
    
    # Update diagnostics
    scat_lev.set_offsets(np.column_stack([x_data, h]))
    scat_cook.set_offsets(np.column_stack([x_data, c]))
    
    # Adjust limits
    margin = 0.05
    x_range = x_data.max() - x_data.min()
    ax_reg.set_xlim(x_data.min() - 0.1*x_range, x_data.max() + 0.1*x_range)
    y_range = y_data.max() - y_data.min()
    ax_reg.set_ylim(y_data.min() - 0.1*y_range, y_data.max() + 0.1*y_range)
    
    ax_lev.set_xlim(x_data.min() - 0.1*x_range, x_data.max() + 0.1*x_range)
    ax_lev.set_ylim(max(0, h.min() - 0.02), h.max() + 0.05)
    ax_cook.set_xlim(x_data.min() - 0.1*x_range, x_data.max() + 0.1*x_range)
    ax_cook.set_ylim(0, max(c.max() * 1.1, 0.1))
    
    # Update explanation
    r2 = 1 - np.sum(residuals**2) / np.sum((y_data - np.mean(y_data))**2)
    with explanation_out:
        explanation_out.clear_output(wait=True)
        display(HTML(f"""
        <div class="explanation-container">
            <h4 style="margin:0 0 8px 0; color:{colors['primary']}">Live Statistics:</h4>
            <p><b>Regression:</b> $\\hat{{y}} = {beta[1]:.2f}x + {beta[0]:.2f}$ | 
               <b>RÂ²:</b> {r2:.3f} | 
               <b>MSE:</b> {np.mean(residuals**2):.2f}</p>
            <ul style="margin:4px 0 0 18px; padding:0;">
                <li><b>Top:</b> Data + regression line (drag any point)</li>
                <li><b>Middle:</b> Leverage values $h_{{ii}}$</li>
                <li><b>Bottom:</b> Cook's distances (influence)</li>
            </ul>
        </div>
        """))
    
    fig.canvas.draw_idle()

def reset_data(_=None):
    global x_data, y_data, n_points, distribution
    distribution = dist_selector.value
    n_points = n_points_input.value
    seed = seed_input.value
    
    x_data, y_data = generate_data(n_points, distribution=distribution, seed=seed)
    update_plots()

def on_param_change(change):
    if change["name"] == "value":
        reset_data()

reset_button.on_click(reset_data)
dist_selector.observe(on_param_change, names='value')
seed_input.observe(on_param_change, names='value')
n_points_input.observe(on_param_change, names='value')

# ----------------------------------------------------------
# Mouse interaction: drag ANY data point (x AND y)
# ----------------------------------------------------------
def on_pick(event):
    global selected_point
    if event.artist is scat_reg:
        if len(event.ind):
            selected_point = event.ind[0]
            # Highlight selected point
            colors_list = [colors['secondary'] if i == selected_point else colors['residual']
                          for i in range(len(x_data))]
            scat_reg.set_color(colors_list)
            fig.canvas.draw_idle()

def on_motion(event):
    global selected_point, x_data, y_data
    if selected_point is None:
        return
    if event.inaxes is not ax_reg:
        return
    if event.xdata is None or event.ydata is None:
        return
    
    # Update BOTH x and y coordinates of the dragged point
    x_data[selected_point] = event.xdata
    y_data[selected_point] = event.ydata
    #x_data = np.sort(x_data)  # Keep sorted for nicer plots
    update_plots()

def on_release(event):
    global selected_point
    if selected_point is not None:
        scat_reg.set_color(colors['residual'])
        fig.canvas.draw_idle()
    selected_point = None

fig.canvas.mpl_connect('pick_event', on_pick)
fig.canvas.mpl_connect('motion_notify_event', on_motion)
fig.canvas.mpl_connect('button_release_event', on_release)

# ----------------------------------------------------------
# Layout
# ----------------------------------------------------------
controls = widgets.HBox([
    widgets.Box([dist_selector], layout=widgets.Layout(margin='0 10px 0 0')),
    widgets.Box([seed_input], layout=widgets.Layout(margin='0 10px 0 0')),
    widgets.Box([n_points_input], layout=widgets.Layout(margin='0 10px 0 0')),
    reset_button
])

display(widgets.VBox([controls, explanation_out]))
update_plots()
