In [1]:
import os
import re
import pandas as pd
from pathlib import Path

def parse_classification_report(file_path):
    """Parse classification_report.txt file and extract key metrics"""
    with open(file_path, 'r') as f:
        content = f.read()
    
    # Extract accuracy
    accuracy_match = re.search(r'accuracy\s+(\d+\.\d+)', content)
    accuracy = float(accuracy_match.group(1)) if accuracy_match else None
    
    # Extract macro avg precision, recall, and f1-score
    macro_avg_match = re.search(r'macro avg\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)', content)
    macro_precision = float(macro_avg_match.group(1)) if macro_avg_match else None
    macro_recall = float(macro_avg_match.group(2)) if macro_avg_match else None
    macro_f1 = float(macro_avg_match.group(3)) if macro_avg_match else None
    
    return accuracy, macro_precision, macro_recall, macro_f1

def compare_folds(base_path=".", model_name="YourModel"):
    """Compare results across all folds"""
    results = []
    
    # Process each fold
    for fold_num in range(1, 11):  # fold_1 to fold_10
        fold_dir = Path(base_path) / f"fold_{fold_num}"
        report_file = fold_dir / "classification_report.txt"
        
        if report_file.exists():
            try:
                accuracy, macro_precision, macro_recall, macro_f1 = parse_classification_report(report_file)
                
                results.append({
                    'model_name': model_name,
                    'fold_number': fold_num,
                    'accuracy': accuracy,
                    'macro_precision': macro_precision,
                    'macro_recall': macro_recall,
                    'f1_score_macro_avg': macro_f1
                })
                
                print(f"Processed fold_{fold_num}: Accuracy={accuracy:.4f}, Macro Precision={macro_precision:.4f}, Macro Recall={macro_recall:.4f}, Macro F1={macro_f1:.4f}")
                
            except Exception as e:
                print(f"Error processing fold_{fold_num}: {e}")
        else:
            print(f"Warning: {report_file} not found")
    
    return results

def generate_comparison_report(results):
    """Generate comparison report and statistics"""
    if not results:
        print("No results to analyze")
        return
    
    df = pd.DataFrame(results)

    # Calculate averages for all metrics
    avg_accuracy = df['accuracy'].mean()
    avg_macro_precision = df['macro_precision'].mean()
    avg_macro_recall = df['macro_recall'].mean()
    avg_macro_f1 = df['f1_score_macro_avg'].mean()

    # Add a new row with the average of all metrics
    avg_row = pd.DataFrame([{
        'model_name': 'Average',
        'fold_number': 'Average',
        'accuracy': avg_accuracy,
        'macro_precision': avg_macro_precision,
        'macro_recall': avg_macro_recall,
        'f1_score_macro_avg': avg_macro_f1
    }])

    # Concatenate the average row to the existing dataframe
    df = pd.concat([df, avg_row], ignore_index=True)

    # Print the averages

    print(f"Average Macro Precision: {avg_macro_precision:.4f}")
    print(f"Average Macro Recall: {avg_macro_recall:.4f}")
    print(f"Average Macro F1-Score: {avg_macro_f1:.4f}")
    print(f"\nAverage Accuracy: {avg_accuracy:.4f}")
    
    # Save results to CSV
    output_file = "fold_comparison_results.csv"
    df.to_csv(output_file, index=False, float_format='%.4f')
    print(f"\nResults saved to: {output_file}")
    
    return df



def main():
    # Configuration
    BASE_PATH = "results"  # Current directory, change if needed
    MODEL_NAME = "LEGALBERT_GNN_PROCESSED_UCREAT"  # Change to your actual model name
    
    print("Starting fold comparison analysis...")
    
    # Parse all fold results
    results = compare_folds(BASE_PATH, MODEL_NAME)
    
    if results:
        # Generate comparison report
        df = generate_comparison_report(results)
        
        
    
    else:
        print("No valid results found. Please check your file paths and formats.")

if __name__ == "__main__":
    main()


Starting fold comparison analysis...
Processed fold_1: Accuracy=0.9119, Macro Precision=0.8743, Macro Recall=0.9410, Macro F1=0.9038
Processed fold_2: Accuracy=0.9021, Macro Precision=0.8381, Macro Recall=0.9102, Macro F1=0.8681
Processed fold_3: Accuracy=0.8160, Macro Precision=0.7367, Macro Recall=0.7802, Macro F1=0.7357
Processed fold_4: Accuracy=0.8398, Macro Precision=0.7870, Macro Recall=0.8960, Macro F1=0.8183
Processed fold_5: Accuracy=0.8494, Macro Precision=0.7894, Macro Recall=0.9235, Macro F1=0.8331
Processed fold_6: Accuracy=0.8170, Macro Precision=0.8825, Macro Recall=0.7006, Macro F1=0.7362
Processed fold_7: Accuracy=0.8812, Macro Precision=0.8270, Macro Recall=0.9295, Macro F1=0.8663
Processed fold_8: Accuracy=0.9313, Macro Precision=0.6140, Macro Recall=0.6305, Macro F1=0.6190
Processed fold_9: Accuracy=0.8580, Macro Precision=0.8027, Macro Recall=0.9205, Macro F1=0.8470
Processed fold_10: Accuracy=0.8529, Macro Precision=0.8209, Macro Recall=0.9034, Macro F1=0.8440
Av