In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import json
import pandas as pd
import numpy as np
from pathlib import Path

def load_results(baseline_path: str, method_path: str):
    with open(baseline_path) as f:
        baseline_results = json.load(f)
    with open(method_path) as f:
        method_results = json.load(f)
    return baseline_results, method_results

def set_style():
    plt.style.use('seaborn')
    colors = sns.color_palette("husl", 4)
    sns.set_palette(colors)
    return colors

def plot_performance_comparison(baseline_results, method_results, colors):
    methods_data = []
    
    # Baseline metrics
    methods_data.extend([
        {
            'Method': 'Traditional GNN',
            'Type': 'Baseline',
            'Precision': baseline_results['baseline_gnn']['overall']['precision'],
            'Recall': baseline_results['baseline_gnn']['overall']['recall'],
            'F1': baseline_results['baseline_gnn']['overall']['f1']
        },
        {
            'Method': 'Random Forest',
            'Type': 'Baseline',
            'Precision': baseline_results['random_forest']['overall']['precision'],
            'Recall': baseline_results['random_forest']['overall']['recall'],
            'F1': baseline_results['random_forest']['overall']['f1']
        }
    ])
    
    # Our method metrics
    test_results = method_results['test_results']
    methods_data.extend([
        {
            'Method': 'Family-Level',
            'Type': 'Temporal-Symbolic Model',
            'Precision': test_results['family']['metrics']['overall']['precision'],
            'Recall': test_results['family']['metrics']['overall']['recall'],
            'F1': test_results['family']['metrics']['overall']['f1']
        },
        {
            'Method': 'Group-Level',
            'Type': 'Temporal-Symbolic Model',
            'Precision': test_results['group']['metrics']['overall']['precision'],
            'Recall': test_results['group']['metrics']['overall']['recall'],
            'F1': test_results['group']['metrics']['overall']['f1']
        }
    ])
    
    df = pd.DataFrame(methods_data)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    bar_width = 0.25
    opacity = 0.8
    
    index = np.arange(len(df['Method'].unique()))
    
    metrics = ['Precision', 'Recall', 'F1']
    for i, metric in enumerate(metrics):
        plt.bar(index + i*bar_width, 
                df[metric], 
                bar_width,
                alpha=opacity,
                color=colors[i],
                label=metric)
    
    plt.xlabel('Methods')
    plt.ylabel('Score')
    plt.title('Performance Comparison: Baselines vs Temporal-Symbolic Model')
    plt.xticks(index + bar_width, df['Method'], rotation=45)
    
    # Add type labels
    for i, method in enumerate(df['Method']):
        type_label = df[df['Method'] == method]['Type'].iloc[0]
        plt.text(i, -0.05, type_label, 
                rotation=45, ha='right', va='top', 
                transform=ax.get_xaxis_transform())
    
    plt.legend()
    plt.tight_layout()
    plt.savefig('performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_novel_detection(baseline_results, method_results, colors):
    methods_data = []
    
    # Baseline novel detection
    for method in ['isolation_forest', 'one_class_svm']:
        methods_data.append({
            'Method': method.replace('_', ' ').title(),
            'Type': 'Baseline',
            'Precision': baseline_results[method]['precision'],
            'Recall': baseline_results[method]['recall'],
            'F1': baseline_results[method]['f1']
        })
    
    # Our method novel detection
    test_results = method_results['test_results']
    for model in ['family', 'group']:
        methods_data.append({
            'Method': f'{model.title()}-Level',
            'Type': 'Temporal-Symbolic Model',
            'Precision': test_results[model]['novel_detection']['overall']['precision'],
            'Recall': test_results[model]['novel_detection']['overall']['recall'],
            'F1': test_results[model]['novel_detection']['overall']['f1']
        })
    
    df = pd.DataFrame(methods_data)
    
    plt.figure(figsize=(10, 6))
    sns.barplot(data=pd.melt(df, 
                            id_vars=['Method', 'Type'], 
                            var_name='Metric', 
                            value_name='Score'),
                x='Method', y='Score', hue='Metric', palette=colors[:3])
    
    plt.xticks(rotation=45)
    plt.title('Novel Detection Performance')
    
    # Add type labels
    ax = plt.gca()
    for i, method in enumerate(df['Method']):
        type_label = df[df['Method'] == method]['Type'].iloc[0]
        plt.text(i, -0.05, type_label, 
                rotation=45, ha='right', va='top', 
                transform=ax.get_xaxis_transform())
    
    plt.tight_layout()
    plt.savefig('novel_detection_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_training_history(method_results):
    history = method_results['training_history']
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    for model_type in ['family', 'group']:
        metrics = pd.DataFrame([{
            'epoch': data['epoch'],
            'train_acc': data['train']['accuracy'],
            'train_loss': data['train']['loss'],
            'val_acc': np.mean([m['recall'] for m in data['val']['per_class'].values() 
                              if isinstance(m, dict) and m['support'] > 0])
        } for data in history[model_type]])
        
        ax = ax1 if model_type == 'family' else ax2
        title = f"{model_type.title()}-Level Training"
        
        # Plot accuracy and loss
        ax.plot(metrics['epoch'], metrics['train_acc'], label='Train Accuracy')
        ax.plot(metrics['epoch'], metrics['val_acc'], label='Validation Accuracy')
        ax2 = ax.twinx()  # Create second y-axis
        ax2.plot(metrics['epoch'], metrics['train_loss'], 'r--', label='Train Loss')
        
        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Accuracy')
        ax2.set_ylabel('Loss')
        lines1, labels1 = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax2.legend(lines1 + lines2, labels1 + labels2)

    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300)
    plt.close()

def plot_group_distribution(method_results):
    data_stats = method_results['data_statistics']['test']['group']
    
    # Convert counts to DataFrame
    df = pd.DataFrame([
        {'group': group, 'count': count}
        for group, count in data_stats['known_groups'].items()
    ])
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df, x='group', y='count')
    plt.title('Behavioral Group Distribution')
    plt.xlabel('Group ID')
    plt.ylabel('Sample Count')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig('group_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    colors = set_style()

    baseline_results, method_results = load_results('/data/saranyav/gcn_new/baseline_results.json', '/data/saranyav/gcn_new/final_report.json')
    # Generate all plots
    plot_performance_comparison(baseline_results, method_results, colors)
    plot_novel_detection(baseline_results, method_results, colors)
    plot_training_history(method_results)
    plot_group_distribution(method_results)