In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score
from scipy.stats import sem # Standard error of the mean

def evaluate_models_label_shift_with_uncertainty(model, og, test_datasets, n_experiments=10):
    # Initialize lists to hold metrics
    metrics = {
        'accuracies': [], 'precisions': [], 'recalls': [],
        'og_accuracies': [], 'og_precisions': [], 'og_recalls': [],
        'ratios': []
    }

    # Iterate over each test dataset
    for ratio, test_dataset in test_datasets.items():
        print('Test Bias Ratio: ', ratio)
        X_test, y_test = zip(*[(x.detach().numpy().flatten(), y) for x, y in test_dataset])
        X_test = np.array(X_test)
        y_test = np.array(y_test)

        # Run multiple experiments to gather statistics
        exp_metrics = {
            'accuracies': [], 'precisions': [], 'recalls': [],
            'og_accuracies': [], 'og_precisions': [], 'og_recalls': []
        }

        for _ in range(n_experiments):
            # Model predictions
            predictions = model.predict_biased_label_shift(X_test, [0.1] * 10, [ratio] + [(1-ratio)/9]*9)
            exp_metrics['accuracies'].append(accuracy_score(y_test, predictions))
            exp_metrics['precisions'].append(precision_score(y_test, predictions, average='macro'))
            exp_metrics['recalls'].append(recall_score(y_test, predictions, average='macro'))

            # Standard RandomForest predictions
            og_predictions = og.predict(X_test)
            exp_metrics['og_accuracies'].append(accuracy_score(y_test, og_predictions))
            exp_metrics['og_precisions'].append(precision_score(y_test, og_predictions, average='macro'))
            exp_metrics['og_recalls'].append(recall_score(y_test, og_predictions, average='macro'))

        # Calculate means and standard errors
        for key in exp_metrics:
            if key in metrics:
                metrics[key].append(np.mean(exp_metrics[key]))
            else:
                metrics[key] = [np.mean(exp_metrics[key])]
            if key + '_err' in metrics:
                metrics[key + '_err'].append(sem(exp_metrics[key]))
            else:
                metrics[key + '_err'] = [sem(exp_metrics[key])]

        metrics['ratios'].append(ratio)
    
    print(metrics)

    # Plotting the results with uncertainty regions
    def plot_with_uncertainty(y_values, y_err, label, color):
        plt.plot(metrics['ratios'], y_values, label=label, color=color)
        plt.fill_between(metrics['ratios'], np.array(y_values) - np.array(y_err), np.array(y_values) + np.array(y_err), color=color, alpha=0.2)

    for metric in ['accuracies', 'precisions', 'recalls']:
        plt.figure(figsize=(10, 6))
        plot_with_uncertainty(metrics[metric], metrics[metric + '_err'], 'Label Shift ' + metric.capitalize(), 'red')
        plot_with_uncertainty(metrics['og_' + metric], metrics['og_' + metric + '_err'], 'RF ' + metric.capitalize(), 'blue')
        plt.xlabel('Fraction of 0')
        plt.ylabel('Performance Metrics')
        plt.title('Performance Metrics for Different Sample Ratios with Uncertainty')
        plt.legend()
        plt.grid(True)
        plt.show()