In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from ipywidgets import interactive, FloatSlider, VBox, Label, Layout
from IPython.display import display
import warnings
warnings.filterwarnings('ignore')

# For reproducibility
np.random.seed(42)


def plot_description(text):
    print(f"\nDescription:\n{text}\n")


def multi_class_classification_demo(separation=1.0):
    # ---------------------------------------------------
    # 1. Create multi-class dataset
    # ---------------------------------------------------
    np.random.seed(42)
    X, y = make_classification(
        n_samples=300, n_features=2, n_informative=2, n_redundant=0,
        n_classes=3, n_clusters_per_class=1, class_sep=separation, random_state=42
    )
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    # ---------------------------------------------------
    # 2. Train logistic regression model for multi-class
    # ---------------------------------------------------
    model = LogisticRegression(multi_class='multinomial', solver='lbfgs')
    model.fit(X, y)
    
    # ---------------------------------------------------
    # 3. Visualize decision boundaries
    # ---------------------------------------------------
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))
    Z_probabilities = model.predict_proba(np.c_[xx.ravel(), yy.ravel()])
        
    Z_max_proba = Z_probabilities.max(axis=1)
    Z_max_proba = Z_max_proba.reshape(xx.shape)
    
    Z_predictions = np.argmax(Z_probabilities, axis=1)
    Z_predictions = Z_predictions.reshape(xx.shape)
    
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))
    plt.suptitle("Multi-Class Classification with Logistic Regression", fontsize=16)
    titles = ["Max. Prediction Probabilities", "Model Decision Boundaries"]
    
    for i, title in enumerate(titles):
        if i == 0:
            zz = Z_max_proba
            cmap = "Blues"
        else:
            zz = Z_predictions
            cmap = "viridis"

        ax[i].set_title(title, fontsize=12)
        ax[i].contourf(xx, yy, zz, cmap=cmap , alpha=0.5)
        ax[i].scatter(X[y==0, 0], X[y==0, 1], color="blue", label="Class 0", edgecolor="black", linewidth=0.5)
        ax[i].scatter(X[y==1, 0], X[y==1, 1], color="red", label="Class 1", edgecolor="black", linewidth=0.5)
        ax[i].scatter(X[y==2, 0], X[y==2, 1], color="green", label="Class 2", edgecolor="black", linewidth=0.5)
        ax[i].set_xlabel("Feature 1", fontsize=12)
        ax[i].set_ylabel("Feature 2", fontsize=12)
        ax[i].legend(title="Ground Truth")
        ax[i].grid(True)
    
    plt.show()

def multi_class_classification_demo_interact():
    plot_description("Example for Multi-Class Classification with a logistic regression model trained using the cross-entropy loss function."
                     " Training samples are shown in blue (class 0), red (class 1) and green (class 2). The model outputs three probability values,"
                     " one for each class, based on the sample's input features (Feature 1 & 2).\n\n"
                     "Left plot: The colored background shows the maximum probability value for each combination of Feature 1 & 2. The darker the "
                     "background, the higher the probability. Notice how the light area (i.e., the model's uncertainty) increases as training samples"
                     " of different classes mix (play around with the slider below).\n\n"
                     "Right plot: The background colors of this plot show the model decisions (i.e., assignment of input feature combinations to one class)."
                     " Purple area => Class 0, Turquoise area => Class 1, Yellow area => Class 2.")
    
    # --- Interactive control ---
    sep_slider = FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description="Move Samples",
        style={'description_width': '150px'},
        layout=Layout(width='500px')
    )
    ui_box = VBox([
        Label(value="ðŸ“Š Controls", layout=Layout(margin="0 0 0 0")),
    ])
    interactive_plot = interactive(multi_class_classification_demo, separation=sep_slider)
    display(ui_box, interactive_plot)