In [4]:
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio

from const import DATA_DIR

# ============================================================
# SET PLOTLY TEMPLATE
# ============================================================
pio.templates.default = "plotly_dark"

# ============================================================
# LOAD DATA
# ============================================================
csv_path = DATA_DIR / "nuclear-weapons" / "number-of-countries-that-approve-of-nuclear-weapons-treaties" / "number-of-countries-that-approve-of-nuclear-weapons-treaties.csv"
df = pd.read_csv(csv_path)
df.columns = [c.strip() for c in df.columns]
df = df.rename(columns={"Year": "year", "Entity": "entity"})

# ============================================================
# SHORT TITLES
# ============================================================
title_map = {
    "Partial Test Ban Treaty": "PTBT",
    "Nuclear Non-Proliferation Treaty": "NPT",
    "Comprehensive Nuclear-Test-Ban Treaty": "CTBT",
    "Treaty on the Prohibition of Nuclear Weapons": "TPNW"
}

treaty_cols = []
for col in df.columns:
    for key, short in title_map.items():
        if key in col:
            df = df.rename(columns={col: short})
            treaty_cols.append(short)

# ============================================================
# EVENT YEARS
# ============================================================
event_years = {
    "PTBT": 1963,
    "NPT": 1968,
    "CTBT": 1996,
    "TPNW": 2017
}

# ============================================================
# DATA CLEAN
# ============================================================
years = range(1960, 2026)
df_long = []

for col in treaty_cols:
    tmp = (
        df[["year", col]]
        .dropna()
        .drop_duplicates(subset=["year"])
        .sort_values("year")
    )
    tmp["treaty"] = col
    tmp = tmp.rename(columns={col: "count"})
    df_long.append(tmp)

df_long = pd.concat(df_long)

# ============================================================
# PLOT
# ============================================================
fig = go.Figure()

colors = {
    "PTBT": "#1f77b4",
    "NPT": "#2ca02c",
    "CTBT": "#ff7f0e",
    "TPNW": "#d62728"
}

for treaty in treaty_cols:
    data_t = df_long[df_long["treaty"] == treaty]
    
    fig.add_trace(
        go.Scatter(
            x=data_t["year"],
            y=data_t["count"],
            mode="lines",
            name=treaty,
            line=dict(width=3, color=colors[treaty]),
            hovertemplate=
                "<b>%{text}</b><br>" +
                "Year: %{x}<br>" +
                "Countries: %{y}<extra></extra>",
            text=[treaty]*len(data_t)
        )
    )

    # Vertical event line
    fig.add_vline(
        x=event_years[treaty],
        line=dict(color=colors[treaty], dash="dot", width=1.5),
        opacity=0.6
    )

# Add vertical grid line at x=1960
fig.add_vline(
    x=1960,
    line=dict(color="#e0e0e0", dash="dot", width=1.5),
    opacity=1,
    layer="below"
)

# Add horizontal grid line at y=0
fig.add_hline(
    y=0,
    line=dict(color="#e0e0e0", dash="dot", width=1.5),
    opacity=1,
    layer="below"
)

# ============================================================
# LAYOUT
# ============================================================
fig.update_layout(
    title=dict(
        text="<b>The Rise of Global Nuclear Safety</b><br>"
             "<span style='font-size:13px;color:#aaa;'>"
             "Number of countries approving major nuclear treaties over time"
             "</span>",
        x=0.02
    ),
    width=1000,
    height=600,
    plot_bgcolor="#222",
    paper_bgcolor="#222",
    font=dict(size=13, color="#fff"),
    xaxis=dict(
        title=None,
        range=[1960, 2025],
        showgrid=True,
        gridcolor="#444",
        griddash="dot",
        fixedrange=True,
        zeroline=False,
        showline=False,
        linecolor="#444",
        linewidth=1,
        tickfont=dict(color="#fff")
    ),
    yaxis=dict(
        title="Number of signatory countries",
        showgrid=True,
        gridcolor="#444",
        griddash="dot",
        fixedrange=True,
        zeroline=False,
        showline=False,
        linecolor="#444",
        linewidth=1,
        tickfont=dict(color="#fff")
    ),
    legend=dict(
        title="Treaty",
        orientation="h",
        y=1.07,
        x=1,
        xanchor="right",
        yanchor="top",
        font=dict(color="#fff")
    )
)

fig.show()

In [5]:
from const import VISUALIZATIONS_DIR

fig.write_html(VISUALIZATIONS_DIR / "the_rise_of_global_nuclear_safety.html")