In [None]:
from Datasets.data import AdultDataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import optuna

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import StratifiedKFold

In [3]:
# Load and prepare data
def load_data(dataset_name):
    if dataset_name in ['Adult', 'Bank']:
        dataset = AdultDataset()
        X, y, feature_names, x = dataset.load_data()
        X_old, y_old, X_new, y_new = dataset.split_old_new_data(X, y)
        return X_old, y_old, X_new, y_new, feature_names, x
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

In [None]:
X_old, y_old, X_new, y_new, feature_names, x = load_data('Adult')
# Verify the split
print(f"Old data shape: {X_old}, Positive samples: {y_old.sum()}")
print(f"New data shape: {X_new}, Positive samples: {y_new.sum()}")

In [None]:
from Models.models import train_old_model, train_new_models
from Models.lct import LocalCorrectionTree

In [6]:
def evaluate_old_model(dataset, X, y, old_model):
    """
    Evaluate the old model on the given dataset (X, y).
    Returns a dictionary of metrics with lists of values.
    """
    y_pred = old_model.predict(X)
    is_multiclass = dataset in ['CTG', '7-point']
    
    if is_multiclass:
        precision = precision_score(y, y_pred, average='macro', zero_division=0)
        recall = recall_score(y, y_pred, average='macro', zero_division=0)
        f1 = f1_score(y, y_pred, average='macro', zero_division=0)
    else:
        precision = precision_score(y, y_pred, zero_division=0)
        recall = recall_score(y, y_pred, zero_division=0)
        f1 = f1_score(y, y_pred, zero_division=0)
    
    accuracy = accuracy_score(y, y_pred)
    
    return {
        'Accuracy': [accuracy],
        'Precision': [precision],
        'Recall': [recall],
        'F1': [f1]
    }

In [7]:
def evaluate_new_models(dataset, models, X_test, y_test, base_models_dict):
    """
    Evaluate new models including LCT on the test set.
    Returns metrics formatted for error bar plotting.
    """
    metrics = {
        'Accuracy': {},
        'Precision': {},
        'Recall': {},
        'F1': {}
    }
    
    is_multiclass = dataset in ['CTG', '7-point']
    
    for model_name, model in models.items():
        base_model_name = model_name.replace('+', '').replace('C', '')
        old_scores_test = base_models_dict[base_model_name].predict_proba(X_test)
        
        if model_name == 'Ours':  # LCT model
            corrections = model.predict(X_test)
            corrected_scores = old_scores_test + corrections
            y_pred = np.argmax(corrected_scores, axis=1)
        elif model_name.endswith('+') or model_name.endswith('+C'):
            # Models with old scores as additional features
            X_test_with_scores = np.hstack([X_test, old_scores_test])
            y_pred = model.predict(X_test_with_scores)
        else:
            # Base models
            y_pred = model.predict(X_test)
        
        if is_multiclass:
            precision = precision_score(y_test, y_pred, average='macro', zero_division=0)
            recall = recall_score(y_test, y_pred, average='macro', zero_division=0)
            f1 = f1_score(y_test, y_pred, average='macro', zero_division=0)
        else:
            precision = precision_score(y_test, y_pred, zero_division=0)
            recall = recall_score(y_test, y_pred, zero_division=0)
            f1 = f1_score(y_test, y_pred, zero_division=0)
        
        accuracy = accuracy_score(y_test, y_pred)
        
        metrics['Accuracy'][model_name] = accuracy
        metrics['Precision'][model_name] = precision
        metrics['Recall'][model_name] = recall
        metrics['F1'][model_name] = f1
    
    return metrics


In [8]:
def plot_metrics_comparison(dataset, metrics, model_names):
    """
    Create comparison plots with error bars for model metrics.
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1']
    
    # Define model colors and markers
    colors = {
        'L1-LR': '#1f77b4',
        'L1-LR+': '#1f77b4',
        'L2-LR': '#ff7f0e',
        'L2-LR+': '#ff7f0e',
        'DT': '#2ca02c',
        'DT+': '#2ca02c',
        'RF': '#8c564b',
        'RF+': '#8c564b',
        'LGBM': '#9467bd',
        'LGBM+': '#9467bd',
        'LGBM+C': '#9467bd',
        'Ours': '#d62728',
        'Old': '#7f7f7f'
    }
    
    for idx, metric in enumerate(metrics_names):
        ax = axes[idx]
        y_pos = np.arange(len(model_names))
        values = metrics[metric]
        
        # Plot horizontal lines with error bars
        for i, (name, value) in enumerate(zip(model_names, values)):
            color = colors.get(name, 'black')
            mean_val = np.mean(value)
            std_val = np.std(value)
            
            # Plot horizontal line
            line_length = 0.02
            ax.plot([mean_val-line_length, mean_val+line_length], [i, i],
                   color=color, linewidth=2, solid_capstyle='butt')
            
            # Plot error bars
            if name.endswith('+'):
                marker = 'o'
                fillstyle = 'none'
            else:
                marker = 'o'
                fillstyle = 'full'
            
            ax.errorbar(mean_val, i, xerr=std_val,
                       fmt=marker, color=color, capsize=3,
                       markersize=4, fillstyle=fillstyle,
                       elinewidth=1, capthick=1)
        
        ax.set_yticks(y_pos)
        ax.set_yticklabels(model_names)
        ax.set_title(f'({chr(97+idx)}) {metric} ({dataset})')
        ax.grid(True, alpha=0.3, linestyle='--')
        
        # Set x-axis limits and ticks
        all_values = [np.mean(v) for v in values]
        min_val = min(all_values) - 0.05
        max_val = max(all_values) + 0.05
        ax.set_xlim(min_val, max_val)
        
        # Format x-axis ticks to match the image
        ax.set_xticks(np.linspace(min_val, max_val, 5))
        ax.set_xticklabels([f'{x:.2f}' for x in ax.get_xticks()])
    
    plt.tight_layout()
    return fig

In [9]:
def print_tree(lct, node_id=0, depth=0, feature_names=None):
    """
    Print the Local Correction Tree structure.
    Args:
        lct: LocalCorrectionTree instance
        node_id: Current node ID
        depth: Current depth in tree
        feature_names: Optional list of feature names
    """
    feature_idx, threshold, w_node = lct.nodes[node_id]
    left_id, right_id = lct.children[node_id]
    
    indent = "  " * depth
    if feature_idx == -1:
        print(f"{indent}Leaf: correction = {w_node.round(4)}")
    else:
        feature = f"X[{feature_idx}]" if feature_names is None else feature_names[feature_idx]
        print(f"{indent}Node: {feature} <= {threshold:.4f}")
        print_tree(lct, left_id, depth + 1, feature_names)
        print_tree(lct, right_id, depth + 1, feature_names)

In [None]:
datasets = ['Adult']
for dataset in datasets:
    print(f"\nProcessing {dataset} dataset")
    
    # Define only base models
    base_models = ['L1-LR', 'L2-LR', 'DT', 'RF', 'LGBM']
    
    # 1. Load data
    if dataset == '7-point':
        X_old, y_old, X_new, y_new, X_clinical = load_data(dataset)
    else:
        X_old, y_old, X_new, y_new, feature_names, x = load_data(dataset)
    
    X_new, X_old = np.array(X_new), np.array(X_old)
    y_new, y_old = np.array(y_new), np.array(y_old)
    
    # 2. Train base models on old data
    base_models_dict = train_old_model(dataset, X_new, X_old, y_new, y_old)
    
    # 3. Perform 5-fold cross validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    all_metrics = {
        'Accuracy': {model: [] for model in base_models},
        'Precision': {model: [] for model in base_models},
        'Recall': {model: [] for model in base_models},
        'F1': {model: [] for model in base_models}
    }
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(X_old, y_old)):
        print(f"Processing fold {fold + 1}/5")
        
        X_train, X_test = X_old[train_idx], X_old[test_idx]
        y_train, y_test = y_old[train_idx], y_old[test_idx]
        
        # Evaluate each base model individually
        for model_name in base_models:
            fold_metrics = evaluate_old_model(dataset, X_test, y_test, base_models_dict[model_name])
            
            for metric in all_metrics:
                all_metrics[metric][model_name].extend(fold_metrics[metric])
    
    # 4. Prepare metrics for plotting
    plot_metrics = {
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1': []
    }
    
    for metric in plot_metrics:
        for model in base_models:
            plot_metrics[metric].append(all_metrics[metric][model])
    
    # 5. Create and save plot
    fig = plot_metrics_comparison(dataset, plot_metrics, base_models)
    plt.savefig(f'{dataset}_base_models_comparison.png', bbox_inches='tight', dpi=400)
    plt.close()
    
    # 6. Print numerical results
    print(f"\nBase Model Results for {dataset}:")
    for metric in plot_metrics:
        print(f"\n{metric}:")
        for name, values in zip(base_models, plot_metrics[metric]):
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f"{name}: {mean_val:.4f} ± {std_val:.4f}")

print("\nDone evaluating base models.")

In [None]:
datasets = ['Adult']
for dataset in datasets:
    print(f"\nProcessing {dataset} dataset")
    
    # Define base and enhanced models
    base_models = ['L1-LR', 'L2-LR', 'DT', 'RF', 'LGBM']
    enhanced_models = ['L1-LR+', 'L2-LR+', 'DT+', 'RF+', 'LGBM+', 'LGBM+C', 'Ours']
    all_models = base_models + enhanced_models
    
    # 1. Load data
    if dataset == '7-point':
        X_old, y_old, X_new, y_new, X_clinical = load_data(dataset)
    else:
        X_old, y_old, X_new, y_new, feature_names, x = load_data(dataset)
    
    X_new, X_old = np.array(X_new), np.array(X_old)
    y_new, y_old = np.array(y_new), np.array(y_old)
    
    # 2. Train base models on old data
    base_models_dict = train_old_model(dataset, X_new, X_old, y_new, y_old)
    
    # 3. Train enhanced models on new data
    enhanced_models_dict = {}
    for base_model_name in base_models:
        enhanced_models_dict.update(train_new_models(dataset, X_new, y_new, base_models_dict[base_model_name], LocalCorrectionTree))
    
    # 4. Perform 5-fold cross validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    all_metrics = {
        'Accuracy': {model: [] for model in all_models},
        'Precision': {model: [] for model in all_models},
        'Recall': {model: [] for model in all_models},
        'F1': {model: [] for model in all_models}
    }
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(X_new, y_new)):
        print(f"Processing fold {fold + 1}/5")
        
        X_train, X_test = X_new[train_idx], X_new[test_idx]
        y_train, y_test = y_new[train_idx], y_new[test_idx]
        
        # Evaluate each base model individually
        for model_name in base_models:
            fold_metrics = evaluate_old_model(dataset, X_test, y_test, base_models_dict[model_name])
            
            for metric in all_metrics:
                all_metrics[metric][model_name].extend(fold_metrics[metric])
        
        # Evaluate each enhanced model individually
        for model_name in enhanced_models:
            if model_name == 'Ours':
                # Special handling for LocalCorrectionTree
                old_model_scores = base_models_dict[model_name].predict_proba(X_train)
                lct_model = enhanced_models_dict[model_name]
                lct_model.fit(X_train, y_train, old_model_scores)
                fold_metrics = evaluate_new_models(dataset, lct_model, X_test, y_test, base_models_dict[model_name])
            else:
                fold_metrics = evaluate_new_models(dataset, X_test, y_test, enhanced_models_dict[model_name], base_models_dict['LGBM'])
            
            for metric in all_metrics:
                all_metrics[metric][model_name].extend(fold_metrics[metric])
    
    # 5. Prepare metrics for plotting
    plot_metrics = {
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1': []
    }
    
    for metric in plot_metrics:
        for model in all_models:
            plot_metrics[metric].append(all_metrics[metric][model])
    
    # 6. Create and save plot
    fig = plot_metrics_comparison(dataset, plot_metrics, all_models)
    plt.savefig(f'{dataset}_all_models_comparison.png', bbox_inches='tight', dpi=400)
    plt.close()
    
    # 7. Print numerical results
    print(f"\nAll Model Results for {dataset}:")
    for metric in plot_metrics:
        print(f"\n{metric}:")
        for name, values in zip(all_models, plot_metrics[metric]):
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f"{name}: {mean_val:.4f} ± {std_val:.4f}")

    # 8. Print final Local Correction Tree structure
    if 'Ours' in enhanced_models_dict:
        print("\nFinal Local Correction Tree Structure:")
        enhanced_models_dict['Ours'].simplify()
        print_tree(enhanced_models_dict['Ours'], feature_names=feature_names)

print("\nDone evaluating all models.")