In [5]:
import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from dotenv import load_dotenv
from pandas.plotting import parallel_coordinates
import importlib
import plotly.express as px
import plotly.graph_objects as go
import os
import glob
import pandas as pd
import json

import utils.db_tools as db_tools
from utils.db_tools import (
    get_db,
    filter_df,
    make_animation,
    get_data,
    metrics_grid,
    plot_grid,
    compute_metrics
)

from classify import classify_trajectories

importlib.reload(db_tools)

<module 'utils.db_tools' from '/cluster/home/vogtva/pde-solvers-cuda/analysis/utils/db_tools.py'>

In [17]:
model = "bruss"
run_id = "ball_sampling2"
load_dotenv()
data_dir = os.getenv("DATA_DIR")
output_dir = os.getenv("OUT_DIR")
df = pd.read_csv(f"{output_dir}/{model}/{run_id}/classification_metrics.csv")
df_class = classify_trajectories(
    df, steady_threshold=1, osc_threshold=1.28, dev_threshold=1.28
)
# df_class.value_counts("category")
# df = df_class.copy()
df = get_db(os.path.join(data_dir, model, run_id))
# df = df[df["filename"].apply(os.path.exists)].reset_index(drop=True)
df["op"] = df["original_point"].astype(str)

In [7]:
# plt.figure(figsize=(10, 6))
# sns.scatterplot(x=df_class['A'], y=df_class['B'], hue=df_class["category"])
# plt.xlabel('A')
# plt.ylabel('B')
# plt.title('Scatter plot of A vs B for Sampling Centers')
# plt.show()

fig = px.scatter(
    df_class,
    x="A",
    y="B",
    color="category",
    title="Scatter plot of A vs B",
    labels={"A": "A", "B": "B"},
    width=800,
    height=800,
)

# Display the plot in the notebook
fig.show()


In [8]:
df_class.value_counts("category")

category
SS    378
I     246
BU      1
Name: count, dtype: int64

In [9]:
df_class["op"] = df_class["original_point"].astype(str)
for _, df1 in df_class.groupby("op"):
    original_point = df1.iloc[0]["original_point"]
    print(original_point, df1.value_counts("category").to_dict())

{'A': 0.5, 'B': 0.875, 'Du': 3.0, 'Dv': 48.0} {'SS': 25}
{'A': 0.5, 'B': 1.0, 'Du': 3.0, 'Dv': 24.0} {'SS': 25}
{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 11.0} {'SS': 25}
{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 14.0} {'SS': 24, 'BU': 1}
{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 4.0} {'SS': 25}
{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 8.0} {'SS': 25}
{'A': 0.5, 'B': 1.25, 'Du': 2.0, 'Dv': 32.0} {'SS': 25}
{'A': 0.75, 'B': 0.9375, 'Du': 1.0, 'Dv': 11.0} {'SS': 25}
{'A': 0.75, 'B': 1.5, 'Du': 1.0, 'Dv': 14.0} {'SS': 25}
{'A': 0.75, 'B': 1.5, 'Du': 2.0, 'Dv': 22.0} {'SS': 25}
{'A': 0.75, 'B': 1.875, 'Du': 3.0, 'Dv': 48.0} {'I': 25}
{'A': 1.0, 'B': 1.25, 'Du': 2.0, 'Dv': 8.0} {'SS': 25}
{'A': 1.0, 'B': 1.25, 'Du': 3.0, 'Dv': 33.0} {'SS': 25}
{'A': 1.0, 'B': 1.75, 'Du': 1.0, 'Dv': 18.0} {'SS': 25}
{'A': 1.0, 'B': 3.0, 'Du': 2.0, 'Dv': 22.0} {'I': 25}
{'A': 1.5, 'B': 2.625, 'Du': 1.0, 'Dv': 14.0} {'I': 21, 'SS': 4}
{'A': 1.5, 'B': 3.0, 'Du': 2.0, 'Dv': 16.0} {'SS': 25}
{'A': 1.5, 'B': 4.5, 'Du': 3.0, 

In [34]:
df_filt = df[df.op == "{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 14.0}"]
for i, row in df_filt.iterrows():
    print(row["Nt"])

60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000
60000


In [14]:
t = np.linspace(0, 100, 100)
for _, df1 in df.groupby("op"):
    original_point = df1.iloc[0]["original_point"]
    print(original_point)
    all_metrics = []

    # Collect metrics for all rows in group
    for _, row in df1.iterrows():
        d = get_data(row)
        dev, dt, dx = compute_metrics(row, 0)
        all_metrics.append(dev)

    # Convert to numpy array for easier computation
    all_metrics = np.array(all_metrics)

    # Compute mean and std
    avg_metric = np.mean(all_metrics, axis=0)
    std_metric = np.std(all_metrics, axis=0)

    # Create figure
    fig = go.Figure()

    # Add shaded area for standard deviation
    fig.add_trace(
        go.Scatter(
            x=np.concatenate([t, t[::-1]]),
            y=np.concatenate(
                [avg_metric + std_metric, (avg_metric - std_metric)[::-1]]
            ),
            fill="toself",
            fillcolor="rgba(0,100,80,0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            showlegend=False,
        )
    )

    # Add mean line
    fig.add_trace(
        go.Scatter(
            x=t,
            y=avg_metric,
            mode="lines",
            name="Average Deviation",
            hovertemplate="Index: %{x}<br>Deviation: %{y:.2f}<extra></extra>",
        )
    )

    # Update layout
    fig.update_layout(
        title=f"Deviation Metrics for Original Point: {original_point}",
        xaxis_title="Time Step/Index",
        yaxis_title="Deviation Value",
        hovermode="x unified",
        showlegend=True,
        template="plotly_white",
    )

    fig.show()


{'A': 0.5, 'B': 0.875, 'Du': 3.0, 'Dv': 48.0}


{'A': 0.5, 'B': 1.0, 'Du': 3.0, 'Dv': 24.0}


{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 11.0}


{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 14.0}



overflow encountered in divide


invalid value encountered in subtract



{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 4.0}


{'A': 0.5, 'B': 1.25, 'Du': 1.0, 'Dv': 8.0}


{'A': 0.5, 'B': 1.25, 'Du': 2.0, 'Dv': 32.0}


{'A': 0.75, 'B': 0.9375, 'Du': 1.0, 'Dv': 11.0}


{'A': 0.75, 'B': 1.5, 'Du': 1.0, 'Dv': 14.0}


{'A': 0.75, 'B': 1.5, 'Du': 2.0, 'Dv': 22.0}


{'A': 0.75, 'B': 1.875, 'Du': 3.0, 'Dv': 48.0}


{'A': 1.0, 'B': 1.25, 'Du': 2.0, 'Dv': 8.0}


{'A': 1.0, 'B': 1.25, 'Du': 3.0, 'Dv': 33.0}


{'A': 1.0, 'B': 1.75, 'Du': 1.0, 'Dv': 18.0}


{'A': 1.0, 'B': 3.0, 'Du': 2.0, 'Dv': 22.0}


{'A': 1.5, 'B': 2.625, 'Du': 1.0, 'Dv': 14.0}


{'A': 1.5, 'B': 3.0, 'Du': 2.0, 'Dv': 16.0}


{'A': 1.5, 'B': 4.5, 'Du': 3.0, 'Dv': 12.0}


{'A': 2.0, 'B': 4.0, 'Du': 1.0, 'Dv': 11.0}


{'A': 2.0, 'B': 5.0, 'Du': 2.0, 'Dv': 16.0}


{'A': 2.0, 'B': 6.0, 'Du': 2.0, 'Dv': 28.0}


{'A': 2.0, 'B': 6.0, 'Du': 2.0, 'Dv': 32.0}


{'A': 3.0, 'B': 9.0, 'Du': 2.0, 'Dv': 16.0}


{'A': 3.0, 'B': 9.0, 'Du': 2.0, 'Dv': 32.0}


{'A': 5.0, 'B': 6.25, 'Du': 2.0, 'Dv': 8.0}
