In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from ipywidgets import interactive, FloatSlider, VBox, Label, Layout
from IPython.display import display

# For reproducibility
np.random.seed(42)


def plot_description(text):
    print(f"\nDescription:\n{text}\n")


def binary_classification_demo(separation=1.0):
    # ---------------------------------------------------
    # 1. Create binary dataset
    # ---------------------------------------------------
    np.random.seed(42)
    X, y = make_classification(
        n_samples=200, n_features=2, n_informative=2, n_redundant=0,
        n_clusters_per_class=1, class_sep=separation, random_state=42
    )
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    # ---------------------------------------------------
    # 2. Train logistic regression model
    # ---------------------------------------------------
    model = LogisticRegression()
    model.fit(X, y)
    
    # ---------------------------------------------------
    # 3. Visualize decision boundary
    # ---------------------------------------------------
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
    Z = model.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
    Z = Z.reshape(xx.shape)

    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, 1-Z, cmap="RdYlBu", alpha=0.5)
    plt.scatter(X[y==0, 0], X[y==0, 1], color="blue", label="Class 0", edgecolor="black", linewidth=0.5)
    plt.scatter(X[y==1, 0], X[y==1, 1], color="red", label="Class 1", edgecolor="black", linewidth=0.5)
    plt.title("Binary Classification with Logistic Regression", fontsize=16)
    plt.xlabel("Feature 1", fontsize=14)
    plt.ylabel("Feature 2", fontsize=14)
    plt.legend(title="Ground Truth")
    plt.grid(True)
    plt.show()

def binary_classification_demo_interact():
    # ---------------------------------------------------
    # 4. Visualize Binary Cross-Entropy loss function
    # ---------------------------------------------------
    p = np.linspace(0.001, 0.999, 200)
    loss_positive = -np.log(p)
    loss_negative = -np.log(1 - p)

    plot_description("We are looking at a binary classification problem (i.e., assigning a sample to one of two classes). The Figure"
                     " below shows an example visualization of the Binary Cross-Entropy (BCE) loss function that was used to train the "
                     "logistic regression model. The BCE function calculates the loss values based on the orange line (for samples of class 0) "
                    "or based on the blue line (for samples of class 1).")
    
    plt.figure(figsize=(6, 4))
    plt.plot(p, loss_positive, label="Loss (True Label = 1)", linewidth=2)
    plt.plot(p, loss_negative, label="Loss (True Label = 0)", linewidth=2)
    plt.title("Binary Cross-Entropy Loss", fontsize=14)
    plt.xlabel("Predicted Probability (p)", fontsize=12)
    plt.ylabel("Loss", fontsize=12)
    plt.legend()
    plt.grid(True)
    plt.show()

    plot_description("Looking at the plot below, the blue and rot dots represent training samples of class 0 and class 1 respectively."
                     " The background shows prediction probabilities of the logisitic regression model."
                     " Bluish background: High probability for Class 0. Reddish background: High probability for Class 1. \n"
                     "The model assigns class 0 or 1 to a new \"unseen\" sample, based on these probabilities. "
                     "Notice how prediction uncertainty increases as training samples of both classes get mixed (play around with the slider below).")
    
    # --- Interactive control ---
    sep_slider = FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description="Move Samples",
        style={'description_width': '150px'},
        layout=Layout(width='500px')
    )

    ui_box = VBox([
        Label(value="ðŸ“Š Controls", layout=Layout(margin="0 0 0 0")),
    ])
    interactive_plot = interactive(binary_classification_demo, separation=sep_slider)

    display(ui_box, interactive_plot)