In [11]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
from collections import defaultdict


In [12]:
sepsis = [
    "C://Users//nikol//MT-repo//logs//train//runs//2025-09-28_18-52-04//baseline_cvae_rerun_test//metrics.json",
    "C://Users//nikol//MT-repo//logs//eval//runs//2025-09-25_10-48-59//baseline_cvae_rerun_test//metrics.json",
    "C://Users//nikol//MT-repo//logs//train//runs//2025-09-26_20-07-38//baseline_cvae_rerun_test//metrics.json"
]

bpic2012 = [
    "C://Users//nikol//MT-repo//logs//eval//runs//2025-10-03_17-54-49//baseline_cvae_rerun_test//metrics.json",
    "C://Users//nikol//MT-repo//logs//eval//runs//2025-10-15_14-03-31//baseline_cvae_rerun_test//metrics.json",
    "C://Users//nikol//MT-repo//logs//eval//runs//2025-10-15_19-20-58//baseline_cvae_rerun_test//metrics.json"
]

emergency = [
    "C://Users//nikol//MT-repo//logs//train//runs//2025-11-15_21-10-43//baseline_cvae_rerun_test//metrics.json",
    "C://Users//nikol//MT-repo//logs//eval//runs//2025-11-15_17-32-17//baseline_cvae_rerun_test//metrics.json",
    "C://Users//nikol//MT-repo//logs//train//runs//2025-11-14_21-18-42//baseline_cvae_rerun_test//metrics.json"

]

model_paths = emergency
dataset_name = "EMERGENCY"
model_names = ["cvae", "TF decoder", "TF decoder (beam)"]

In [13]:
metrics = {}
for name, path in zip(model_names, model_paths):
        try:
            with open(path, 'r') as f:
                metrics[name] = json.load(f)
        except Exception as e:
            print(f"Failed to load {path}: {e}")
            metrics[name] = None

# Example usage:
cvae_data = metrics.get('cvae', {})
cvaev2_data = metrics.get('cvaev2', {})


In [14]:
def expand_chsic_per_activity(model_metrics):
    """Expand nested 'chsic_per_activity' dict into flat chsic_act_<activity> metrics per file."""
    if not isinstance(model_metrics, dict):
        return
    for file_name, metric_dict in model_metrics.items():
        if not isinstance(metric_dict, dict):
            continue
        per_act = metric_dict.get('chsic_per_activity')
        if isinstance(per_act, dict):
            for act, val in per_act.items():
                # Only add numeric values
                if isinstance(val, (int, float)):
                    metric_dict[f'chsic_act_{act}'] = val
            # Optionally keep or remove the original aggregated dict; keep for reference
            # del metric_dict['chsic_per_activity']  # uncomment to remove original

def reshape(per_file_metrics):
    by_metric = defaultdict(lambda: defaultdict(list))
    for file_name, metric_dict in per_file_metrics.items():
        if not isinstance(metric_dict, dict):
            continue
        for metric, val in metric_dict.items():
            if val is None:
                continue
            # Only accept scalar numeric values for plotting
            if not isinstance(val, (int, float)):
                continue
            if file_name.startswith("gen"):
                category = "GEN"
            elif file_name.startswith("train"):
                category = "Train"
            else:
                category = file_name
            by_metric[metric][category].append(val)
    return by_metric

# First expand cHSIC per-activity metrics into flat metrics for each model
for model_name in model_names:
    if model_name in metrics and isinstance(metrics[model_name], dict):
        expand_chsic_per_activity(metrics[model_name])

cvae_data = reshape(metrics.get('cvae', {}))
cvaev2_data = reshape(metrics.get('cvaev2', {}))

# Collect all metric names from both models
all_metric_names = set(cvae_data.keys()) | set(cvaev2_data.keys())

boxplot_data = defaultdict(dict)
for metric in all_metric_names:
    boxplot_data[metric] = {}
    
for model_name in model_names:
    model_data = metrics.get(model_name, {})
    reshaped_data = reshape(model_data)
    for metric in all_metric_names:
        if model_name == 'cvae' and 'Train' in reshaped_data.get(metric, {}):
            boxplot_data[metric]['Train'] = reshaped_data[metric]['Train']
        
        if 'GEN' in reshaped_data.get(metric, {}):
            boxplot_data[metric][model_name] = reshaped_data[metric]['GEN']


In [15]:
def plot_boxplots(metrics, output_path='output', output_filename='results.png', data_name=dataset_name):
    """
    Plot boxplots for the given metrics.

    Parameters:
      metrics (dict): A dictionary containing the metrics to plot. The keys represent the metric names, and the values are dictionaries with method names as keys (e.g. TEST1, GEN) and a list of metric values as values.
      output_path (str): The path where the output file will be saved. Default is 'output'.
      output_filename (str): The name of the output file. Default is 'results.png'.
    """
    if not metrics:
        return
    os.makedirs(output_path, exist_ok=True)
    metrics = dict(metrics)
    metrics = {k: dict(v) for k, v in metrics.items()}
    default_figsize = plt.rcParams.get('figure.figsize')
    
    # Arrange up to 3 metrics per row
    n_metrics = len(metrics)
    ncols = 3 if n_metrics > 3 else 1
    nrows = int(np.ceil(n_metrics / ncols)) if n_metrics else 1
    fig, axs = plt.subplots(nrows, ncols, figsize=(max(1, default_figsize[0]*ncols), default_figsize[1]*nrows))
    axs = np.atleast_1d(axs).ravel()
    
    # Define consistent color mapping for categories
    color_mapping = {
        'Train': '#1f77b4',  # Blue
        'CVAE': '#ff7f0e',    # Orange
        'TF decoder': '#2ca02c',   # Green
        'VAL': '#d62728',    # Red
        'TF decoder (beam)': '#17becf',  # Purple
        'TEST2': '#8c564b',  # Brown
        'TEST3': '#e377c2',  # Pink
    }
    
    # Fallback colors for any unexpected categories
    fallback_colors = ['#17becf', '#bcbd22', '#7f7f7f', '#ff9896', '#c5b0d5', '#c49c94']
    
    for ax, (metric_name, metric_values) in zip(axs, metrics.items()):
      labels = list(metric_values.keys())
      if 'cvae' in labels:
          labels[labels.index('cvae')] = "CVAE"
      data = metric_values.values()
      if not labels:
          continue
      boxplot = ax.boxplot(data, patch_artist=True, medianprops={'color': 'black', 'linewidth': 1})
      
# Assign colors based on category labels
      colors = []
      fallback_idx = 0
      for label in labels:
          if label in color_mapping:
              colors.append(color_mapping[label])
          else:
              colors.append(fallback_colors[fallback_idx % len(fallback_colors)])
              fallback_idx += 1
        
      for patch, color in zip(boxplot['boxes'], colors):
          patch.set_facecolor(color)
          patch.set_alpha(0.7)
        
      ax.set_xticks(range(1, len(labels) + 1))
      if output_filename.startswith("hsic"):
          labels[0] = "Test"
      
      ax.set_xticklabels(labels, rotation=45, ha='right')
      metric_name = metric_name.replace('_', ' ')
      
      if metric_name == 'chsic uniform':
          metric_name = "CHSIC"
      title = f"{metric_name.upper()} - {data_name.upper()}"
      ax.set_title(title)
      ax.grid(True, alpha=0.3)
    
    # Hide any unused subplots
    for ax in axs[len(metrics):]:
        ax.axis('off')

    output_filename = f"{data_name.lower()}_{output_filename}"
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, output_filename), dpi=300, bbox_inches='tight')
    plt.clf()
    plt.figure(figsize=default_figsize)
    plt.close(fig)
    
# --- Group classic metrics ---
control_flow_metrics = {m: boxplot_data[m] for m in ['cfld','2gram'] if m in boxplot_data}
time_metrics = {m: boxplot_data[m] for m in ['red', 'ctd'] if m in boxplot_data}
conformance = {m: boxplot_data[m] for m in ['conformance'] if m in boxplot_data}
cwd_metrics = {m: boxplot_data[m] for m in ['cwd'] if m in boxplot_data}

# --- HSIC aggregates (if present) ---
hsic_uniform = {m: boxplot_data[m] for m in ['chsic_uniform'] if m in boxplot_data}
hsic_weighted = {m: boxplot_data[m] for m in ['chsic_freq_weighted'] if m in boxplot_data}

# --- HSIC per-activity metrics ---
hsic_activity_metrics = {m: boxplot_data[m] for m in boxplot_data if m.startswith('chsic_act_')}


# Plot groups
plot_boxplots(control_flow_metrics, output_path=f"../output/{dataset_name.lower()}", output_filename='control_flow_metrics.pdf')
plot_boxplots(time_metrics, output_path=f"../output/{dataset_name.lower()}", output_filename='time_metrics.pdf')
plot_boxplots(conformance, output_path=f"../output/{dataset_name.lower()}", output_filename='conformance.pdf')
plot_boxplots(cwd_metrics, output_path=f"../output/{dataset_name.lower()}", output_filename='cwd_metrics.pdf')
plot_boxplots(hsic_uniform, output_path=f"../output/{dataset_name.lower()}", output_filename='hsic.pdf')
if len(hsic_activity_metrics) > 15:
    def split_dict_sorted_by_key(d: dict):
        items = sorted(d.items(), key=lambda kv: kv[0])
        mid = len(items) // 2
        return dict(items[:mid]), dict(items[mid:])
    hsic_act_part1, hsic_act_part2 = split_dict_sorted_by_key(hsic_activity_metrics)
    plot_boxplots(hsic_act_part1, output_path=f"../output/{dataset_name.lower()}", output_filename='hsic_per_activity_part1.pdf')
    plot_boxplots(hsic_act_part2, output_path=f"../output/{dataset_name.lower()}", output_filename='hsic_per_activity_part2.pdf')
else:
    plot_boxplots(hsic_activity_metrics, output_path=f"../output/{dataset_name.lower()}", output_filename='hsic_per_activity.pdf')

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>