In [9]:
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import math

def plot_sembas_results(file, num):
    """
    Simple 3D scatter plot using Plotly - automatically loads results.json
    """
    # Load data from JSON file
    with open(file, 'r') as f:
        data = json.load(f)
    
    coords = data[num]["requests"]
    results = data[num]["results"]
    description = data[num]["description"]

    # Extract coordinates
    x = [pt[0] for pt in coords]
    y = [pt[1] for pt in coords]
    z = [pt[2] for pt in coords]
    
    # Create color mapping (0 = red/failed, 1 = blue/successful)
    colors = [1 if res else 0 for res in results]
    
    # Create 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode='markers',
        marker=dict(
            size=10,
            color=colors,
            colorscale=[[0, 'red'], [1, 'blue']],
            opacity=1.0,
            line=dict(width=2, color='black')
        ),
        text=[f'Point {i+1}<br>{"Successful" if res else "Failed"}<br>({x[i]:.2f}, {y[i]:.2f}, {z[i]:.2f})'
              for i, res in enumerate(results)],
        hoverinfo='text'
    )])
    
    # Update layout
    fig.update_layout(
        title=description,
        scene=dict(
            xaxis_title='Progress Weight',
            yaxis_title='Bounds Weight',
            zaxis_title='Proximity Weight'
        ),
        width=1000,
        height=800
    )
    
    fig.show()
    # fig.write_image("images/"+str(num)+'.png')


def plot_sembas_grid_html(files_and_params, output_filename="combined_sembas_results.html"):
    """
    Create a single HTML file with multiple datasets
    
    Args:
        files_and_params: List of tuples [(file, start_index, num_plots), ...]
        output_filename: Name of the output HTML file
    """
    all_figures = []
    dataset_names = []
    
    # Process each dataset
    for file, start_index, num_plots in files_and_params:
        print(f"Processing {file}...")
        figures = plot_sembas_grid(file, start_index, num_plots, save_html=False)
        all_figures.extend(figures)
        dataset_names.extend([f"{file.split('/')[-1]} - Page {i+1}" for i in range(len(figures))])
    
    # Create combined HTML
    html_content = """
<!DOCTYPE html>
<html>
<head>
    <title>Combined Sembas Results Visualization</title>
    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
    <style>
        body { font-family: Arial, sans-serif; margin: 20px; }
        .page { margin-bottom: 50px; page-break-after: always; }
        h1 { text-align: center; color: #333; }
        h2 { color: #666; border-bottom: 2px solid #ccc; padding-bottom: 10px; }
        .plot-container { width: 100%; height: 1700px; }
        .dataset-section { margin-bottom: 60px; }
    </style>
</head>
<body>
    <h1>Combined Sembas Results Visualization</h1>
"""
    
    # Add each figure to HTML
    for i, (fig, name) in enumerate(zip(all_figures, dataset_names)):
        html_content += f"""
    <div class="dataset-section">
        <h2>{name}</h2>
        <div id="plot{i}" class="plot-container"></div>
        <script>
            Plotly.newPlot('plot{i}', {fig.to_json()});
        </script>
    </div>
"""
    
    html_content += """
</body>
</html>
"""
    
    # Save HTML file
    with open(output_filename, 'w') as f:
        f.write(html_content)
    
    print(f"Combined HTML file saved as: {output_filename}")
    return output_filename


def plot_sembas_grid(file, start_index=0, num_plots=None):
    """
    Create a grid of 3D plots using the original function logic in a for loop
    """
    # Load data from JSON file
    with open(file, 'r') as f:
        data = json.load(f)
    
    # Determine number of plots
    if num_plots is None:
        num_plots = len(data) - start_index
    
    # Group plots into pages of 4
    plots_per_page = 4
    num_pages = math.ceil(num_plots / plots_per_page)
    
    for page in range(num_pages):
        # Calculate which plots go on this page
        page_start = start_index + (page * plots_per_page)
        page_end = min(page_start + plots_per_page, start_index + num_plots)
        
        # Get descriptions for subplot titles
        subplot_titles = []
        for i in range(page_start, page_end):
            if i < len(data):
                subplot_titles.append(data[i]["description"])
        
        # Create subplot grid: 2 rows, 2 columns
        fig = make_subplots(
            rows=2, cols=2,
            specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}],
                   [{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
            subplot_titles=subplot_titles
        )
        
        # Plot positions for 2x2 grid
        positions = [(1, 1), (1, 2), (2, 1), (2, 2)]
        
        # Use original function logic in a for loop
        for idx, plot_num in enumerate(range(page_start, page_end)):
            if plot_num >= len(data):
                break
            
            # Original function logic starts here
            coords = data[plot_num]["requests"]
            results = data[plot_num]["results"]
            description = data[plot_num]["description"]
            
            # Extract coordinates
            x = [pt[0] for pt in coords]
            y = [pt[1] for pt in coords]
            z = [pt[2] for pt in coords]
            
            # Create color mapping (0 = red/failed, 1 = blue/successful)
            colors = [1 if res else 0 for res in results]
            
            # Create the same scatter plot as original function
            row, col = positions[idx]
            fig.add_trace(
                go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker=dict(
                        size=8,  # Slightly smaller for grid view
                        color=colors,
                        colorscale=[[0, 'red'], [1, 'blue']],
                        opacity=1.0,
                        line=dict(width=2, color='black')
                    ),
                    text=[f'Point {i+1}<br>{"Successful" if res else "Failed"}<br>({x[i]:.2f}, {y[i]:.2f}, {z[i]:.2f})'
                          for i, res in enumerate(results)],
                    hoverinfo='text',
                    showlegend=False
                ),
                row=row, col=col
            )
            
            # Update the scene for this subplot (same as original function)
            scene_name = f'scene{idx+1}' if idx > 0 else 'scene'
            fig.update_layout(**{
                scene_name: dict(
                    xaxis_title='Progress Weight',
                    yaxis_title='Bounds Weight',
                    zaxis_title='Proximity Weight'
                )
            })
        
        # Overall layout
        fig.update_layout(
            title_text=f"Sembas Results Grid - Page {page + 1}/{num_pages}",
            width=1400,
            height=1000
        )
        
        fig.show()

In [28]:
# file = 'paper_data/scalar_collision_vector_pass.json'
# plot_sembas_grid(file, start_index=4, num_plots=4)

# file = 'paper_data/scalar_pass.json'
# plot_sembas_grid(file, start_index=0, num_plots=4)

file = 'paper_data/bounds_test.json'
plot_sembas_grid(file, start_index=6, num_plots=8)