In [6]:
from Datasets.data import AdultDataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import optuna

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import StratifiedKFold

In [7]:
# Load and prepare data
def load_data(dataset_name):
    if dataset_name in ['Adult', 'Bank']:
        dataset = AdultDataset()
        X, y, feature_names, x = dataset.load_data()
        X_old, y_old, X_new, y_new = dataset.split_old_new_data(X, y)
        return X_old, y_old, X_new, y_new, feature_names, x
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

In [None]:
X_old, y_old, X_new, y_new, feature_names, x = load_data('Adult')
# Verify the split
print(f"Old data shape: {X_old}, Positive samples: {y_old.sum()}")
print(f"New data shape: {X_new}, Positive samples: {y_new.sum()}")

In [9]:
from Models.models import train_old_model, train_new_models
from Models.lct import LocalCorrectionTree

In [None]:
def evaluate_old_model(dataset, X, y, old_model):
    """
    Evaluate the old model on the entire new dataset (X, y).
    Returns a dictionary of metrics with single values.
    """
    y_pred = old_model.predict(X)
    is_multiclass = dataset in ['CTG', '7-point']
    
    if is_multiclass:
        precision = precision_score(y, y_pred, average='macro', zero_division=0)
        recall = recall_score(y, y_pred, average='macro', zero_division=0)
        f1 = f1_score(y, y_pred, average='macro', zero_division=0)
    else:
        precision = precision_score(y, y_pred, zero_division=0)
        recall = recall_score(y, y_pred, zero_division=0)
        f1 = f1_score(y, y_pred, zero_division=0)
    
    accuracy = accuracy_score(y, y_pred)
    
    return {
        'Accuracy': [accuracy],
        'Precision': [precision],
        'Recall': [recall],
        'F1': [f1]
    }

In [None]:
def evaluate_new_models(dataset, models, X_test, y_test, old_model):
    """
    Evaluate new models including LCT on the test set.
    Returns metrics formatted for error bar plotting.
    """
    metrics = {
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1': []
    }
    
    is_multiclass = dataset in ['CTG', '7-point']
    old_scores_test = old_model.predict_proba(X_test)
    
    for model_name, model in models.items():
        if model_name == 'Ours':  # LCT model
            corrections = model.predict(X_test)
            corrected_scores = old_scores_test + corrections
            y_pred = np.argmax(corrected_scores, axis=1)
        else:
            if model_name.endswith('+'):
                # Models with old scores as additional features
                X_test_with_scores = np.hstack([X_test, old_scores_test])
                y_pred = model.predict(X_test_with_scores)
            else:
                # Base models
                y_pred = model.predict(X_test)
        
        if is_multiclass:
            precision = precision_score(y_test, y_pred, average='macro', zero_division=0)
            recall = recall_score(y_test, y_pred, average='macro', zero_division=0)
            f1 = f1_score(y_test, y_pred, average='macro', zero_division=0)
        else:
            precision = precision_score(y_test, y_pred, zero_division=0)
            recall = recall_score(y_test, y_pred, zero_division=0)
            f1 = f1_score(y_test, y_pred, zero_division=0)
        
        accuracy = accuracy_score(y_test, y_pred)
        
        metrics['Accuracy'].append(accuracy)
        metrics['Precision'].append(precision)
        metrics['Recall'].append(recall)
        metrics['F1'].append(f1)
    
    return metrics


In [None]:
def plot_metrics_comparison(dataset, metrics, model_names):
    """
    Create comparison plots with error bars for model metrics.
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1']
    
    # Define model colors and markers
    colors = {
        'L1-LR': '#1f77b4',
        'L1-LR+': '#1f77b4',
        'L2-LR': '#ff7f0e',
        'L2-LR+': '#ff7f0e',
        'DT': '#2ca02c',
        'DT+': '#2ca02c',
        'RF': '#8c564b',
        'RF+': '#8c564b',
        'LGBM': '#9467bd',
        'LGBM+': '#9467bd',
        'LGBM+C': '#9467bd',
        'Ours': '#d62728',
        'Old': '#7f7f7f'
    }
    
    for idx, metric in enumerate(metrics_names):
        ax = axes[idx]
        y_pos = np.arange(len(model_names))
        values = metrics[metric]
        
        # Plot horizontal lines with error bars
        for i, (name, value) in enumerate(zip(model_names, values)):
            color = colors.get(name, 'black')
            mean_val = np.mean(value)
            std_val = np.std(value)
            
            # Plot horizontal line
            line_length = 0.02
            ax.plot([mean_val-line_length, mean_val+line_length], [i, i],
                   color=color, linewidth=2, solid_capstyle='butt')
            
            # Plot error bars
            if name.endswith('+'):
                marker = 'o'
                fillstyle = 'none'
            else:
                marker = 'o'
                fillstyle = 'full'
            
            ax.errorbar(mean_val, i, xerr=std_val,
                       fmt=marker, color=color, capsize=3,
                       markersize=4, fillstyle=fillstyle,
                       elinewidth=1, capthick=1)
        
        ax.set_yticks(y_pos)
        ax.set_yticklabels(model_names)
        ax.set_title(f'({chr(97+idx)}) {metric} ({dataset})')
        ax.grid(True, alpha=0.3, linestyle='--')
        
        # Set x-axis limits and ticks
        all_values = [np.mean(v) for v in values]
        min_val = min(all_values) - 0.05
        max_val = max(all_values) + 0.05
        ax.set_xlim(min_val, max_val)
        
        # Format x-axis ticks to match the image
        ax.set_xticks(np.linspace(min_val, max_val, 5))
        ax.set_xticklabels([f'{x:.2f}' for x in ax.get_xticks()])
    
    plt.tight_layout()
    return fig

In [None]:
def print_tree(lct, node_id=0, depth=0, feature_names=None):
    """
    Print the Local Correction Tree structure.
    Args:
        lct: LocalCorrectionTree instance
        node_id: Current node ID
        depth: Current depth in tree
        feature_names: Optional list of feature names
    """
    feature_idx, threshold, w_node = lct.nodes[node_id]
    left_id, right_id = lct.children[node_id]
    
    indent = "  " * depth
    if feature_idx == -1:
        print(f"{indent}Leaf: correction = {w_node.round(4)}")
    else:
        feature = f"X[{feature_idx}]" if feature_names is None else feature_names[feature_idx]
        print(f"{indent}Node: {feature} <= {threshold:.4f}")
        print_tree(lct, left_id, depth + 1, feature_names)
        print_tree(lct, right_id, depth + 1, feature_names)

In [None]:
# Modified main execution
datasets = ['Adult']
for dataset in datasets:
    print(f"\nProcessing {dataset} dataset")
    
    base_models = ['L1-LR', 'L2-LR', 'DT', 'RF', 'LGBM']
    enhanced_models = ['L1-LR+', 'L2-LR+', 'DT+', 'RF+', 'LGBM+', 'LGBM+C', 'Ours']
    model_names = base_models + enhanced_models + ['Old']
    
    # 1. Load data
    if dataset == '7-point':
        X_old, y_old, X_new, y_new, X_clinical = load_data(dataset)
    else:
        X_old, y_old, X_new, y_new, feature_names, x = load_data(dataset)
    
    X_new = np.array(X_new)
    y_new = np.array(y_new)
    
    # 2. Train and evaluate old model with multiple runs
    n_runs = 5
    old_metrics_runs = {
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1': []
    }
    
    for run in range(n_runs):
        old_model = train_old_model(dataset, X_old, y_old)
        run_metrics = evaluate_old_model(dataset, X_new, y_new, old_model)
        for metric in old_metrics_runs:
            old_metrics_runs[metric].append(run_metrics[metric][0])
    
    # 3. Perform 5-fold cross validation for new models
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    all_metrics = {
        'Accuracy': {model: [] for model in model_names[:-1]},
        'Precision': {model: [] for model in model_names[:-1]},
        'Recall': {model: [] for model in model_names[:-1]},
        'F1': {model: [] for model in model_names[:-1]}
    }
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(X_new, y_new)):
        print(f"Processing fold {fold + 1}/5")
        
        X_train, X_test = X_new[train_idx], X_new[test_idx]
        y_train, y_test = y_new[train_idx], y_new[test_idx]
        
        # Get old model predictions for enhanced models
        old_scores_train = old_model.predict_proba(X_train)
        old_scores_test = old_model.predict_proba(X_test)
        
        # Train base models
        base_model_dict = train_new_models(dataset, X_train, y_train, old_model, LocalCorrectionTree)
        
        # Train enhanced models with old model scores
        X_train_enhanced = np.hstack([X_train, old_scores_train])
        X_test_enhanced = np.hstack([X_test, old_scores_test])
        enhanced_model_dict = train_new_models(dataset, X_train_enhanced, y_train, old_model, LocalCorrectionTree)
        
        # Combine all models
        fold_models = {**base_model_dict, **enhanced_model_dict}
        
        # Evaluate models
        fold_metrics = evaluate_new_models(dataset, fold_models, X_test, y_test, old_model)
        
        for metric in all_metrics:
            for model in model_names[:-1]:
                if model in base_models:
                    all_metrics[metric][model].append(fold_metrics[metric][model])
                else:
                    # For enhanced models, use enhanced data
                    if model == 'Ours':
                        # LCT uses original features but adds corrections
                        all_metrics[metric][model].append(fold_metrics[metric][model])
                    else:
                        all_metrics[metric][model].append(fold_metrics[metric][model])
                        
    print("\nFinal Local Correction Tree Structure:")
    if 'Ours' in enhanced_models:  # Check if LCT model exists
        enhanced_models['Ours'].simplify()  # Remove redundant nodes
        print_tree(enhanced_models['Ours'], feature_names=feature_names)
    
    # 4. Prepare metrics for plotting
    plot_metrics = {
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1': []
    }
    
    for metric in plot_metrics:
        plot_metrics[metric].append(old_metrics_runs[metric])
        for model in model_names[:-1]:
            plot_metrics[metric].append(all_metrics[metric][model])
    
    # 5. Create and save plot
    fig = plot_metrics_comparison(dataset, plot_metrics, model_names)
    plt.savefig(f'{dataset}_model_comparison.png', bbox_inches='tight', dpi=300)
    plt.close()
    
    # 6. Print numerical results
    print(f"\nResults for {dataset}:")
    for metric in plot_metrics:
        print(f"\n{metric}:")
        for name, values in zip(model_names, plot_metrics[metric]):
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f"{name}: {mean_val:.4f} ± {std_val:.4f}")

print("\nDone evaluating old model and new models.")