# Gaussian Linear Model

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import ipywidgets as widgets
from ipywidgets import interact

# Interactive plot function
def plot_interactive(x=4, a=0.5, b=1, var_noise=1):
    # Ensure variance of noise is positive
    var_noise = max(var_noise, 1e-6)  # Prevent non-positive values

    # Generate data
    np.random.seed(42)  # For reproducibility
    n = 100  # Number of points
    X = np.random.uniform(-10, 10, n)  # X values sampled uniformly
    noise = np.random.normal(0, np.sqrt(var_noise), n)  # Gaussian noise
    Y = a + b * X + noise  # Y values

    mean_mu_x = a + b * x
    mu = np.linspace(-3 * np.sqrt(var_noise) + mean_mu_x, 
                      3 * np.sqrt(var_noise) + mean_mu_x, 500)
    pdf_mu = norm.pdf(mu, loc=mean_mu_x, scale=np.sqrt(var_noise))

    plt.figure(figsize=(10, 6))
    plt.scatter(X, Y, color='blue', alpha=0.5, label='Sampled (Xi, Yi)')
    plt.plot(np.sort(X), a + b * np.sort(X), color='red', linewidth=2, 
             label=r'$\mu(x) = E[Y|X=x] = $' + f"{a:.2f}+{b:.2f}" + '$x$')
    plt.plot(x + pdf_mu, mu, color="green", 
             label=f"$Y|X=x \\sim N({mean_mu_x:.1f}, {var_noise:.1f})$")
    plt.axvline(x=x, color="green", linestyle="--", 
                label=f"x={x:.4f}", linewidth=0.7)
    plt.axhline(0, color='black', linewidth=0.7, linestyle='--')
    plt.axvline(0, color='black', linewidth=0.7, linestyle='--')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title(f'Linear Regression: Y = {a:.4f} + {b:.4f}X + Noise')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()


def reset_parameters(b):
    x_slider.value = 4
    a_slider.value = 0.5
    b_slider.value = 1.0
    var_noise_slider.value = 1.0

# Style for sliders
slider_style = {'description_width': '150px'}  # Increase text width for descriptions
slider_layout = widgets.Layout(width='500px')  # Make sliders wider

x_slider = widgets.FloatSlider(value=4, min=-10, max=10, step=0.1, description='Value of x:', style=slider_style, layout=slider_layout)
a_slider = widgets.FloatSlider(value=0.5, min=-2, max=2, step=0.1, description='Intercept (a):', style=slider_style, layout=slider_layout)
b_slider = widgets.FloatSlider(value=1.0, min=-2, max=2, step=0.1, description='Slope (b):', style=slider_style, layout=slider_layout)
var_noise_slider = widgets.FloatSlider(value=1, min=0.1, max=5, step=0.1, description='Variance of Noise:', style=slider_style, layout=slider_layout)

# Create a reset button
reset_button = widgets.Button(description="Reset", layout=widgets.Layout(width='400px'))
reset_button.on_click(reset_parameters)

# Display the widgets
display(reset_button)

# Create interactive widget
interact(
    plot_interactive, 
    x=x_slider,
    a=a_slider,
    b=b_slider,
    var_noise=var_noise_slider,
);


Button(description='Reset', layout=Layout(width='400px'), style=ButtonStyle())

interactive(children=(FloatSlider(value=4.0, description='Value of x:', layout=Layout(width='500px'), max=10.0…