In [1]:
# %% [markdown]
# # Automated Figure Generation from All TensorBoard Logs (Dissertation-Ready)
#
# This notebook reads TensorBoard event files for your specified PoL experiments,
# generates polished PDF figures (vector format) suitable for direct LaTeX
# inclusion, and stores them in the `all_generated_figures_polished/` directory.
# It includes features like Savitzky-Golay smoothing, data caching,
# canonical tag mapping, and aesthetic polishes for publication quality.
# Finally, it zips the output directory for easy downloading of all figures.
#
# **How to use:**
# 1. Run “1 · Setup”.
# 2. Verify the `all_experiment_log_dirs` mapping in Section 2.
# 3. (Optional) Adjust `SMOOTH_WIN` (Savitzky–Golay window length) in Section 3.
# 4. Run “4 · Process All Logs & Generate Plots”.
# 5. Run "5 · Zip and Download All Generated Figures".

# %% [markdown] ################################################################################
# ## 1. Setup
# ----------------------------------------------------------------------------------------------

# %% 1.1 – Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True) # Use force_remount if needed, otherwise False

# %% 1.2 – Install required libraries
# tbparse for reading TensorBoard files, feather-format for caching DataFrames.
!pip -q install tbparse feather-format

# %% 1.3 – Import necessary libraries
import pathlib
import json
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tbparse import SummaryReader
from scipy.signal import savgol_filter # For Savitzky-Golay smoothing
import matplotlib.colors as mcolors   # For colormap access
import shutil # For creating zip archives
from google.colab import files # For triggering downloads

# %% [markdown] ################################################################################
# ## 2. Define Log Directories and Output Path
# ----------------------------------------------------------------------------------------------

# %% 2.1 – Map descriptive experiment names to their log directories on Google Drive.
# !!! IMPORTANT: Verify these paths are correct and match your Drive structure !!!
all_experiment_log_dirs = {
    "Baseline (No WM)"  : pathlib.Path("/content/drive/MyDrive/SecurePoL-with-Watermarking/logs/no_watermark"),
    "Feature-Based WM"  : pathlib.Path("/content/drive/MyDrive/SecurePoL-with-Watermarking/logs/feature_based"),
    "Param-Perturb WM"  : pathlib.Path("/content/drive/MyDrive/SecurePoL-with-Watermarking/logs/param_pert"),
    "Non-Intrusive WM"  : pathlib.Path("/content/drive/MyDrive/SecurePoL-with-Watermarking/logs/non_intrusive"),
}

# %% 2.2 – Define the output directory for all generated figures (vector PDF format).
main_output_directory = pathlib.Path("all_generated_figures_polished")
main_output_directory.mkdir(parents=True, exist_ok=True)

print(f"✓ Setup complete. Log directories to be processed:")
for name, path_val in all_experiment_log_dirs.items():
    print(f"  - '{name}': {path_val} (Exists: {path_val.exists()})")
print(f"Generated PDF figures will be written to: {main_output_directory.resolve()}")

# %% [markdown] ################################################################################
# ## 3. Helper Utilities and Configuration
# ----------------------------------------------------------------------------------------------

# %% 3.1 – Canonical tag names mapped to potential aliases found in TensorBoard logs.
# This helps robustly find metrics even if naming conventions varied slightly.
TAG_MAP = {
    "val_acc"   : ["Acc/val", "Accuracy/val", "val_accuracy", "acc_val", "ValAcc"],
    "train_loss": ["Loss/train", "loss_train", "TrainLoss", "Loss/epoch"], # Added Loss/epoch as an alias
    "val_loss"  : ["Loss/val",   "loss_val", "ValLoss"],
    "lr"        : ["LR", "learning_rate", "LearningRate"],
    # Add other canonical tags and their aliases if needed (e.g., for 'dist_' metrics)
    "dist_1"    : ["dist_1", "Dist/L1", "dist_l1"],
    "dist_2"    : ["dist_2", "Dist/L2", "dist_l2"],
    "dist_inf"  : ["dist_inf", "Dist/Linf", "dist_linf"],
    "dist_cos"  : ["dist_cos", "Dist/Cosine", "dist_cosine"],
}

# %% 3.2 – Smoothing configuration for plots.
# Savitzky-Golay window length. Must be an odd integer.
# Set to 1 or 0 for no smoothing. Typical values: 5, 7, 9.
SMOOTH_WIN = 5

def _apply_smoothing(series_data):
    """Applies Savitzky-Golay smoothing to a Pandas Series."""
    if SMOOTH_WIN <= 1 or len(series_data) < SMOOTH_WIN:
        return series_data # No smoothing or not enough data points

    # Ensure window length is odd and less than or equal to data length
    window_length = min(SMOOTH_WIN if SMOOTH_WIN % 2 != 0 else SMOOTH_WIN + 1, len(series_data))
    if window_length < 3 : return series_data # Savgol filter needs window >= polyorder + 1, min polyorder 0

    # Polyorder should be less than window_length. Typically 2 or 3.
    poly_order = min(2, window_length - 1)

    try:
        return savgol_filter(series_data, window_length, poly_order, mode="interp")
    except ValueError: # Handle cases where window/polyorder is still problematic for some data
        print(f"  ! Warning: Could not apply Savitzky-Golay smoothing for series of length {len(series_data)} with window {window_length}, polyorder {poly_order}. Returning raw data.")
        return series_data


# %% 3.3 – Helper function to find the actual tag name in a DataFrame.
def _find_actual_tag_name(df_columns, canonical_tag_name):
    """Finds the actual tag in DataFrame columns based on canonical name and TAG_MAP aliases."""
    if canonical_tag_name in df_columns: # Check direct match for canonical name itself
        return canonical_tag_name
    for alias in TAG_MAP.get(canonical_tag_name, []):
        if alias in df_columns:
            return alias
    return None # Tag not found

# %% 3.4 – Function to load scalar data from TensorBoard logs with caching.
def load_scalar_data_from_logdir(log_directory: pathlib.Path, run_display_name: str) -> pd.DataFrame:
    """Loads all scalar data from TensorBoard event files in a given directory, using a cache."""
    # Define cache path within the main output directory
    cache_storage_dir = main_output_directory / ".cache"
    cache_storage_dir.mkdir(parents=True, exist_ok=True)
    # Sanitize run_display_name for use in filenames
    sanitized_run_name = run_display_name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "_")
    cache_file_path = cache_storage_dir / f"{sanitized_run_name}_scalars.feather"

    if cache_file_path.exists():
        print(f"  ✓ Reading '{run_display_name}' data from cache: {cache_file_path}")
        return pd.read_feather(cache_file_path)

    print(f"  Parsing event files for '{run_display_name}' in: {log_directory} (this may take a moment)...")
    if not log_directory.exists() or not log_directory.is_dir():
        print(f"  ! Error: Log directory not found or is not a directory: {log_directory}")
        return pd.DataFrame()

    try:
        # SummaryReader processes all event files in the given directory.
        event_files = list(log_directory.glob("events.out.tfevents.*"))
        if not event_files:
            print(f"  ! No 'events.out.tfevents.*' files found in {log_directory} for run '{run_display_name}'.")
            return pd.DataFrame()

        reader = SummaryReader(str(log_directory), pivot=True)
        df_scalars = reader.scalars
    except Exception as e:
        print(f"  ! Error reading event files from {log_directory} for '{run_display_name}' with tbparse: {e}")
        return pd.DataFrame()

    if df_scalars.empty:
        print(f"  ! No scalar data extracted from {log_directory} for '{run_display_name}'.")
        return pd.DataFrame()

    if 'step' not in df_scalars.columns:
        print(f"  ! 'step' column not found in data from {log_directory} for '{run_display_name}'. Cannot process.")
        return pd.DataFrame()

    df_scalars['step'] = pd.to_numeric(df_scalars['step'], errors='coerce').astype('Int64')
    df_scalars.dropna(subset=['step'], inplace=True)

    # Process columns to handle list-like entries from tbparse (take first numeric element)
    for col in df_scalars.columns:
        if col != 'step':
            df_scalars[col] = df_scalars[col].apply(
                lambda x: x[0] if isinstance(x, list) and len(x) > 0 and isinstance(x[0], (int, float))
                else (x if isinstance(x, (int, float)) else np.nan)
            )
    # Drop rows where all metric values (excluding 'step') became NaN after processing
    df_scalars.dropna(how='all', subset=[col for col in df_scalars.columns if col != 'step'], inplace=True)

    if not df_scalars.empty:
        try:
            df_scalars.to_feather(cache_file_path) # Write to cache
            print(f"  ✓ Data for '{run_display_name}' loaded and cached to: {cache_file_path}")
        except Exception as e:
            print(f"  ! Error writing cache file {cache_file_path}: {e}")
    return df_scalars

# %% 3.5 – Generic function to generate and save a single-series plot.
def generate_single_metric_plot(
    steps_data: pd.Series,
    series_data: pd.Series,
    plot_title_text: str,
    y_axis_label_text: str,
    output_file_path: pathlib.Path,
    plot_color: str,
    plot_linestyle: str
    ):
    """Generates a single plot with specified styling and saves it as PDF."""

    # Drop NaNs before smoothing and align steps and series data
    series_data_cleaned = series_data.dropna()
    steps_data_aligned = steps_data[series_data_cleaned.index]

    if series_data_cleaned.empty or steps_data_aligned.empty or len(steps_data_aligned) != len(series_data_cleaned):
        print(f"  ! Skipping single plot '{plot_title_text}': Data is empty or mismatched after cleaning NaNs.")
        return

    plt.figure(figsize=(3.5, 2.5)) # Adjusted figure size for better aspect ratio

    series_to_plot = _apply_smoothing(series_data_cleaned)

    steps_to_plot = steps_data_aligned

    plt.plot(steps_to_plot, series_to_plot, color=plot_color, ls=plot_linestyle, lw=1.2, label=plot_title_text.split('(')[0].strip())

    plt.xlabel("Step (or Epoch)", fontsize=9)
    plt.ylabel(y_axis_label_text, fontsize=9)
    plt.title(plot_title_text, fontsize=10)
    plt.xticks(fontsize=8)
    plt.yticks(fontsize=8)

    min_val, max_val = series_to_plot.min(), series_to_plot.max()
    if "accuracy" in y_axis_label_text.lower() or \
       (min_val >= -0.05 and max_val <= 1.1 and not "loss" in y_axis_label_text.lower()):
        plt.yticks(np.arange(0.0, 1.1, 0.2))
        plt.ylim(-0.05, 1.05)

    plt.grid(alpha=0.3, ls=':')
    plt.legend(fontsize=7, loc='best')
    plt.tight_layout(pad=0.3)

    try:
        plt.savefig(output_file_path.with_suffix(".pdf"))
        plt.close()
        print(f"  ✓ Single plot saved: {output_file_path.with_suffix('.pdf')}")
    except Exception as e:
        print(f"  ! Error saving plot {output_file_path.with_suffix('.pdf')}: {e}")
        plt.close()


# %% [markdown] ################################################################################
# ## 4. Process All Logs & Generate Plots
# ----------------------------------------------------------------------------------------------

# %% 4.1 – Load data for all experimental runs.
all_runs_data_frames = {}
print(f"--- Loading Data for All Experiments (Smoothing Window: {SMOOTH_WIN}) ---")
for display_name, dir_path in all_experiment_log_dirs.items():
    print(f"↳ Loading data for '{display_name}'...")
    df_single_run = load_scalar_data_from_logdir(dir_path, display_name)
    if not df_single_run.empty and 'step' in df_single_run.columns:
        all_runs_data_frames[display_name] = df_single_run
    else:
        print(f"  ! Failed to load valid data or 'step' column missing for '{display_name}'. This run will be skipped.")

# %% 4.2 – Generate individual plots for each metric in each run.
print(f"\n--- Generating Individual Metric Plots for Each Run ---")
try:
    default_plot_color = plt.colormaps.get_cmap("tab10").colors[0]
except AttributeError: # Fallback for older matplotlib
    default_plot_color = plt.cm.get_cmap("tab10", 10)(0)
individual_plot_linestyle = '-'

for display_name, df_current_run in all_runs_data_frames.items():
    print(f"\n  Generating plots for run: '{display_name}'")
    sanitized_display_name_for_dir = display_name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "_")
    run_specific_figures_dir = main_output_directory / sanitized_display_name_for_dir
    run_specific_figures_dir.mkdir(exist_ok=True)

    metrics_to_plot_individually = [
        {"canonical": "train_loss", "ylabel": "Training Loss", "basename": "train_loss"},
        {"canonical": "val_loss",   "ylabel": "Validation Loss", "basename": "val_loss"},
        {"canonical": "val_acc",    "ylabel": "Validation Accuracy", "basename": "val_acc"},
        {"canonical": "lr",         "ylabel": "Learning Rate","basename": "lr"},
    ]
    for i_dist in [1, 2, 'inf', 'cos']:
        dist_tag_canonical = f'dist_{i_dist}'
        actual_dist_tag_found = _find_actual_tag_name(df_current_run.columns, dist_tag_canonical)
        if actual_dist_tag_found:
            if not any(m_info["canonical"] == actual_dist_tag_found for m_info in metrics_to_plot_individually):
                 metrics_to_plot_individually.append({
                    "canonical": actual_dist_tag_found,
                    "ylabel": f"Distance ({str(i_dist).upper()})",
                    "basename": actual_dist_tag_found.replace('/', '_')
                })

    for metric_config in metrics_to_plot_individually:
        actual_tag_name = _find_actual_tag_name(df_current_run.columns, metric_config["canonical"])
        if actual_tag_name and actual_tag_name in df_current_run.columns:
            generate_single_metric_plot(
                df_current_run['step'], df_current_run[actual_tag_name],
                plot_title_text=f"{actual_tag_name.replace('/', ' ')} ({display_name})",
                y_axis_label_text=metric_config["ylabel"],
                output_file_path=run_specific_figures_dir / metric_config["basename"],
                plot_color=default_plot_color,
                plot_linestyle=individual_plot_linestyle
            )

# %% [markdown]
# ### 4.3 Comparative Plots (Overlaying metrics from all runs)

# %%
print(f"\n--- Generating Comparative Plots (Smoothing Window: {SMOOTH_WIN}) ---")

metrics_for_comparison = {
    "val_acc": "Validation Accuracy",
    "train_loss": "Training Loss",
    "val_loss": "Validation Loss",
}

try:
    comparison_colors_cmap = plt.colormaps.get_cmap("tab10")
    # Get the list of RGBA color tuples from the colormap
    base_comparison_colors = comparison_colors_cmap.colors
except AttributeError: # Fallback for older matplotlib
    comparison_colors_cmap = plt.cm.get_cmap("tab10", 10)
    base_comparison_colors = [comparison_colors_cmap(i) for i in range(10)]

# Define a list of distinct linestyles
comparison_linestyles = ['-', '--', '-.', ':']


if not all_runs_data_frames:
    print("No data loaded from any run. Cannot generate comparative plots.")
else:
    for canonical_metric, common_y_label in metrics_for_comparison.items():
        plt.figure(figsize=(4.5, 3.0))

        legend_items_to_sort = []
        for display_name_legend, df_legend_run in all_runs_data_frames.items():
            actual_tag_for_legend = _find_actual_tag_name(df_legend_run.columns, canonical_metric)
            if actual_tag_for_legend and actual_tag_for_legend in df_legend_run.columns:
                series_for_legend = df_legend_run[actual_tag_for_legend].dropna().apply(lambda x: x[0] if isinstance(x, list) and x else x).astype(float)
                if not series_for_legend.empty:
                    last_value_smoothed = _apply_smoothing(series_for_legend)[-1] if len(series_for_legend) > 0 else np.nan
                    legend_items_to_sort.append({
                        "name": display_name_legend,
                        "last_val": last_value_smoothed if not pd.isna(last_value_smoothed) else (-float('inf') if "acc" in canonical_metric.lower() else float('inf'))
                    })

        sort_descending = "acc" in canonical_metric.lower()
        legend_items_to_sort.sort(key=lambda x: x["last_val"], reverse=sort_descending)

        # Reset color and linestyle indices for each new comparative plot
        non_baseline_color_idx = 0
        non_baseline_linestyle_idx = 0
        any_series_plotted_comp = False

        for sorted_item in legend_items_to_sort:
            run_name_to_plot = sorted_item["name"]
            df_to_plot_comp = all_runs_data_frames[run_name_to_plot]
            actual_tag_comp = _find_actual_tag_name(df_to_plot_comp.columns, canonical_metric)

            if actual_tag_comp and actual_tag_comp in df_to_plot_comp.columns:
                series_data_comp = df_to_plot_comp[actual_tag_comp].dropna().apply(lambda x: x[0] if isinstance(x, list) and x else x).astype(float)
                steps_data_comp = df_to_plot_comp['step'][series_data_comp.index]

                if not series_data_comp.empty:
                    series_to_plot_comp_smoothed = _apply_smoothing(series_data_comp)

                    if "Baseline" in run_name_to_plot:
                        current_plot_color = 'grey'
                        current_plot_linestyle = ':'
                    else:
                        # Cycle through base_comparison_colors for non-baseline runs
                        current_plot_color = base_comparison_colors[non_baseline_color_idx % len(base_comparison_colors)]
                        current_plot_linestyle = comparison_linestyles[non_baseline_linestyle_idx % len(comparison_linestyles)]
                        non_baseline_color_idx += 1
                        non_baseline_linestyle_idx += 1


                    plt.plot(steps_data_comp, series_to_plot_comp_smoothed,
                             label=run_name_to_plot,
                             color=current_plot_color,
                             ls=current_plot_linestyle,
                             lw=1.2)
                    any_series_plotted_comp = True

        if any_series_plotted_comp:
            plt.xlabel("Step (or Epoch)", fontsize=9)
            plt.ylabel(common_y_label, fontsize=9)
            plt.title(f"Comparative {common_y_label}", fontsize=10)
            plt.xticks(fontsize=8)
            plt.yticks(fontsize=8)

            all_plotted_y_min_comp, all_plotted_y_max_comp = [], []
            for item_legend in legend_items_to_sort:
                df_check = all_runs_data_frames[item_legend["name"]]
                tag_check = _find_actual_tag_name(df_check.columns, canonical_metric)
                if tag_check and tag_check in df_check.columns:
                    series_check = df_check[tag_check].dropna().apply(lambda x: x[0] if isinstance(x, list) and x else x).astype(float)
                    if not series_check.empty:
                        all_plotted_y_min_comp.append(series_check.min())
                        all_plotted_y_max_comp.append(series_check.max())

            if all_plotted_y_min_comp and all_plotted_y_max_comp:
                global_min_comp = min(all_plotted_y_min_comp)
                global_max_comp = max(all_plotted_y_max_comp)
                if "accuracy" in common_y_label.lower() or \
                   (global_min_comp >= -0.05 and global_max_comp <= 1.1 and not "loss" in common_y_label.lower()):
                    plt.yticks(np.arange(0.0, 1.1, 0.2))
                    plt.ylim(-0.05, 1.05)

            plt.grid(alpha=0.3, ls=':')
            plt.legend(title="Experiment", fontsize=7, title_fontsize=8, loc='best', frameon=True)
            plt.tight_layout(pad=0.3)

            output_pdf_filename = main_output_directory / f"COMPARATIVE_{canonical_metric.replace('/', '_')}.pdf"
            plt.savefig(output_pdf_filename)
            plt.close()
            print(f"✓ Comparative plot saved: {output_pdf_filename}")
        else:
            print(f"  ! Could not generate comparative plot for '{canonical_metric}' as no valid data was found across runs.")

print(f"\n✓ All single & comparative PDF figures are in: {main_output_directory.resolve()}")


# %% [markdown]
# ## 5. Zip and Download All Generated Figures
#
# This section will create a zip file of the `all_generated_figures_polished/` directory
# and then trigger a download for that single zip file.

# %%
zip_filename = "all_dissertation_figures.zip"
directory_to_zip = main_output_directory

if directory_to_zip.exists() and any(directory_to_zip.iterdir()): # Check if directory exists and is not empty
    print(f"\nZipping the output directory: {directory_to_zip} ...")
    shutil.make_archive(zip_filename.replace(".zip", ""), 'zip', directory_to_zip)
    print(f"✓ Directory zipped to: {zip_filename}")

    print(f"Triggering download for {zip_filename}...")
    files.download(zip_filename)
    print(f"✓ Download initiated for {zip_filename}. Check your browser's downloads.")
else:
    print(f"\nOutput directory {directory_to_zip} is empty or does not exist. Nothing to zip or download.")


# %% [markdown]
# ## 6. For Other Specialized Figures (ROC, Top-Q, etc.)
#
# This notebook focuses on figures derivable directly from scalar TensorBoard logs.
# * **ROC Curves:** Use the `colab_roc_generator_detailed.ipynb` script (which includes AUC CI calculations).
# * **PoL-Specific Metrics & Other Custom Plots:** Run your project's own Jupyter notebooks from your GitHub repository, as they contain the specific logic to process `proof/*` files and other custom data.

Mounted at /content/drive
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for feather-format (setup.py) ... [?25l[?25hdone
✓ Setup complete. Log directories to be processed:
  - 'Baseline (No WM)': /content/drive/MyDrive/SecurePoL-with-Watermarking/logs/no_watermark (Exists: True)
  - 'Feature-Based WM': /content/drive/MyDrive/SecurePoL-with-Watermarking/logs/feature_based (Exists: True)
  - 'Param-Perturb WM': /content/drive/MyDrive/SecurePoL-with-Watermarking/logs/param_pert (Exists: True)
  - 'Non-Intrusive WM': /content/drive/MyDrive/SecurePoL-with-Watermarking/logs/non_intrusive (Exists: True)
Generated PDF figures will be written to: /content/all_generated_figures_polished
--- Loading Data for All Experiments (Smoothing Window: 5) ---
↳ Loading data for 'Baseline (No WM)'...
  Parsing event files for 'Baseline (No WM)' in: /content/drive/MyDrive/SecurePoL-with-Watermarking/logs/no_watermark (this may take a moment)...
  ✓ Data for 'Baseline (No WM)' loaded

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✓ Download initiated for all_dissertation_figures.zip. Check your browser's downloads.
