In [3]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd

# Step activation
def step(x):
    return (x > 0).astype(int)

# Forward pass
def forward(x, W1, b1, W2, b2):
    h = step(x @ W1.T + b1)     # hidden layer
    out = step(h @ W2.T + b2)   # output layer
    return h, out

# Defaults
default_params = {
    "W11": 1.0, "W12": 1.0, "b1": -0.5,
    "W21": 1.0, "W22": 1.0, "b2": -1.5,
    "Wout1": 1.0, "Wout2": 1.0, "bout": -0.5
}

# Data
X = np.array([[0,0],[0,1],[1,0],[1,1]])
y = np.array([0,1,1,0])

# Plot + truth table
def plot_and_table(params, output_table):
    W1 = np.array([[params["W11"], params["W12"]],
                   [params["W21"], params["W22"]]])
    b1 = np.array([params["b1"], params["b2"]])
    W2 = np.array([[params["Wout1"], params["Wout2"]]])
    b2 = np.array([params["bout"]])

    # grid
    xx, yy = np.meshgrid(np.linspace(-0.25, 1.25, 200),
                         np.linspace(-0.25, 1.25, 200))
    X_grid = np.c_[xx.ravel(), yy.ravel()]

    h_grid = step(X_grid @ W1.T + b1)
    out_grid = step(h_grid @ W2.T + b2)
    regions = [h_grid[:,0], h_grid[:,1], out_grid[:,0]]

    # XOR outputs
    h, y_pred = forward(X, W1, b1, W2, b2)

    # Layout: 1 row, 4 columns (3 plots + table)
    fig, axes = plt.subplots(1, 4, figsize=(14,4),
                             gridspec_kw={'width_ratios':[1,1,1,1.2]})
    titles = ["Hidden 1", "Hidden 2", "Output"]

    # plots
    for ax, Z, title in zip(axes[:3], regions, titles):
        Z = Z.reshape(xx.shape)
        ax.contourf(xx, yy, Z, alpha=0.3,
                    cmap=ListedColormap(['#ffcccc','#ccffcc']))
        ax.set_title(title)
        ax.set_xlim(-0.25, 1.25)
        ax.set_ylim(-0.25, 1.25)
        ax.set_xticks([0,1])
        ax.set_yticks([0,1])
        ax.grid(True)
        for (xi, yi) in zip(X, y):
            color = "green" if yi == 1 else "red"
            ax.scatter(*xi, c=color, s=100, edgecolor='k')

    # truth table in last subplot
    table = pd.DataFrame({
        "x1": X[:,0],
        "x2": X[:,1],
        "h1": h[:,0],
        "h2": h[:,1],
        "target": y,
        "pred": y_pred[:,0],
        "OK": ["‚úì" if yi==yp else "‚úó" for yi, yp in zip(y,y_pred[:,0])] #"‚úì/‚úó": ["‚úÖ" if yi==yp else "‚ùå" for yi, yp in zip(y,y_pred[:,0])]
    })
    axes[3].axis("off")
    tbl = axes[3].table(cellText=table.values,
                        colLabels=table.columns,
                        loc="center")
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.2,1.2)

    # color cells in "pred" column
    n_rows = table.shape[0]
    target_col_idx = list(table.columns).index("target")
    pred_col_idx = list(table.columns).index("pred")
    ok_col_idx = list(table.columns).index("OK")
    for i in range(n_rows):
        correct = table.iloc[i]["target"] == table.iloc[i]["pred"]
        color = "#c8f7c5" if correct else "#f7c5c5"
        tbl[(i+1, target_col_idx)].set_facecolor(color)
        tbl[(i+1, pred_col_idx)].set_facecolor(color)
        tbl[(i+1, ok_col_idx)].set_facecolor(color)

    plt.tight_layout()
    plt.show()

# Sliders
sliders = {k: widgets.FloatSlider(min=-1, max=1, step=0.1, value=v, description=k)
           for k,v in default_params.items()}

# Buttons
button_random = widgets.Button(description="üé≤ Randomize")
button_reset = widgets.Button(description="üîÑ Reset")

# Helper: randomize with boundary in [0,1]^2
def random_hyperplane():
    # Random point in [0,1]^2
    x1, x2 = np.random.rand(2)
    # Random weights
    w1, w2 = np.random.uniform(-2, 2, 2)
    # Compute bias so line crosses point
    b = - (w1 * x1 + w2 * x2)
    return w1, w2, b

def randomize(_):
    # Hidden 1 and 2
    for (W1k, W2k, bk) in [("W11", "W12", "b1"),
                           ("W21", "W22", "b2")]:
        w1, w2, b = random_hyperplane()
        sliders[W1k].value = np.round(w1, 1)
        sliders[W2k].value = np.round(w2, 1)
        sliders[bk].value = np.round(b, 1)
    # Output neuron: same logic but now inputs are h1,h2 ‚àà [0,1]
    w1, w2, b = random_hyperplane()
    sliders["Wout1"].value = np.round(w1, 1)
    sliders["Wout2"].value = np.round(w2, 1)
    sliders["bout"].value = np.round(b, 1)

def reset(_):
    for k,v in default_params.items():
        sliders[k].value = v

button_random.on_click(randomize)
button_reset.on_click(reset)

# Organize sliders in groups
box_hidden1 = widgets.VBox([sliders["W11"], sliders["W12"], sliders["b1"]])
box_hidden2 = widgets.VBox([sliders["W21"], sliders["W22"], sliders["b2"]])
box_output  = widgets.VBox([sliders["Wout1"], sliders["Wout2"], sliders["bout"]])
ui_sliders = widgets.HBox([box_hidden1, box_hidden2, box_output])
ui_buttons = widgets.HBox([button_random, button_reset])

# Output for plots
plot_output = widgets.Output()

def update(**kwargs):
    with plot_output:
        clear_output(wait=True)
        plot_and_table(kwargs, None)

ui = widgets.interactive_output(update, sliders)

# Dashboard: sliders+buttons on top, plots+table below
dashboard = widgets.VBox([ui_sliders, ui_buttons, plot_output])
display(dashboard, ui)


VBox(children=(HBox(children=(VBox(children=(FloatSlider(value=1.0, description='W11', max=1.0, min=-1.0), Flo‚Ä¶

Output()