In [5]:
# --- 🧠 Visualization Agent Testing Notebook ---

# ✅ Step 0: Setup Environment (Run this first)
import os
import sys
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Ensure the project root is in the path
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())

# ✅ Step 1: Imports
from agent.nodes import (
    visualization_specification_agent,
    visualization_rendering_agent,
    _render_chart  # Import the render function for direct testing
)
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# ✅ Step 2: Mock Analysis Results (Simulating output from analysis_computation_agent)
# This simulates what the analysis_computation_agent would return

# Create a more complete mock that matches what the agent actually produces
mock_computed_metrics = {
    "rmse": {
        "Random Forest": 42.5,
        "XGBoost": 38.2,
        "LightGBM": 40.1
    },
    "r2_score": {
        "Random Forest": 0.78,
        "XGBoost": 0.82,
        "LightGBM": 0.80
    },
    "mae": {
        "Random Forest": 35.2,
        "XGBoost": 31.8,
        "LightGBM": 33.5
    }
}

# Create raw data format (simulating SQL query results)
mock_raw_data = [
    {
        "success": True,
        "data": [
            {"model_name": "Random Forest", "metric_name": "rmse", "metric_value": 42.5},
            {"model_name": "Random Forest", "metric_name": "r2_score", "metric_value": 0.78},
            {"model_name": "Random Forest", "metric_name": "mae", "metric_value": 35.2},
            {"model_name": "XGBoost", "metric_name": "rmse", "metric_value": 38.2},
            {"model_name": "XGBoost", "metric_name": "r2_score", "metric_value": 0.82},
            {"model_name": "XGBoost", "metric_name": "mae", "metric_value": 31.8},
            {"model_name": "LightGBM", "metric_name": "rmse", "metric_value": 40.1},
            {"model_name": "LightGBM", "metric_name": "r2_score", "metric_value": 0.80},
            {"model_name": "LightGBM", "metric_name": "mae", "metric_value": 33.5},
        ],
        "row_count": 9,
        "columns": ["model_name", "metric_name", "metric_value"]
    }
]

mock_analysis_results = {
    "computed_metrics": mock_computed_metrics,
    "raw_data": mock_raw_data,
    "patterns": [
        "XGBoost shows best performance across all metrics",
        "Random Forest has highest RMSE indicating larger prediction errors",
        "All models show good R² scores above 0.75"
    ],
    "anomalies": [],
    "statistical_summary": {
        "mean_rmse": 40.3,
        "mean_r2": 0.80,
        "mean_mae": 33.5
    },
    "data_row_count": 9
}

# ✅ Step 3: Initialize State
state = {
    "user_query": "Compare Random Forest, XGBoost, and LightGBM for NRx forecasting performance",
    "comparison_type": "performance",
    "requires_visualization": True,
    "analysis_results": mock_analysis_results,
    "execution_path": [],
    "models_requested": ["Random Forest", "XGBoost", "LightGBM"],
    "use_case": "NRx_forecasting"
}

print("="*80)
print("🧪 VISUALIZATION AGENT TESTING")
print("="*80)
print(f"\n📋 Initial State:")
print(f"   - User Query: {state['user_query']}")
print(f"   - Comparison Type: {state['comparison_type']}")
print(f"   - Requires Visualization: {state['requires_visualization']}")
print(f"   - Analysis Results Keys: {list(state['analysis_results'].keys())}")
print(f"   - Data Row Count: {state['analysis_results']['data_row_count']}")

# ✅ Step 4: Run Visualization Specification Agent
print("\n" + "="*80)
print("🎨 STEP 1: GENERATING VISUALIZATION SPECIFICATIONS")
print("="*80)

try:
    state_after_spec = visualization_specification_agent(state)
    state.update(state_after_spec)
    
    viz_specs = state.get("visualization_specs", [])
    
    if viz_specs:
        print(f"\n✅ Generated {len(viz_specs)} Visualization Specification(s):\n")
        for i, spec in enumerate(viz_specs, 1):
            print(f"📊 Spec {i}:")
            print(f"   - Type: {spec.get('type')}")
            print(f"   - Title: {spec.get('title')}")
            print(f"   - Data Key: {spec.get('data_key')}")
            print(f"   - X-axis: {spec.get('x')}")
            print(f"   - Y-axis: {spec.get('y')}")
            print(f"   - Additional Params: {spec.get('additional_params')}")
            print()
    else:
        print("⚠️ No visualization specs generated!")
        print("This might mean:")
        print("  - requires_visualization was False")
        print("  - LLM failed to generate specs")
        print("  - Error in visualization_specification_agent")
        
except Exception as e:
    print(f"\n❌ Error in visualization_specification_agent:")
    print(f"   {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()

# ✅ Step 5: Run Visualization Rendering Agent
print("\n" + "="*80)
print("📊 STEP 2: RENDERING CHARTS")
print("="*80)

try:
    state_after_render = visualization_rendering_agent(state)
    state.update(state_after_render)
    
    rendered_charts = state.get("rendered_charts", [])
    
    if rendered_charts:
        print(f"\n✅ {len(rendered_charts)} Chart(s) Rendered Successfully!\n")
    else:
        print("\n⚠️ No charts were rendered.")
        print("\nDebugging Info:")
        print(f"  - visualization_specs exists: {bool(state.get('visualization_specs'))}")
        print(f"  - Number of specs: {len(state.get('visualization_specs', []))}")
        print(f"  - analysis_results exists: {bool(state.get('analysis_results'))}")
        
except Exception as e:
    print(f"\n❌ Error in visualization_rendering_agent:")
    print(f"   {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()

# ✅ Step 6: Display the Charts
print("\n" + "="*80)
print("🖼️ STEP 3: DISPLAYING CHARTS")
print("="*80 + "\n")

rendered_charts = state.get("rendered_charts", [])

if rendered_charts:
    for i, chart_data in enumerate(rendered_charts, 1):
        title = chart_data.get("title", f"Chart {i}")
        chart_type = chart_data.get("type", "unknown")
        figure = chart_data.get("figure")
        
        print(f"{'='*80}")
        print(f"📈 Chart {i}: {title} ({chart_type})")
        print(f"{'='*80}\n")
        
        if figure:
            try:
                # In Jupyter, this will display inline
                figure.show()
                print(f"\n✅ Chart displayed successfully\n")
            except Exception as e:
                print(f"❌ Error displaying chart: {e}\n")
        else:
            print("❌ No figure object found\n")
else:
    print("⚠️ No charts to display.")
    print("\nTroubleshooting steps:")
    print("1. Check if visualization_specs were generated")
    print("2. Check if analysis_results has the correct data structure")
    print("3. Look for errors in the rendering step above")

# ✅ Step 7: Manual Test - Create Chart Directly
print("\n" + "="*80)
print("🔧 STEP 4: MANUAL CHART TEST (Bypass Agents)")
print("="*80 + "\n")

print("Creating a simple bar chart directly from mock data...\n")

try:
    # Extract data from mock_raw_data
    data_rows = mock_raw_data[0]['data']
    df = pd.DataFrame(data_rows)
    
    print(f"DataFrame shape: {df.shape}")
    print(f"DataFrame columns: {df.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df.head())
    
    # Create a simple bar chart
    fig = px.bar(
        df[df['metric_name'] == 'rmse'],
        x='model_name',
        y='metric_value',
        title='Model Performance Comparison (RMSE)',
        labels={'metric_value': 'RMSE', 'model_name': 'Model'},
        color='model_name'
    )
    
    fig.update_layout(
        template="plotly_white",
        height=400,
        showlegend=True
    )
    
    print("\n✅ Manual chart created successfully!")
    fig.show()
    
except Exception as e:
    print(f"❌ Error creating manual chart: {e}")
    import traceback
    traceback.print_exc()

# ✅ Step 8: Summary
print("\n" + "="*80)
print("📊 TEST SUMMARY")
print("="*80 + "\n")

print(f"✅ Visualization Specs Generated: {len(state.get('visualization_specs', []))}")
print(f"✅ Charts Rendered: {len(state.get('rendered_charts', []))}")
print(f"✅ Execution Path: {' → '.join(state.get('execution_path', []))}")

if not state.get('rendered_charts'):
    print("\n⚠️ TROUBLESHOOTING NEEDED")
    print("\nPossible issues:")
    print("1. LLM (Gemini) not generating proper visualization specs")
    print("2. Data format mismatch in _render_chart function")
    print("3. Missing GEMINI_API_KEY environment variable")
    print("4. Analysis results structure doesn't match expected format")
    print("\nCheck the debug output above for specific error messages.")
else:
    print("\n✅ ALL TESTS PASSED!")
    print("Your visualization pipeline is working correctly.")

print("\n" + "="*80)



🧪 VISUALIZATION AGENT TESTING

📋 Initial State:
   - User Query: Compare Random Forest, XGBoost, and LightGBM for NRx forecasting performance
   - Comparison Type: performance
   - Requires Visualization: True
   - Analysis Results Keys: ['computed_metrics', 'raw_data', 'patterns', 'anomalies', 'statistical_summary', 'data_row_count']
   - Data Row Count: 9

🎨 STEP 1: GENERATING VISUALIZATION SPECIFICATIONS

✅ Generated 3 Visualization Specification(s):

📊 Spec 1:
   - Type: None
   - Title: None
   - Data Key: None
   - X-axis: None
   - Y-axis: None
   - Additional Params: None

📊 Spec 2:
   - Type: None
   - Title: None
   - Data Key: None
   - X-axis: None
   - Y-axis: None
   - Additional Params: None

📊 Spec 3:
   - Type: None
   - Title: None
   - Data Key: None
   - X-axis: None
   - Y-axis: None
   - Additional Params: None


📊 STEP 2: RENDERING CHARTS
Rendering chart: None
Data key: raw_data
Data type: <class 'list'>
DataFrame shape: (9, 3)
DataFrame columns: ['model_name', '

Traceback (most recent call last):
  File "C:\Users\Admin\AppData\Local\Temp\ipykernel_15176\447822257.py", line 233, in <module>
    fig.show()
  File "c:\Users\Admin\Documents\GitHub\analytics_chatbot\.venv\Lib\site-packages\plotly\basedatatypes.py", line 3420, in show
    return pio.show(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Admin\Documents\GitHub\analytics_chatbot\.venv\Lib\site-packages\plotly\io\_renderers.py", line 415, in show
    raise ValueError(
ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed
