# 3. RSA Optimization

This is the Julia code for the MATH271 class at Carleton College in Spring 2025 term with Rob Thompson.

Final Project: Understanding Breast Cancer Diagnosis and Prognosis via Linear Programming
Authors: Palmy Klangsathorn and Angelina Kong
Date: 06/05/2025

This project explores the application of linear programming techniques to the problem of breast cancer diagnosis and prognosis. We will utilize the Breast Cancer Wisconsin (Prognostic) dataset to build and evaluate a predictive model. Our approach will focus on formulating the classification task as a linear program, contrasting it with traditional methods, and analyzing its performance through various statistical metrics and visualizations.

This part, the Recurrence Surface Approximation (RSA) technique will also build on linear programming to classify and estimate a patient's risk of breast cancer occurring once more. Recurrence Surface Approximation (RSA) can be used to find a prognosis and predicting the likelihood and timing of recurrence. A patient can be classified as being recurrent if the disease has already been observed, but knowing if a patient is not recurrent can be a struggle, as it can not be observed in the future. To address this concern, RSA can be used to better understand what is known, Time to Recur (TTR) and the Disease Free Survival time (DFS). 

In [None]:
# make sure to have these packages
# using Pkg
# Pkg.add("DataFrames")
# Pkg.add("StatsBase")
# Pkg.add("LinearAlgebra")
# Pkg.add("JuMP")
# Pkg.add("HiGHS") 
# Pkg.add("MLJBase") 

# load necessary packages
using DataFrames
using StatsBase
using LinearAlgebra 
using JuMP
using HiGHS 
import MathOptInterface as MOI 
using Random 
using MLJBase

In [None]:
# 1. load data

function read_to_dict(file_path::String)
    data_dict = Dict{String, Vector{Any}}()
    try
        open(file_path, "r") do io
            for line in eachline(io)
                elements = split(line, ',')
                if isempty(elements) || all(isempty, elements)
                    continue
                end
                key = String(strip(elements[1]))
                data_list = Vector{Any}()
                push!(data_list, String(strip(elements[2])))
                
                val_str_time = strip(elements[3])
                if !isempty(val_str_time) && val_str_time != "?"
                    try
                        push!(data_list, parse(Float64, val_str_time))
                    catch
                        push!(data_list, NaN)
                    end
                else
                    push!(data_list, missing)
                end

                expected_feature_count = 30
                current_feature_count = 0
                
                for i in 4:length(elements)
                    if current_feature_count >= expected_feature_count
                        break
                    end
                    val_str = strip(elements[i])
                    if !isempty(val_str) && val_str != "?"
                        try
                            push!(data_list, parse(Float64, val_str))
                        catch
                            push!(data_list, NaN)
                        end
                    else
                        push!(data_list, missing)
                    end
                    current_feature_count += 1
                end

                while current_feature_count < expected_feature_count
                    push!(data_list, missing)
                    current_feature_count += 1
                end

                data_dict[key] = data_list
            end
        end
    catch e
        if isa(e, SystemError) && e.errnum == Base.UV_ENOENT
            println("Error: File not found")
        else
            rethrow(e)
        end
    end
    return data_dict
end

file_path = "wpbc.data"
wpbc_data = read_to_dict(file_path)

In [None]:
# vectors for each column
ids = String[]
event_strs = String[]
times = Vector{Union{Float64, Missing}}() 
features_data = [Vector{Union{Float64, Missing}}() for _ in 1:30] 

for (id, data) in wpbc_data
    if length(data) == 32
        push!(ids, id)
        push!(event_strs, data[1])
        push!(times, data[2])
        
        for i in 1:30
            push!(features_data[i], data[i+2])
        end
    else
        println("Warning: Row with ID $id has unexpected data length $(length(data)). Expected 32 values. Skipping.")
    end
end

# create the DataFrame
df_prognosis = DataFrame(
    ID = ids,
    Event_str = event_strs,
    Time = times
)

# add feature columns
for i in 1:30
    df_prognosis[!, Symbol("feature_$(i)")] = features_data[i]
end
# println("First 5 rows (ID, Event, Time, first 3 features):")
# println(first(df_prognosis[:, [:ID, :Event_str, :Time, :feature_1, :feature_2, :feature_3]], 5)) 
# println("Total columns in df_prognosis: $(names(df_prognosis))")

# Define feature column names for imputation
features_rsa_names = [Symbol("feature_$(i)") for i in 1:30]

for col_name in features_rsa_names
    if eltype(df_prognosis[!, col_name]) <: Union{Missing, Number}
        if any(ismissing, df_prognosis[!, col_name])
            mean_val = skipmissing(df_prognosis[!, col_name]) |> mean
            replace!(df_prognosis[!, col_name], missing => mean_val)
            # println(" $(round(mean_val, digits=2))") 
        end
    end
end

if any(ismissing, df_prognosis.Time)
    mean_time = skipmissing(df_prognosis.Time) |> mean
    replace!(df_prognosis.Time, missing => mean_time)
end

for col_name in features_rsa_names
    df_prognosis[!, col_name] = convert(Vector{Float64}, df_prognosis[!, col_name])
end
df_prognosis.Time = convert(Vector{Float64}, df_prognosis.Time)




In [None]:
# 2. Preprocessing like Encode Event, Extract Features, Normalize

# map Event_str to numeric: "R" (Recurrence) = 1, "N" (Non-recurrence/Censored) = 0
df_prognosis.Event = ifelse.(df_prognosis.Event_str .== "R", 1, 0)

# feature columns for X_rsa
features_rsa_final = names(df_prognosis, Not([:ID, :Event_str, :Time, :Event]))
X_rsa_raw = Matrix(df_prognosis[:, features_rsa_final])

# normalize features using zscore (standard one)
X_rsa = zscore(X_rsa_raw, 1)

# extract Time and Event vectors
Time_rsa = df_prognosis.Time
Event_rsa = df_prognosis.Event


In [None]:
# 3. Train-Test Split for RSA Data
Random.seed!(123)
num_observations_rsa = size(X_rsa, 1)
indices_rsa = collect(1:num_observations_rsa)
Random.shuffle!(indices_rsa)

split_point_rsa = floor(Int, num_observations_rsa * 0.8)

train_indices_rsa = indices_rsa[1:split_point_rsa]
test_indices_rsa = indices_rsa[split_point_rsa+1:end]

X_train_rsa = X_rsa[train_indices_rsa, :]
Time_train_rsa = Time_rsa[train_indices_rsa]
Event_train_rsa = Event_rsa[train_indices_rsa]

X_test_rsa = X_rsa[test_indices_rsa, :]
Time_test_rsa = Time_rsa[test_indices_rsa]
Event_test_rsa = Event_rsa[test_indices_rsa]

# these are the specific subsets for solve_rsa_lp function
X_recurrent_train = X_train_rsa[Event_train_rsa .== 1, :]
TTR_recurrent_train = Time_train_rsa[Event_train_rsa .== 1]

X_censored_train = X_train_rsa[Event_train_rsa .== 0, :]
DFS_censored_train = Time_train_rsa[Event_train_rsa .== 0]

X_recurrent_test = X_test_rsa[Event_test_rsa .== 1, :]
TTR_recurrent_test = Time_test_rsa[Event_test_rsa .== 1]

X_censored_test = X_test_rsa[Event_test_rsa .== 0, :]
DFS_censored_test = Time_test_rsa[Event_test_rsa .== 0]




In [None]:
# 4. Define solve_rsa_lp function, this is the function for optimizing
function solve_rsa_lp(X_recurrent::Matrix{Float64}, TTR_recurrent::Vector{Float64},
                      X_censored::Matrix{Float64}, DFS_censored::Vector{Float64};
                      gamma_param::Float64=1.0, Omega_param::Float64=1.0)

    N_R = size(X_recurrent, 1)
    N_C = size(X_censored, 1)
    d = size(X_recurrent, 2)

    model_rsa = JuMP.Model(HiGHS.Optimizer)
    set_silent(model_rsa)

    # variables
    @variable(model_rsa, w[1:d])
    @variable(model_rsa, b) # Bias term
    @variable(model_rsa, y[1:N_R] >= 0)
    @variable(model_rsa, z[1:N_R] >= 0) 
    @variable(model_rsa, v[1:N_C] >= 0) 

    # auxiliary variables for L1 regularization on w
    @variable(model_rsa, t[1:d] >= 0)

    # objective function is minimize combined errors
    # Only include if N_R > 0
    obj_recurrent_error = N_R > 0 ? sum(y) + sum(z) : 0.0
    # Only include if N_C > 0
    obj_censored_error = N_C > 0 ? gamma_param * sum(v) : 0.0 

    if N_R > 0 && N_C > 0
        @objective(model_rsa, Min, obj_recurrent_error/N_R + obj_censored_error/N_C)
    elseif N_R > 0
        @objective(model_rsa, Min, obj_recurrent_error/N_R)
    elseif N_C > 0
        @objective(model_rsa, Min, obj_censored_error/N_C)
    else
        @objective(model_rsa, Min, 0.0)
    end

    # constraints for recurrent patients (i in R)
    for i in 1:N_R
        predicted_ttr = dot(X_recurrent[i,:], w) + b
        @constraint(model_rsa, predicted_ttr - TTR_recurrent[i] <= y[i])
        @constraint(model_rsa, TTR_recurrent[i] - predicted_ttr <= z[i])
    end

    # constraints for censored patients (i in C)
    for i in 1:N_C
        predicted_ttr = dot(X_censored[i,:], w) + b
        @constraint(model_rsa, predicted_ttr + v[i] >= DFS_censored[i])
    end

    # regularization (L1 norm on w)
    for j in 1:d
        @constraint(model_rsa, w[j] <= t[j])
        @constraint(model_rsa, w[j] >= -t[j])
    end
    @constraint(model_rsa, sum(t) <= Omega_param)

    optimize!(model_rsa)

    status = termination_status(model_rsa)

    if status == MOI.OPTIMAL || status == MOI.LOCALLY_OPTIMAL
        return value.(w), value(b), status
    else
        return zeros(d), 0.0, status
    end
end



In [None]:
# 5. Define predict_rsa function
function predict_rsa(x_new::Vector{Float64}, w_opt::Vector{Float64}, b_opt::Float64)
    return dot(x_new, w_opt) + b_opt
end

In [None]:
# --- 6. Run RSA Optimization ---
println("\n--- Running RSA Optimization on Training Data ---")
w_rsa_opt, b_rsa_opt, rsa_status = solve_rsa_lp(
    X_recurrent_train, TTR_recurrent_train,
    X_censored_train, DFS_censored_train;
    gamma_param=0.5,
    Omega_param=80.0 
)

if rsa_status == MOI.OPTIMAL || rsa_status == MOI.LOCALLY_OPTIMAL
    println("RSA optimization successful!")
    # --- 7. Evaluate RSA Performance on Test Data ---
    println("\n--- Evaluating RSA Performance on Test Data ---")

    predicted_ttr_test = [predict_rsa(X_test_rsa[i,:], w_rsa_opt, b_rsa_opt) for i in 1:size(X_test_rsa, 1)]

    println("\n--- RSA Classification and Confusion Matrix (Test Data) ---")

    # Define a threshold for classification.
    # A common approach is to use the median or mean of TTR for recurrent cases,
    # or a clinically relevant time point (5 years = 60 months).
    # For demonstration, let's use the median TTR from the recurrent training data.
    # choose a threshold based on the problem.
    if !isempty(TTR_recurrent_train)
        classification_threshold = median(TTR_recurrent_train)
        println("Using classification threshold (median TTR from training recurrent data): $(round(classification_threshold, digits=2)) months")
    else
        # Fallback if no recurrent data in train set, choose a reasonable default
        classification_threshold = 36.0 # Example: 36 months (3 years)
        println("Warning: No recurrent training data. Using default threshold: $(round(classification_threshold, digits=2)) months")
    end

    # Classify predictions:
    # If predicted TTR <= threshold, classify as Recurrent (1)
    # If predicted TTR > threshold, classify as Non-Recurrent (0)
    # Note: RSA aims to predict TIME to recurrence. A *shorter* predicted time
    # could indicate a higher risk of recurrence, or that recurrence has already happened.
    # For classification, we're typically interested in whether recurrence *will* happen.
    # A low predicted TTR might mean "recurrent soon" -> classified as recurrent.
    # A high predicted TTR might mean "not recurrent soon" -> classified as non-recurrent.
    # The 'Event' column (1 for Recurrent, 0 for Non-Recurrent/Censored) is our ground truth.
    
    # For this, let's define:
    # Predicted Recurrent: predicted_ttr_test <= classification_threshold
    # Predicted Non-Recurrent: predicted_ttr_test > classification_threshold

    predicted_classes = [p_ttr <= classification_threshold ? 1 : 0 for p_ttr in predicted_ttr_test]
    true_classes = Event_test_rsa # This is your ground truth (1 for Recurrent, 0 for Non-Recurrent)

    # Calculate Confusion Matrix components
    TP = sum((predicted_classes .== 1) .& (true_classes .== 1))
    TN = sum((predicted_classes .== 0) .& (true_classes .== 0))
    FP = sum((predicted_classes .== 1) .& (true_classes .== 0))
    FN = sum((predicted_classes .== 0) .& (true_classes .== 1))

    println("\n--- RSA Confusion Matrix (Test Data, based on TTR threshold) ---")
    println("True Positives (TP): $TP")
    println("True Negatives (TN): $TN")
    println("False Positives (FP): $FP")
    println("False Negatives (FN): $FN")

    # Calculate common classification metrics
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN) # Also known as Sensitivity
    specificity = TN / (TN + FP)
    f1_score = 2 * (precision * recall) / (precision + recall)

    println("\n--- RSA Classification Metrics (Test Data) ---")
    println("Accuracy:    $(round(accuracy * 100, digits=2))%")
    println("Precision:   $(round(precision * 100, digits=2))%")
    println("Recall:      $(round(recall * 100, digits=2))%")
    println("F1-Score:    $(round(f1_score * 100, digits=2))%")
    println("Specificity: $(round(specificity * 100, digits=2))%")

else
    println("RSA optimization failed. Cannot perform classification evaluation.")
end

In [None]:
println("--- Inspection of w_rsa_opt ---")
println("Type of w_rsa_opt: ", typeof(w_rsa_opt))
println("Length of w_rsa_opt: ", length(w_rsa_opt))
println("First few elements of w_rsa_opt: ", w_rsa_opt[1:min(length(w_rsa_opt), 5)]) # Show first 5 elements or all if less than 5

## Plotting

In [None]:
using Plots
using Printf

println("\n--- Generating Visualizations ---")

col_names = [
    "ID", "Diagnosis",
    "radius1", "texture1", "perimeter1", "area1", "smoothness1", "compactness1", "concavity1", "concave_points1", "symmetry1", "fractal_dimension1",
    "radius2", "texture2", "perimeter2", "area2", "smoothness2", "compactness2", "concavity2", "concave_points2", "symmetry2", "fractal_dimension2",
    "radius3", "texture3", "perimeter3", "area3", "smoothness3", "compactness3", "concavity3", "concave_points3", "symmetry3", "fractal_dimension3"
]
feature_names = col_names[3:end]

println("Generating bar plot for RSA feature coefficients...")
annotations = []

max_abs_coeff = isempty(w_rsa_opt) ? 0.0 : maximum(abs.(w_rsa_opt))
min_coeff_for_inside = 0.1 * max_abs_coeff
annotation_fontsize = 6
inside_margin_factor = 0.8

for (i, coeff) in enumerate(w_rsa_opt)
    label = @sprintf("%.2f", coeff)
    y_offset = abs(coeff) > min_coeff_for_inside ? inside_margin_factor * coeff : coeff + 0.02 * max_abs_coeff
    push!(annotations, (i, y_offset, label))
end

# Plotting
p_coeffs_rsa = bar(
    1:length(w_rsa_opt), w_rsa_opt,
    title = "RSA Model Feature Coefficients",
    xlabel = "Feature Index",
    ylabel = "Coefficient Value",
    legend = false,
    size = (1200, 700),
    xticks = (1:length(feature_names), feature_names),
    xrotation = 90,
    bottom_margin = 15Plots.mm,
    left_margin = 10Plots.mm,
    annotations = annotations,
    annotationfontsize = annotation_fontsize,
)

display(p_coeffs_rsa)

savefig(p_coeffs_rsa, "images/rsa_feature_coefficients.png")

println("Displayed 'RSA Model Feature Coefficients' plot.")


In [None]:

# 2. Scatter Plot: Predicted vs. Actual TTR for Recurrent Patients (Test Set)
if !isempty(TTR_recurrent_test)
    println("\nGenerating scatter plot for Predicted vs. Actual TTR (Recurrent Patients)...")
    
    # Ensure predicted_ttr_recurrent_test is calculated
    predicted_ttr_recurrent_test = [predict_rsa(X_recurrent_test[i,:], w_rsa_opt, b_rsa_opt) for i in 1:size(X_recurrent_test, 1)]

    p_ttr_scatter = scatter(TTR_recurrent_test, predicted_ttr_recurrent_test,
                            title="RSA: Predicted vs. Actual TTR (Recurrent, Test Set)",
                            xlabel="Actual Time To Recurrence (TTR)",
                            ylabel="Predicted Time To Recurrence (TTR)",
                            legend=false,
                            markeralpha=0.7,
                            markersize=5,
                            color=:blue)
    
    # Add a y=x reference line
    # Determine the maximum value for the axis limits to make the y=x line fit
    max_val = max(maximum(TTR_recurrent_test), maximum(predicted_ttr_recurrent_test))
    plot!(p_ttr_scatter, [0, max_val], [0, max_val], linecolor=:red, linestyle=:dash, label="y=x")
    
    display(p_ttr_scatter)
    # Uncomment the line below to save the plot to a file
    savefig(p_ttr_scatter, "images/rsa_predicted_vs_actual_ttr.png")
    println("Displayed 'Predicted vs. Actual TTR (Recurrent)' plot.")
else
    println("\nSkipping Predicted vs. Actual TTR plot: No recurrent patients in the test set.")
end


# 3. Scatter Plot: Predicted TTR vs. DFS for Censored Patients (Test Set)
if !isempty(DFS_censored_test)
    println("\nGenerating scatter plot for Predicted TTR vs. DFS (Censored Patients)...")

    # Ensure predicted_ttr_censored_test is calculated
    predicted_ttr_censored_test = [predict_rsa(X_censored_test[i,:], w_rsa_opt, b_rsa_opt) for i in 1:size(X_censored_test, 1)]

    # Highlight points where prediction is below DFS (violates constraint)
    violating_indices = findall(predicted_ttr_censored_test .< DFS_censored_test)
    non_violating_indices = findall(predicted_ttr_censored_test .>= DFS_censored_test)

    p_dfs_scatter = scatter(DFS_censored_test[non_violating_indices], predicted_ttr_censored_test[non_violating_indices],
                            label="Predicted >= DFS",
                            markeralpha=0.7,
                            markersize=5,
                            color=:green)
    
    scatter!(p_dfs_scatter, DFS_censored_test[violating_indices], predicted_ttr_censored_test[violating_indices],
             label="Predicted < DFS",
             markeralpha=0.7,
             markersize=5,
             color=:red)

    plot!(p_dfs_scatter, title="RSA: Predicted TTR vs. DFS (Censored, Test Set)",
                         xlabel="Disease-Free Survival Time (DFS)",
                         ylabel="Predicted Time To Recurrence (TTR)",
                         legend=:topleft)

    # Add a y=x reference line (ideal scenario where prediction >= DFS)
    max_val_dfs = max(maximum(DFS_censored_test), maximum(predicted_ttr_censored_test))
    plot!(p_dfs_scatter, [0, max_val_dfs], [0, max_val_dfs], linecolor=:blue, linestyle=:dash, label="y=x (Ideal)")
    
    display(p_dfs_scatter)
    # Uncomment the line below to save the plot to a file
    savefig(p_dfs_scatter, "images/rsa_predicted_vs_dfs_censored.png")
    println("Displayed 'Predicted TTR vs. DFS (Censored)' plot.")
else
    println("\nSkipping Predicted TTR vs. DFS plot: No censored patients in the test set.")
end