# 2. LP Model for Finding a Single MSM Plane

This is the Julia code for the MATH271 class at Carleton College in Spring 2025 term with Rob Thompson.

Final Project: Understanding Breast Cancer Diagnosis and Prognosis via Linear Programming
Authors: Palmy Klangsathorn and Angelina Kong
Date: 06/05/2025

This project explores the application of linear programming techniques to the problem of breast cancer diagnosis and prognosis. We will utilize the Wisconsin Breast Cancer Diagnostic (WBCD) dataset to build and evaluate a predictive model. Our approach will focus on formulating the classification task as a linear program, contrasting it with traditional methods, and analyzing its performance through various statistical metrics and visualizations.

Moreover, we reproduced the Multisurface Method (MSM) and the Multisurface Method Tree (MSM-T) to classify whether a tumor is malignant or benign. The MSM uses an LP model that constructs multiple series of separating planes in a space that divides the data into being either malignant or benign. To do this, let any two points be linearly separable; a plane will be placed between them. Each plane is found by solving a LP problem where the objective is to create the largest distance between the plane and the data points. If there is a case where the two points are not linearly separable, the MSM-T is introduced. MSM-T can construct a plane that will minimize the average distance of misclassified points to the plane, which reduces the number of misclassified points. Due to this, the MSM-T will classify breast cancer samples into being malignant or benign accurately from its use of a decision tree to classify the new points. Thus, MSM and MSM-T produces an interpretable and efficient classification of malignant and benign tumors.   

In [None]:
# this function read the dataset which is a text file and make a dictionary of the dataset
function read_to_dict(file_path::String)
  data_dict = Dict{String, Vector{Any}}()
  try
      open(file_path, "r") do io
          for line in eachline(io)
              # split each element in the line by comma
              elements = split(line, ',')

              if isempty(elements) || all(isempty, elements)
                  continue
              end

              # the first element is the key and the rest form the list
              key = String(strip(elements[1]))

              # For 'wdbc.data', the first element is an ID, the second is 'M' or 'B' (diagnosis),
              # and the rest are floating-point numbers.
              data_list = Vector{Any}()
              push!(data_list, String(strip(elements[2]))) # diagnosis 'M' or 'B'
              for i in 3:length(elements)
                  val_str = strip(elements[i])
                  if !isempty(val_str)
                      try
                          push!(data_list, parse(Float64, val_str))
                      catch
                          push!(data_list, val_str)
                      end
                  end
              end
              data_dict[key] = data_list
          end
      end
  catch e
      if isa(e, SystemError) && e.errnum == Base.UV_ENOENT
          println("Error: File not found")
      else
          rethrow(e)
      end
  end
  return data_dict
end

# our dataset
file_path = "wdbc.data"
wdbc_data = read_to_dict(file_path)
println("Number of entries in dictionary: $(length(wdbc_data))")

# ----------------------------------------------------------------
# check check
specific_key = "848406" 
if haskey(wdbc_data, specific_key)
    println("\nData for key '$specific_key':")
    println(wdbc_data[specific_key])
else
    println("\nKey '$specific_key' not found in the dictionary.")
end

# columns info
col_names = [
    "ID", "Diagnosis",
    "radius1", "texture1", "perimeter1", "area1", "smoothness1", "compactness1", "concavity1", "concave_points1", "symmetry1", "fractal_dimension1",
    "radius2", "texture2", "perimeter2", "area2", "smoothness2", "compactness2", "concavity2", "concave_points2", "symmetry2", "fractal_dimension2",
    "radius3", "texture3", "perimeter3", "area3", "smoothness3", "compactness3", "concavity3", "concave_points3", "symmetry3", "fractal_dimension3"
]

print("\nThere are  $(length(col_names)) columns")
print("\nThere are  $(length(col_names)-2) features")

In [None]:
# Step 2: Data Prepocessing

using DataFrames

# We need to create a DataFrame because it provides a structured and efficient way
# to store and manipulate tabular data. It also allows us to apply common data operations 
# like filtering, grouping, and summarizing.
# convert all column names to Symbols, which is the standard for DataFrames
column_names = Symbol.(col_names)

row_tuples = []
for (id, values) in wdbc_data
    # create a vector for the current row's values
    current_row_values = Vector{Any}([id, values[1]]) # start with ID and Diagnosis
    append!(current_row_values, values[2:end]) # append all features

    # Create a NamedTuple for the current row
    # This maps the column names (Symbols) to their respective values
    push!(row_tuples, NamedTuple{Tuple(column_names)}(current_row_values))
end

# Create the DataFrame
df = DataFrame(row_tuples)

# Now, we can run our subsequent Julia code on 'df'

In [None]:
# Step 3: Encode diagnosis: M = 1, B = -1

# The 'Diagnosis' column in the original data contains categorical string labels:
# "M" for malignant and "B" for benign tumors.

# Most machine learning algorithms require numeric inputs, not strings. 
# So we need to convert these string labels into numeric values.

# In this step, we use the `ifelse.` broadcasting function to perform element-wise conversion:
# - If the diagnosis is "M" (malignant), we encode it as 1.0
# - If the diagnosis is "B" (benign), we encode it as -1.0

df.Diagnosis = ifelse.(df.Diagnosis .== "M", 1, -1)

In [None]:
# Step 4: Extract features and normalize (using StatsBase.zscore)
# using Pkg
# Pkg.activate(".")
# Pkg.instantiate()

# Pkg.add("StatsBase")
using StatsBase 
using DataFrames

# Extract features from 3rd column to end
X = Matrix(df[:, 3:end]) 
y = df.Diagnosis

X = zscore(X, 1) # Normalize columns (dims=1 means normalize each column)

# Now, X is normalized

# By applying zscore(X, 1), we transform each column (feature) of X so that it has a mean of 0 and a standard deviation of 1. 
# features in the original dataset often have different scales and units (like "radius" might be in millimeters, "area" in square millimeters,
# "smoothness" a dimensionless value). Without normalization, features with larger numerical ranges or higher magnitudes would implicitly 
# contribute more to the distance calculations and objective functions of algorithms. Normalization ensures that all features contribute 
# proportionally to the model, preventing features with larger scales from dominating the learning process

# confirm sizes
println("\n--- Checking sizes before shuffleobs ---")
println("Size of X (rows, cols): ", size(X))
println("Length of y (rows): ", length(y))
println("Are X rows and y length equal? ", size(X, 1) == length(y))
println("------------------------------------")

In [None]:
# Step 5: Train-test split 

# using Pkg
# Pkg.activate(".")
# Pkg.instantiate()
using MLJBase 
using Random 

# get the total number of observations
num_observations = size(X, 1)

# create a vector of indices
indices = collect(1:num_observations)

# shuffle the indices in-place
Random.shuffle!(indices)

# determine split point
split_point = floor(Int, num_observations * 0.8)

# split the shuffled indices into train and test sets
train_indices = indices[1:split_point]
test_indices = indices[split_point+1:end]

# use these shuffled indices to create your train and test sets
X_train = X[train_indices, :]
y_train = y[train_indices]

X_test = X[test_indices, :]
y_test = y[test_indices]

y_train_numeric = convert(Vector{Int}, y_train)
y_test_numeric = convert(Vector{Int}, y_test)


println("X_train dimensions: ", size(X_train))
println("y_train length: ", length(y_train))
println("X_test dimensions: ", size(X_test))
println("y_test length: ", length(y_test))

In [None]:
using JuMP
using HiGHS 
using LinearAlgebra 
using Statistics 
using Plots 
using DataFrames 
using CSV 
using Random 
import Pkg; Pkg.add("MathOptInterface")
import MathOptInterface as MOI 

# This function finds one separating plane for a given subset of data
function solve_msm_plane_lp(X_subset::Matrix{Float64}, y_subset::Vector{Int64}, C_msm::Float64)
    n_subset, d_subset = size(X_subset)

    model_msm = JuMP.Model(HiGHS.Optimizer)
    set_silent(model_msm)

    # Variables
    @variable(model_msm, w[1:d_subset])
    @variable(model_msm, b)
    @variable(model_msm, zeta[1:n_subset] >= 0) # Slack variables for misclassification
    @variable(model_msm, rho >= 0) # Margin variable

    # Auxiliary variable
    @variable(model_msm, t[1:d_subset] >= 0) # Auxiliary variables for absolute values

    # Objective: Maximize the margin (rho) while penalizing misclassification (zeta)
    @objective(model_msm, Max, rho - C_msm * sum(zeta))


    for i in 1:n_subset
        @constraint(model_msm, y_subset[i] * (dot(X_subset[i,:], w) + b) >= rho - zeta[i])
    end

    # linearized L1 norm
    for j in 1:d_subset
        @constraint(model_msm, w[j] <= t[j])
        @constraint(model_msm, w[j] >= -t[j])
    end
    @constraint(model_msm, sum(t) <= 1.0) # Sum of t_j replaces sum(abs.(w))

    optimize!(model_msm)

    status = termination_status(model_msm)

    if status == MOI.OPTIMAL
        return value.(w), value(b), value(rho), status
    else
        println("solve_msm_plane_lp failed with status: $status")
        return zeros(d_subset), 0.0, 0.0, status
    end
end



In [None]:
# this orchestrates finding multiple planes
function run_msm(X_train::Matrix{Float64}, y_train_numeric::Vector{Int64};
                 max_planes::Int=5, C_msm::Float64=1.0, convergence_threshold::Float64=0.01)

    all_w = Vector{Vector{Float64}}() 
    all_b = Vector{Float64}()         
    all_rho = Vector{Float64}()  

    n_samples = size(X_train, 1)
    current_misclassified_indices = collect(1:n_samples)
    previous_num_misclassified = n_samples + 1

    println("\n--- Starting MSM Training ---")

    for k in 1:max_planes
        if isempty(current_misclassified_indices)
            println("Iteration $k: All points classified. Stopping.")
            break
        end

        X_subset = X_train[current_misclassified_indices, :]
        y_subset = y_train_numeric[current_misclassified_indices]

        println("Iteration $k: Training plane on $(length(current_misclassified_indices)) misclassified points.")

        w_k, b_k, rho_k, status = solve_msm_plane_lp(X_subset, y_subset, C_msm)

        if status != MOI.OPTIMAL
            println("Iteration $k: solve_msm_plane_lp did not find an optimal solution. Stopping.")
            break 
        end

        push!(all_w, w_k)
        push!(all_b, b_k)
        push!(all_rho, rho_k)

        # Re-evaluate all original training points against the *set of all found planes*
        # This is a critical part of MSM: how to define 'misclassified' over multiple planes.
        # A simple approach: A point is classified if *any* plane classifies it correctly.
        # Otherwise, it remains misclassified.
        
        y_overall_pred = fill(0, n_samples) # 0 for unclassified, -1 or 1 for classified
        for i in 1:n_samples
            point_classified = false
            for p_idx in 1:length(all_w) # Check against all planes found so far
                val = dot(X_train[i,:], all_w[p_idx]) + all_b[p_idx]
                
                # If the point falls within the margin of this plane, it's not "clearly" classified by *this* plane
                # For simplicity here, we consider it classified if it's beyond the margin on the correct side
                # A more nuanced approach would be needed for a full MSM.
                if val > all_rho[p_idx] # Classified as +1 by this plane, outside its margin
                    if y_train_numeric[i] == 1
                        y_overall_pred[i] = 1 # Correctly classified
                        point_classified = true
                        break
                    end
                elseif val < -all_rho[p_idx] # Classified as -1 by this plane, outside its margin
                    if y_train_numeric[i] == -1
                        y_overall_pred[i] = -1 # Correctly classified
                        point_classified = true
                        break
                    end
                end
            end
            if !point_classified && y_overall_pred[i] == 0 # If still not classified by any plane
                # If it's a misclassified point, keep it in the current_misclassified_indices set.
                # For this simplified MSM, a point is "misclassified" if it's not correctly classified by any plane yet.
                # This logic can be more complex for different MSM variants.
                # Here, we mark it as unclassified if no plane correctly pushes it beyond its margin.
                # A more precise misclassification check for MSM depends on its specific variant.
                # For simplicity, let's say a point is *not* misclassified if *any* plane has successfully pushed it.
                # This might be tricky for "deepest plane" concepts.
                
                # Let's simplify: A point is misclassified if for *all* planes, it's not on the correct side beyond the margin.
                # Rebuilding misclassified_indices more directly:
            end
        end

        # Recalculate current_misclassified_indices based on *all* planes found so far
        temp_misclassified_indices = Int[]
        for i in 1:n_samples
            is_correctly_classified_by_any_plane = false
            for p_idx in 1:length(all_w)
                val = dot(X_train[i,:], all_w[p_idx]) + all_b[p_idx]
                if (y_train_numeric[i] == 1 && val > all_rho[p_idx]) ||
                   (y_train_numeric[i] == -1 && val < -all_rho[p_idx])
                    is_correctly_classified_by_any_plane = true
                    break
                end
            end
            if !is_correctly_classified_by_any_plane
                push!(temp_misclassified_indices, i)
            end
        end
        current_misclassified_indices = temp_misclassified_indices


        num_misclassified = length(current_misclassified_indices)
        misclassification_ratio = num_misclassified / n_samples

        println("Iteration $k Summary:")
        println("  Planes found so far: $(length(all_w))")
        println("  Current misclassified points: $num_misclassified / $n_samples ($(round(misclassification_ratio*100, digits=2))%)")

        # Convergence criteria
        if num_misclassified == 0
            println("All training points perfectly classified. Stopping.")
            break
        end
        if misclassification_ratio <= convergence_threshold
            println("Convergence threshold reached. Stopping.")
            break
        end
        if num_misclassified >= previous_num_misclassified && k > 1 # No progress
            println("No further reduction in misclassified points. Stopping.")
            break
        end
        previous_num_misclassified = num_misclassified
    end

    println("--- MSM Training Finished ---")
    return all_w, all_b, all_rho
end

In [None]:
# Prediction, classify a new point using multiple planes
function predict_msm(x::Vector{Float64}, all_w::Vector{Vector{Float64}}, all_b::Vector{Float64}, all_rho::Vector{Float64})
    positive_votes = 0
    negative_votes = 0

    for k in 1:length(all_w)
        w_k = all_w[k]
        b_k = all_b[k]
        rho_k = all_rho[k]

        val = dot(x, w_k) + b_k

        if val > rho_k # positive side beyond its margin
            positive_votes += 1
        elseif val < -rho_k # negative side beyond its margin
            negative_votes += 1
        end
    end

    if positive_votes > negative_votes
        return 1
    elseif negative_votes > positive_votes
        return -1
    else
        return 0 
    end
end

function predict_msm_batch(X::Matrix{Float64}, all_w::Vector{Vector{Float64}}, all_b::Vector{Float64}, all_rho::Vector{Float64})
    n_samples = size(X, 1)
    y_pred = Vector{Int}(undef, n_samples)
    for i in 1:n_samples
        y_pred[i] = predict_msm(X[i,:], all_w, all_b, all_rho)
    end
    return y_pred
end

In [None]:
# evaluation metrics

function calculate_accuracy(y_true::Vector{Int}, y_pred::Vector{Int})
    correct_predictions = sum(y_true .== y_pred)
    total_predictions = length(y_true)
    return correct_predictions / total_predictions
end

In [None]:
# Train the MSM model
println("\n--- Running MSM Training on Training Data ---")
all_w_msm, all_b_msm, all_rho_msm = run_msm(X_train, y_train_numeric; max_planes=10, C_msm=0.1)

# Make predictions on the training set
y_train_pred_msm = predict_msm_batch(X_train, all_w_msm, all_b_msm, all_rho_msm)
train_accuracy_msm = calculate_accuracy(y_train_numeric, y_train_pred_msm)
println("\nMSM Training Accuracy: $(round(train_accuracy_msm * 100, digits=2))%")

# Make predictions on the test set
println("\n--- Evaluating MSM on Test Data ---")
y_test_pred_msm = predict_msm_batch(X_test, all_w_msm, all_b_msm, all_rho_msm)

test_accuracy_msm = calculate_accuracy(y_test_numeric, y_test_pred_msm)
println("MSM Test Accuracy: $(round(test_accuracy_msm * 100, digits=2))%")

In [None]:
using MLJ 
using CategoricalArrays 
using Printf 

println("\n--- Evaluating MSM on Test Data ---")
y_test_pred_msm = predict_msm_batch(X_test, all_w_msm, all_b_msm, all_rho_msm)

# --- Confusion Matrix and Classification Metrics for MSM ---

# Handling the '0' (uncertain) class from predict_msm_batch:
# For standard binary classification metrics, predictions of '0' need to be resolved
# into either the positive or negative class. A common strategy is to force them
# into a misclassification:
# - If true label is Benign (-1/0), and prediction is 0, treat it as a Malignant (1) prediction (False Positive).
# - If true label is Malignant (1), and prediction is 0, treat it as a Benign (-1/0) prediction (False Negative).

y_test_pred_resolved = [
    if pred == -1
        0 
    elseif pred == 1
        1 
    else 
        true_label == -1 ? 1 : 0
    end
    for (pred, true_label) in zip(y_test_pred_msm, y_test_numeric)
]

y_test_true_resolved = [label == -1 ? 0 : 1 for label in y_test_numeric]


# create CategoricalArrays for MLJ's confusion_matrix
y_true_cat = categorical(y_test_true_resolved, levels=[0, 1], ordered=true)
y_pred_cat = categorical(y_test_pred_resolved, levels=[0, 1], ordered=true)


# Calculate Confusion Matrix
conf_matrix = MLJ.confusion_matrix(y_pred_cat, y_true_cat)
println("\n--- MSM Confusion Matrix (Test Data) ---")
# The rows represent predicted classes, and columns represent true classes.
# For levels [0, 1]:
# conf_matrix[1, 1] is True Negatives (Actual 0, Predicted 0)
# conf_matrix[2, 2] is True Positives (Actual 1, Predicted 1)
# conf_matrix[1, 2] is False Negatives (Actual 1, Predicted 0) - Type II error (Missed Malignant)
# conf_matrix[2, 1] is False Positives (Actual 0, Predicted 1) - Type I error (False Alarm Malignant)
println(conf_matrix)

# Extract core components from the confusion matrix
true_neg = conf_matrix[1, 1] # Actual 0, Predicted 0 (Top-left)
false_pos = conf_matrix[2, 1] # Actual 0, Predicted 1 (Bottom-left)
false_neg = conf_matrix[1, 2] # Actual 1, Predicted 0 (Top-right)
true_pos = conf_matrix[2, 2] # Actual 1, Predicted 1 (Bottom-right)

println("\n--- MSM Classification Metrics (Test Data) ---")
# https://juliaai.github.io/StatisticalMeasures.jl/dev/confusion_matrices/
# Accuracy: (TP + TN) / Total
total_samples = true_pos + true_neg + false_pos + false_neg
accuracy_val = total_samples > 0 ? (true_pos + true_neg) / total_samples : 0.0
println(@sprintf "Accuracy:    %.4f" accuracy_val)

# Precision: TP / (TP + FP) - Of all predicted positives, how many were actually positive?
precision_val = (true_pos + false_pos) > 0 ? true_pos / (true_pos + false_pos) : 0.0
println(@sprintf "Precision:   %.4f" precision_val)

# Recall (Sensitivity): TP / (TP + FN) - Of all actual positives, how many were correctly identified?
recall_val = (true_pos + false_neg) > 0 ? true_pos / (true_pos + false_neg) : 0.0
println(@sprintf "Recall:      %.4f" recall_val)

# F1-Score: 2 * (Precision * Recall) / (Precision + Recall) - Harmonic mean of Precision and Recall
f1_val = (precision_val + recall_val) > 0 ? 2 * (precision_val * recall_val) / (precision_val + recall_val) : 0.0
println(@sprintf "F1-Score:    %.4f" f1_val)

# Specificity: TN / (TN + FP) - Of all actual negatives, how many were correctly identified?
specificity_val = (true_neg + false_pos) > 0 ? true_neg / (true_neg + false_pos) : 0.0
println(@sprintf "Specificity: %.4f" specificity_val)

## Plotting

In [None]:
using Printf

col_names = [
    "ID", "Diagnosis",
    "radius1", "texture1", "perimeter1", "area1", "smoothness1", "compactness1", "concavity1", "concave_points1", "symmetry1", "fractal_dimension1",
    "radius2", "texture2", "perimeter2", "area2", "smoothness2", "compactness2", "concavity2", "concave_points2", "symmetry2", "fractal_dimension2",
    "radius3", "texture3", "perimeter3", "area3", "smoothness3", "compactness3", "concavity3", "concave_points3", "symmetry3", "fractal_dimension3"
]

feature_names = col_names[3:end]


println("\n--- Plotting Feature Coefficients for each MSM Plane ---")
for (idx, w_plane) in enumerate(all_w_msm)
    annotations = []
 
    max_abs_coeff_plane = isempty(w_plane) ? 0.0 : maximum(abs.(w_plane))


    annotation_fontsize = 5
    inside_margin_factor = 0.8
    
    min_coeff_for_inside_text = max_abs_coeff_plane * 0.1 
    
    if iszero(max_abs_coeff_plane)
        outside_margin = 0.05
        min_coeff_for_inside_text = 0.01
    else
        outside_margin = max_abs_coeff_plane * 0.05
    end

    for i in 1:length(w_plane)
        coeff_value = w_plane[i]
        text_label = @sprintf "%.2f" coeff_value 
        
        y_pos = 0.0

        if abs(coeff_value) > min_coeff_for_inside_text
            if coeff_value > 0
                y_pos = coeff_value * inside_margin_factor 
            else 
                y_pos = coeff_value * inside_margin_factor 
            end
        else
            
            if coeff_value > 0
                y_pos = coeff_value + outside_margin
            elseif coeff_value < 0
                y_pos = coeff_value - outside_margin
            else
                y_pos = outside_margin 
            end
        end
        push!(annotations, (i, y_pos, Plots.text(text_label, annotation_fontsize)))
    end

    # the bar plot with the generated annotations
    p_coeffs_msm = bar(1:length(w_plane), w_plane,
                       title="MSM Plane $(idx) Feature Coefficients",
                       xlabel="Feature Name",
                       ylabel="Coefficient Value",
                       xticks=(1:length(feature_names), feature_names), 
                       xrotation=90, 
                       legend=false,
                       size=(1200, 700), 
                       bottom_margin=15Plots.mm, 
                       left_margin=10Plots.mm, 
                       annotations=annotations) 
    
    display(p_coeffs_msm)
    savefig(p_coeffs_msm, "images/msm_plane_$(idx)_coeffs.png")
end

In [None]:
using Plots
using MultivariateStats # this is for PCA
using StatsPlots 

function predict_msm_score(x_new::Vector{Float64}, all_w_msm::Vector{Vector{Float64}}, all_b_msm::Vector{Float64})
    if isempty(all_w_msm) || isempty(all_b_msm) || length(x_new) != length(all_w_msm[1])
        return NaN 
    end
    return dot(x_new, all_w_msm[1]) + all_b_msm[1]
end


function predict_msm_class(X::Matrix{Float64}, all_w_msm::Vector{Vector{Float64}}, all_b_msm::Vector{Float64})
    n_samples = size(X, 1)
    predictions = Vector{Int}(undef, n_samples)

    if isempty(all_w_msm) || isempty(all_b_msm)
        return fill(-99, n_samples)
    end

    for i in 1:n_samples
        x_row = X[i, :]
        score = dot(x_row, all_w_msm[1]) + all_b_msm[1]
        predictions[i] = score > 0 ? 1 : -1
    end
    return predictions
end

#  PCA-reduced to 2 features

d = size(X_train, 2) 

if d == 2
    println("\n--- Generating 2D Decision Boundary Plot for MSM ---")

    x_min, x_max = extrema(X_train[:, 1])
    y_min, y_max = extrema(X_train[:, 2])

    x_buffer = (x_max - x_min) * 0.1
    y_buffer = (y_max - y_min) * 0.1
    plot_x_min, plot_x_max = x_min - x_buffer, x_max + x_buffer
    plot_y_min, plot_y_max = y_min - y_buffer, y_max + y_buffer

    num_grid_points = 100
    test_range_x = range(plot_x_min, stop=plot_x_max, length=num_grid_points)
    test_range_y = range(plot_y_min, stop=plot_y_max, length=num_grid_points)

    Z = Matrix{Float64}(undef, num_grid_points, num_grid_points)
    for (i, x_val) in enumerate(test_range_x)
        for (j, y_val) in enumerate(test_range_y)
            grid_point = [x_val, y_val]
            Z[i, j] = predict_msm_score(grid_point, all_w_msm, all_b_msm)
        end
    end

    p_2d_msm = plot(;
        xlim=(plot_x_min, plot_x_max),
        ylim=(plot_y_min, plot_y_max),
        aspect_ratio=1,
        title="MSM Decision Boundaries (2D Features)",
        xlabel="Feature 1",
        ylabel="Feature 2",
        legend=:outertopright
    )

    contourf!(
        p_2d_msm,
        test_range_x,
        test_range_y,
        Z;
        levels=[0], 
        color=cgrad(:redsblues), 
        alpha=0.7,
        colorbar_title="Predicted Score",
        label="", 
    )

    X1 = X_train[y_train_numeric .== -1, :]
    X2 = X_train[y_train_numeric .== 1, :]

    scatter!(p_2d_msm, X1[:, 1], X1[:, 2]; color=:red, label="Class -1 (Benign)", markershape=:circle, markersize=5)
    scatter!(p_2d_msm, X2[:, 1], X2[:, 2]; color=:blue, label="Class 1 (Malignant)", markershape=:xcross, markersize=5)

    if !isempty(all_w_msm) && !isempty(all_b_msm)
        for (idx, w_plane) in enumerate(all_w_msm)
            b_plane = all_b_msm[idx]
            if length(w_plane) == 2 # it's a 2D plane for plotting its line
                # w[1]*x + w[2]*y + b = 0
                # y = (-w[1]*x - b) / w[2]
                if abs(w_plane[2]) > 1e-6 
                    plot!(p_2d_msm, test_range_x, x -> (-w_plane[1]*x - b_plane) / w_plane[2],
                          label="Plane $(idx)", linestyle=:dot, linewidth=2, color=:black)
                else
                    vline!(p_2d_msm, [-b_plane / w_plane[1]], label="Plane $(idx)", linestyle=:dot, linewidth=2, color=:black)
                end
            end
        end
    end

    display(p_2d_msm)
    savefig(p_2d_msm, "images/msm_2d_decision_boundaries.png")
    println("Displayed 'MSM Decision Boundaries (2D)' plot.")

else
    # ig Data is high-dimensional (d > 2), use PCA for 2D visualization
    println("\n--- Attempting 2D Visualization via PCA for MSM ---")

    # Prepare data for PCA: MultivariateStats.jl expects (features x samples)
    X_train_for_pca = permutedims(X_train)

    # Perform PCA to reduce to 2 dimensions
    M = MultivariateStats.fit(MultivariateStats.PCA, X_train_for_pca; maxoutdim=2)

    # Transform training data to the 2D PCA space
    X_train_pca = MultivariateStats.transform(M, X_train_for_pca) # Result will be 2 x n_samples

    # Transpose back for plotting: (n_samples x 2)
    X_train_pca_plot = permutedims(X_train_pca)

    # Determine plot limits from the transformed data
    x_min_pca, x_max_pca = extrema(X_train_pca_plot[:, 1])
    y_min_pca, y_max_pca = extrema(X_train_pca_plot[:, 2])

    x_buffer_pca = (x_max_pca - x_min_pca) * 0.1
    y_buffer_pca = (y_max_pca - y_min_pca) * 0.1
    plot_x_min_pca, plot_x_max_pca = x_min_pca - x_buffer_pca, x_max_pca + x_buffer_pca
    plot_y_min_pca, plot_y_max_pca = y_min_pca - y_buffer_pca, y_max_pca + y_buffer_pca


    # --- Plot 2a: True Labels in PCA Space ---
    p_pca_true = plot(;
        xlim=(plot_x_min_pca, plot_x_max_pca),
        ylim=(plot_y_min_pca, plot_y_max_pca),
        aspect_ratio=1,
        title="MSM Data in PCA Space (True Labels)",
        xlabel="Principal Component 1",
        ylabel="Principal Component 2",
        legend=:outertopright
    )

    # Separate data points by true class for plotting
    X1_pca = X_train_pca_plot[y_train_numeric .== -1, :]
    X2_pca = X_train_pca_plot[y_train_numeric .== 1, :]

    scatter!(p_pca_true, X1_pca[:, 1], X1_pca[:, 2]; color=:red, label="Class -1 (Benign)", markershape=:circle, markersize=5)
    scatter!(p_pca_true, X2_pca[:, 1], X2_pca[:, 2]; color=:blue, label="Class 1 (Malignant)", markershape=:xcross, markersize=5)

    display(p_pca_true)
    savefig(p_pca_true, "images/msm_pca_true_labels.png")
    println("Displayed 'MSM Data in PCA Space (True Labels)' plot.")
end