In [27]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Load the main results
raw_results = pd.read_csv("/workspaces/src/models/bionemo-framework/attack/analysis/virulence_results/layerwise_sweep_40b.csv")

# Filter to only step=0 data
raw_results_step0 = raw_results[raw_results["checkpoint_step"] == 0]

# Separate by shuffle condition
results_no_shuffle = raw_results_step0[raw_results_step0["shuffle_labels"] == False]
results_shuffle = raw_results_step0[raw_results_step0["shuffle_labels"] == True]

# Define colors for shuffle vs no shuffle
color_no_shuffle = "#6a66e9"  # Blue
color_shuffle = "#f9715d"     # Red

# Create 1x3 subplots
fig = make_subplots(
    rows=1, cols=3,
    horizontal_spacing=0.06,
    column_widths=[0.33, 0.33, 0.34]
)

# Figure 1: Layer vs Pearson (Both shuffle conditions)
# No shuffle
fig.add_trace(
    go.Scatter(
        x=results_no_shuffle['layer'],
        y=results_no_shuffle['test/pearson'],
        mode='lines+markers',
        name='No Label Shuffling',
        line=dict(color=color_no_shuffle, width=2),
        marker=dict(color=color_no_shuffle, size=8),
        legendgroup='no_shuffle',
        showlegend=True
    ),
    row=1, col=1
)

# With shuffle
fig.add_trace(
    go.Scatter(
        x=results_shuffle['layer'],
        y=results_shuffle['test/pearson'],
        mode='lines+markers',
        name='With Label Shuffling',
        line=dict(color=color_shuffle, width=2, dash='dash'),
        marker=dict(color=color_shuffle, size=8, symbol='triangle-up'),
        legendgroup='shuffle',
        showlegend=True
    ),
    row=1, col=1
)

# Figure 2: Layer vs R² (Both shuffle conditions)
# No shuffle
fig.add_trace(
    go.Scatter(
        x=results_no_shuffle['layer'],
        y=results_no_shuffle['test/r2'],
        mode='lines+markers',
        name='No Shuffle',
        line=dict(color=color_no_shuffle, width=2),
        marker=dict(color=color_no_shuffle, size=8),
        legendgroup='no_shuffle',
        showlegend=False
    ),
    row=1, col=2
)

# With shuffle
fig.add_trace(
    go.Scatter(
        x=results_shuffle['layer'],
        y=results_shuffle['test/r2'],
        mode='lines+markers',
        name='Shuffled',
        line=dict(color=color_shuffle, width=2, dash='dash'),
        marker=dict(color=color_shuffle, size=8, symbol='triangle-up'),
        legendgroup='shuffle',
        showlegend=False
    ),
    row=1, col=2
)

# Figure 3: Layer vs Magnitude (No shuffle only)
fig.add_trace(
    go.Scatter(
        x=results_no_shuffle['layer'],
        y=results_no_shuffle['magnitude/combined_mean'],
        mode='lines+markers',
        name='Feature Magnitude',
        line=dict(color=color_no_shuffle, width=2),
        marker=dict(color=color_no_shuffle, size=8),
        showlegend=False
    ),
    row=1, col=3
)

# Update axes labels
fig.update_xaxes(title_text="Layer ID", row=1, col=1)
fig.update_xaxes(title_text="Layer ID", row=1, col=2)
fig.update_xaxes(title_text="Layer ID", row=1, col=3)

fig.update_yaxes(title_text="Pearson Correlation", range=[-0.15, 0.53], row=1, col=1)
fig.update_yaxes(title_text="R²", range=[-0.25, 0.28], row=1, col=2)
fig.update_yaxes(
    title_text="Feature Magnitude", 
    type="log",
    tickvals=[1, 10000, 100000000, 1000000000000, 10000000000000000],
    ticktext=["10<sup>0</sup>", "10<sup>4</sup>", "10<sup>8</sup>", "10<sup>12</sup>", "10<sup>16</sup>"],
    row=1, col=3
)

# Update layout
fig.update_layout(
    height=400,  # Adjusted for single row
    width=1600,
    title_font_size=18,
    title_x=0.5,  # Center the title
    showlegend=True,
    legend=dict(
        orientation="h",
        yanchor="top",
        y=1.14,
        xanchor="center",
        x=0.86,
        font=dict(size=20)  # Legend font size
    ),
    margin=dict(t=0, b=60, r=0, l=40),  # Margins for title and captions
    template="plotly_white",
    plot_bgcolor='white',
    font=dict(family="Arial, sans-serif", size=18)
)

# Update all x-axes with consistent styling
for i in range(1, 4):  # 3 subplots
    fig.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey',
        dtick=5,  # Show ticks every 5 layers
        row=1,
        col=i,
        title_standoff=0,
        tickfont=dict(size=20),  # Tick label font size
        title_font=dict(size=22)  # Axis title font size
    )

# Update all y-axes with consistent styling
for i in range(1, 4):  # 3 subplots
    fig.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey',
        zerolinecolor='lightgrey',
        row=1,
        col=i,
        title_standoff=0,
        tickfont=dict(size=20),  # Tick label font size
        title_font=dict(size=22)  # Axis title font size
    )

# Add captions to subplots
# Caption for Fig1
# fig.add_annotation(
#     text="(a) Pearson Correlation",
#     xref="paper", yref="paper",
#     x=0.17,  # First column position
#     y=-0.15,  # Below the plot
#     showarrow=False,
#     font=dict(family="Times, serif", size=24),
#     xanchor="center"
# )

# # Caption for Fig2
# fig.add_annotation(
#     text="(b) R² Score",
#     xref="paper", yref="paper",
#     x=0.5,  # Second column position
#     y=-0.15,  # Below the plot
#     showarrow=False,
#     font=dict(family="Times, serif", size=24),
#     xanchor="center"
# )

# # Caption for Fig3
# fig.add_annotation(
#     text="(c) Feature Magnitude",
#     xref="paper", yref="paper",
#     x=0.83,  # Third column position
#     y=-0.15,  # Below the plot
#     showarrow=False,
#     font=dict(family="Times, serif", size=24),
#     xanchor="center"
# )

# Show the figure
fig.show()

# Export to SVG
fig.write_image("layerwise_sweep_40b_step0.svg")

# Convert SVG to PDF
import subprocess
subprocess.run(["inkscape", "layerwise_sweep_40b_step0.svg", "--export-pdf=layerwise_sweep_40b_step0.pdf"])



** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...
** Message: 23:45:50.681: Invalid glyph found, continuing...


CompletedProcess(args=['inkscape', 'layerwise_sweep_40b_step0.svg', '--export-pdf=layerwise_sweep_40b_step0.pdf'], returncode=0)

# 40B loglikelihood DMS Results

In [4]:
import os
import glob
import pandas as pd
import numpy as np

BASE_PATH = "dms_results/likelihood/virus_reproduction/full"  # adjust if needed

# Helpers mirrored from plot_fitness_spearman.py

def extract_steps_from_model_name(model_name: str):
    if "_ncbi_" in model_name and "samples=" in model_name:
        parts = model_name.split("_")
        for i, part in enumerate(parts):
            if part == "1m" and i + 1 < len(parts) and parts[i + 1].isdigit():
                return int(parts[i + 1])
    return None


def clean_model_name(model_name: str) -> str:
    if "_ncbi_" in model_name and "samples=" in model_name:
        parts = model_name.split("_")
        steps = None
        samples = None
        for i, part in enumerate(parts):
            if part == "1m" and i + 1 < len(parts):
                steps = parts[i + 1]
            elif part.startswith("samples="):
                samples = part.split("=")[1]
        if steps and samples:
            return f"steps={steps}_samples={samples}"
    if "_epoch" in model_name:
        return model_name.split("_epoch")[0]
    return model_name


def collect_fitness_data_by_model(model_substrings, base_path=BASE_PATH):
    taxon = "Virus"
    data = {}

    taxon_path = os.path.join(base_path, taxon)
    if not os.path.exists(taxon_path):
        print(f"Warning: {taxon_path} does not exist")
        return data, []

    available_models = [m for m in os.listdir(taxon_path) if os.path.isdir(os.path.join(taxon_path, m))]

    if isinstance(model_substrings, str):
        model_substrings = [model_substrings]

    selected = set()
    # Always include base models if present
    for base in ["nemo2_evo2_40b_1m", "esm2_650m", "esm2_3b"]:
        if base in available_models:
            selected.add(base)

    if model_substrings and not (len(model_substrings) == 1 and model_substrings[0] == "all"):
        for substring in model_substrings:
            for model in available_models:
                if substring in model:
                    selected.add(model)
    else:
        selected.update(available_models)

    selected = list(selected)

    for model_name in selected:
        model_path = os.path.join(taxon_path, model_name)
        values = []
        if os.path.exists(model_path):
            for file_path in glob.glob(os.path.join(model_path, "*_fitness.csv")):
                try:
                    df = pd.read_csv(file_path)
                    if len(df.index) == 0:
                        continue
                    spearman_value = df.iloc[0, 0]
                    if pd.notna(spearman_value):
                        values.append(abs(float(spearman_value)))
                except Exception as e:
                    print(f"Error reading {file_path}: {e}")
        data[model_name] = values

    return data, selected



In [7]:

# Collect data (adjust model filters or base path as needed)
model_filters = ["human"]  # restrict finetuned checkpoints to those containing 'human'
fitness_data, selected_models = collect_fitness_data_by_model(model_filters, base_path=BASE_PATH)
print(f"Selected models ({len(selected_models)}):", selected_models)

# Separate evo2 vs esm2
all_models = list(fitness_data.keys())

nemo_base_models = [m for m in all_models if m == "nemo2_evo2_40b_1m"]
nemo_finetuned_candidates = [
    m for m in all_models if ("evo2" in m) and (not m.startswith("esm2")) and (m not in nemo_base_models)
]
# Order finetuned by steps
nemo_finetuned_models = sorted(
    nemo_finetuned_candidates,
    key=lambda name: (extract_steps_from_model_name(name) is None, extract_steps_from_model_name(name) or float('inf'))
)

esm2_fixed_order = ["esm2_650m", "esm2_3b"]
esm2_models = [m for m in esm2_fixed_order if m in all_models]

# Create display labels
labels_evo2 = [clean_model_name(m) for m in (nemo_base_models)]
labels_esm2 = [
    "ESM2-650M" if m.lower().startswith("esm2_650m") else
    "ESM2-3B" if m.lower().startswith("esm2_3b") else
    clean_model_name(m)
    for m in esm2_models
]

# Also create numeric step labels for evo2 (base mapped to 0)
models_left = (nemo_base_models)
steps_labels_evo2 = []
for m in models_left:
    steps = extract_steps_from_model_name(m)
    steps_labels_evo2.append(0 if steps is None else steps)

# Compute means and collect points
def stats_for(models):
    means = []
    points = []
    for m in models:
        vals = fitness_data.get(m, [])
        means.append(np.mean(vals) if len(vals) > 0 else 0.0)
        points.append(vals)
    return means, points

means_evo2, points_evo2 = stats_for(nemo_base_models)
means_esm2, points_esm2 = stats_for(esm2_models)

print("Group sizes:")
print("  evo2:", len(means_evo2))
print("  esm2:", len(means_esm2))




Selected models (8): ['esm2_650m', 'evo2_7b_1m_200_ncbi_virus_human_host_full_species_samples=1600', 'evo2_7b_1m_500_ncbi_virus_human_host_full_species_samples=4000', 'evo2_7b_1m_1000_ncbi_virus_human_host_full_species_samples=8000', 'evo2_7b_1m_2000_ncbi_virus_human_host_full_species_samples=16000', 'esm2_3b', 'evo2_7b_1m_100_ncbi_virus_human_host_full_species_samples=800', 'nemo2_evo2_40b_1m']
Group sizes:
  evo2: 1
  esm2: 2


In [9]:


# Plot 1×2 subplots with shared y-axis in the house style
import plotly.graph_objects as go
from plotly.subplots import make_subplots

left_count = max(1, len(means_evo2))
right_count = max(1, len(means_esm2))
fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.08, column_widths=[left_count, right_count])

# Left: evo2 base + fine-tuning bars (categorical x)
x_evo2_labels = [str(s) for s in steps_labels_evo2]
fig.add_trace(
    go.Bar(
        x=x_evo2_labels,
        y=means_evo2,
        marker_color='#3366CC',  # All bars use the same blue color
        opacity=0.7,
        name='Evo2-7B (Log-Likelihood)',
        width=0.6,
    ),
    row=1, col=1
)

# Overlay points for evo2
for i, vals in enumerate(points_evo2):
    if not vals:
        continue
    fig.add_trace(
        go.Scatter(
            x=[x_evo2_labels[i]] * len(vals),
            y=vals,
            mode='markers',
            marker=dict(
                color='#3366CC',  # All points use the same blue color
                size=8,
                opacity=0.65,
                line=dict(color="white", width=0.5)
            ),
            showlegend=False,
            hovertemplate=f"Step: {x_evo2_labels[i]}<br>|ρ|: %{{y}}<extra></extra>"
        ),
        row=1, col=1
    )

# Right: ESM2 bars (categorical x)
x_esm2_labels = labels_esm2
fig.add_trace(
    go.Bar(
        x=x_esm2_labels,
        y=means_esm2,
        marker_color='#86bff2',  # Unified ESM2 color
        opacity=0.7,
        name='ESM2 (Masked-Marginal)',
        width=0.6,
    ),
    row=1, col=2
)

# Make bars with consistent width (apply to all bar traces)
fig.update_traces(width=0.6, selector=dict(type='bar'))

# Overlay points for ESM2
for i, vals in enumerate(points_esm2):
    if not vals:
        continue
    fig.add_trace(
        go.Scatter(
            x=[x_esm2_labels[i]] * len(vals),
            y=vals,
            mode='markers',
            marker=dict(
                color='#86bff2',
                size=8,
                opacity=0.65,
                line=dict(color="white", width=0.5)
            ),
            showlegend=False,
            hovertemplate=f"Model: {x_esm2_labels[i]}<br>|ρ|: %{{y}}<extra></extra>"
        ),
        row=1, col=2
    )

# Axis labels and ticks
fig.update_xaxes(
    title_text="Fine-tuning Steps",
    tickmode='array',
    tickvals=x_evo2_labels,
    ticktext=x_evo2_labels,
    tickangle=0,
    showline=True,
    linecolor='black',
    mirror=True,
    title_standoff=0,
    row=1, col=1,
    tickfont=dict(size=25, family="Arial"),
)

fig.update_xaxes(
    title_text="",
    tickmode='array',
    tickvals=x_esm2_labels,
    ticktext=x_esm2_labels,
    tickangle=0,
    showline=True,
    linecolor='black',
    mirror=True,
    row=1, col=2,
    tickfont=dict(size=25, family="Arial"),
)

# Shared y axis formatting and style to match the last figure
y_values = [v for vals in (points_evo2 + points_esm2) for v in vals]
# if len(y_values) > 0:
#     y_min = min(y_values) - 0.05
#     y_max = max(y_values) + 0.05
# else:
#     y_min, y_max = 0.0, 1.0
y_min = -0.02
y_max = 0.7

for c in [1, 2]:
    fig.update_yaxes(
        title_text='|ρ|' if c == 1 else '',
        range=[y_min, y_max],
        dtick=0.2,  # Show ticks every 0.2
        ticks='outside',
        showline=True,
        linecolor='black',
        mirror=True,
        gridcolor='lightgrey',
        zerolinecolor='lightgrey',
        showgrid=True,
        row=1, col=c,
        tickfont=dict(size=25, family="Arial"),
    )

# Layout in the style of your last plotly figure
fig.update_layout(
    template='plotly_white',
    width=1600,
    height=400,
    showlegend=True,
    legend=dict(
        orientation="h",  # Horizontal orientation
        yanchor="bottom",
        y=1.02,  # Position above the plot
        xanchor="center",
        x=0.785,  # Center horizontally
        bgcolor='rgba(255,255,255,0.8)'
    ),
    margin=dict(t=0, b=0, l=60, r=0),  # Adjusted margins
    font=dict(family='Arial, sans-serif', size=25),
    plot_bgcolor='white',
    bargap=0.0,
    bargroupgap=0.0
)

fig.show()

# Save outputs (optional - uncomment to save)
out_file_svg = "dms_40b.svg"   
out_file_pdf = "dms_40b.pdf"
fig.write_image(out_file_svg)
import subprocess
subprocess.run(["inkscape", out_file_svg, "--export-pdf=" + out_file_pdf])





CompletedProcess(args=['inkscape', 'dms_40b.svg', '--export-pdf=dms_40b.pdf'], returncode=0)

['nemo2_evo2_40b_1m']
{'esm2_3b': [0.3968587032641512, 0.3629478092172354, 0.0715248351847422, 0.0475148382250067, 0.4051901246171898, 0.0588952602623445, 0.0582138669003198, 0.2665642773739244, 0.5457441495770475, 0.0963254385572049, 0.2765712366523658, 0.2842958301587995, 0.1103221321651832, 0.1409388734825082, 0.5754052530139516, 0.4965301057898278], 'esm2_650m': [0.2368393425434439, 0.1506119511963866, 0.0713658966235152, 0.0123111443944239, 0.4937479846256914, 0.5051654318593204, 0.4727315813653386, 0.1334829245607084, 0.0969929936168829, 0.4296594899569331, 0.0293046996127599, 0.0315951786801824, 0.247600163549855, 0.2040886731127742, 0.10270803598389, 0.1058746595124704], 'nemo2_evo2_40b_1m': [0.0655088113985751, 0.0587416572928384, 0.0738268823801091, 0.1146006749156355, 0.0485369328833895, 0.0353145856767904, 0.0330708661417322, 0.0031046119235095, 0.0310024275917144, 0.0307642006212578, 0.2510401746671068, 0.0537023677576423, 0.0267356981413096, 0.0206824146981627, 0.03200009

CompletedProcess(args=['inkscape', 'dms_probe_40b.svg', '--export-pdf=dms_probe_40b.pdf'], returncode=0)

In [28]:
# pyright: reportMissingTypeStubs=false
import os
import glob
import pandas as pd  # type: ignore[import]
import numpy as np
import plotly.graph_objects as go  # type: ignore[import-untyped]
from plotly.subplots import make_subplots  # type: ignore[import-untyped]
from plot_utils import (
    get_best_train_rmse_layer_stats,
)

# === Configuration ===
MODEL_NAME = "nemo2_evo2_40b_1m"

BASE_DIR_FULL = \
    "/workspaces/src/models/bionemo-framework/attack/analysis/dms_results/likelihood/virus_reproduction/full/Virus"
BASE_DIR_H5 = \
    "/workspaces/src/models/bionemo-framework/attack/analysis/dms_results/likelihood/virus_reproduction/h5_samples=624_seed=42_test/Virus"
PROBE_CSV = \
    "/workspaces/src/models/bionemo-framework/attack/analysis/dms_results/probe_results/closed_form/probe_results_40b.csv"


def read_zero_shot_spearman_values(zs_base_dir: str, model_dir: str) -> list:
    """Read per-dataset |rho| from *_fitness.csv for the given model directory.

    Returns a list of floats. Safely handles missing columns and non-numeric entries.
    """
    model_path = os.path.join(zs_base_dir, model_dir)
    values: list = []
    if not os.path.isdir(model_path):
        return values
    for file_path in glob.glob(os.path.join(model_path, "*_fitness.csv")):
        try:
            df = pd.read_csv(file_path)
            if df is None or df.empty:
                continue
            val = None
            if "spearman" in df.columns:
                v = df.loc[0, "spearman"]
                if pd.notna(v):
                    try:
                        val = abs(float(v))  # type: ignore[arg-type]
                    except Exception:
                        val = None
            # Fallback to first cell if needed
            if val is None:
                try:
                    v = df.iloc[0, 0]
                    if pd.notna(v):
                        val = abs(float(v))  # type: ignore[arg-type]
                except Exception:
                    val = None
            if val is not None:
                values.append(val)
        except Exception:
            continue
    return values


# === Load data ===
zs_full_values = read_zero_shot_spearman_values(BASE_DIR_FULL, MODEL_NAME)
zs_h5_values = read_zero_shot_spearman_values(BASE_DIR_H5, MODEL_NAME)

# Probe values from closed_form following analysis notebook (best train_rmse layer)
layer_indices = list(range(0, 50))
_, probe_values_h5, probe_mean_h5 = get_best_train_rmse_layer_stats(PROBE_CSV, layer_indices)
if probe_values_h5 is None:
    probe_values_h5 = []


# === Build plot ===
# One figure with two groups on x-axis: 0 => Full, 1 => H5
group_positions = {"Full": 0, "H5": 1}

fig = make_subplots(rows=1, cols=1)

# Bars
# Full group: one bar (Zero-shot only)
fig.add_trace(
    go.Bar(
        x=[group_positions["Full"]],
        y=[float(np.mean(zs_full_values))] if len(zs_full_values) > 0 else [np.nan],
        name="Evo2-40B (Log-Likelihood)",
        marker_color="#3366CC",
        opacity=0.7,
        width=0.3,
        offset=-0.15,
        showlegend=True,
        legendgroup="zs",
    ),
    row=1, col=1,
)

# H5 group: two bars (Zero-shot and Probe)
fig.add_trace(
    go.Bar(
        x=[group_positions["H5"]],
        y=[float(np.mean(zs_h5_values))] if len(zs_h5_values) > 0 else [np.nan],
        name="Evo2-40B (Log-Likelihood)",
        marker_color="#3366CC",
        opacity=0.7,
        width=0.3,
        offset=-0.3,
        showlegend=False,  # already shown above
        legendgroup="zs",
    ),
    row=1, col=1,
)

fig.add_trace(
    go.Bar(
        x=[group_positions["H5"]],
        y=[float(np.mean(probe_values_h5))] if len(probe_values_h5) > 0 else [np.nan],
        name="Evo2-40B (Probe)",
        marker_color="#E377C2",
        opacity=0.7,
        width=0.3,
        offset=0,
        legendgroup="probe",
    ),
    row=1, col=1,
)

# Points over the bars
# Full zero-shot points (centered at 0)
if len(zs_full_values) > 0:
    fig.add_trace(
        go.Scatter(
            x=[group_positions["Full"]] * len(zs_full_values),
            y=zs_full_values,
            mode="markers",
            marker=dict(color="#3366CC", size=8, opacity=0.65, line=dict(color="white", width=0.5)),
            showlegend=False,
            hovertemplate="Group: Full<br>|ρ|: %{y}<extra></extra>",
        ),
        row=1, col=1,
    )

# H5 zero-shot points (slightly left)
if len(zs_h5_values) > 0:
    fig.add_trace(
        go.Scatter(
            x=[group_positions["H5"] - 0.15] * len(zs_h5_values),
            y=zs_h5_values,
            mode="markers",
            marker=dict(color="#3366CC", size=8, opacity=0.65, line=dict(color="white", width=0.5)),
            showlegend=False,
            hovertemplate="Group: H5 (Zero-shot)<br>|ρ|: %{y}<extra></extra>",
        ),
        row=1, col=1,
    )

# H5 probe points (slightly right)
if len(probe_values_h5) > 0:
    fig.add_trace(
        go.Scatter(
            x=[group_positions["H5"] + 0.15] * len(probe_values_h5),
            y=probe_values_h5,
            mode="markers",
            marker=dict(color="#E377C2", size=8, opacity=0.65, line=dict(color="white", width=0.5)),
            showlegend=False,
            hovertemplate="Group: H5 (Probe)<br>|ρ|: %{y}<extra></extra>",
        ),
        row=1, col=1,
    )

# Axis labels and ticks
tickvals = [group_positions["Full"], group_positions["H5"]]
ticktext = ["BioRiskEval-Mut", "BioRiskEval-Mut-Probe"]

fig.update_xaxes(
    tickmode="array",
    tickvals=tickvals,
    ticktext=ticktext,
    tickangle=0,
    showline=True,
    linecolor="black",
    mirror=True,
    tickfont=dict(size=25, family="Arial"),
)

# Shared y-axis formatting
y_min = -0.02
y_max = 0.52
fig.update_yaxes(
    title_text='|ρ|',
    range=[y_min, y_max],
    dtick=0.1,
    ticks='outside',
    showline=True,
    linecolor='black',
    mirror=True,
    gridcolor='lightgrey',
    zerolinecolor='lightgrey',
    showgrid=True,
    tickfont=dict(size=25, family='Arial'),
)

# Layout styling
fig.update_layout(
    template='plotly_white',
    width=600,
    height=400,
    showlegend=True,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=0.72,
        xanchor="center",
        x=0.34,
        bgcolor='rgba(255,255,255,0.8)'
    ),
    margin=dict(t=0, b=0, l=60, r=0),
    font=dict(family='Arial, sans-serif', size=25),
    plot_bgcolor='white',
    bargap=0.1,
    bargroupgap=0.1,
)

fig.show()

# Save outputs
out_file_svg = "dms_40b.svg"
out_file_pdf = "dms_40b.pdf"
fig.write_image(out_file_svg)
import subprocess
subprocess.run(["inkscape", out_file_svg, "--export-pdf=" + out_file_pdf])





CompletedProcess(args=['inkscape', 'dms_40b.svg', '--export-pdf=dms_40b.pdf'], returncode=0)