In [4]:
import pandas as pd
from pathlib import Path
import plotly.graph_objects as go

# ============================================================
# LOAD DATA
# ============================================================
DATA_DIR = Path("..") / "data" / "nuclear-weapons"
csv_folder = "number-of-countries-that-approve-of-nuclear-weapons-treaties"
csv_filename = "number-of-countries-that-approve-of-nuclear-weapons-treaties.csv"

df = pd.read_csv(DATA_DIR / csv_folder / csv_filename)
df.columns = [c.strip() for c in df.columns]
df = df.rename(columns={"Year": "year", "Entity": "entity"})

# ============================================================
# SHORT TITLES
# ============================================================
title_map = {
    "Partial Test Ban Treaty": "PTBT (1963)",
    "Nuclear Non-Proliferation Treaty": "NPT (1968)",
    "Comprehensive Nuclear-Test-Ban Treaty": "CTBT (1996)",
    "Treaty on the Prohibition of Nuclear Weapons": "TPNW (2017)"
}

treaty_cols = []
for col in df.columns:
    for key, short in title_map.items():
        if key in col:
            df = df.rename(columns={col: short})
            treaty_cols.append(short)

# ============================================================
# EVENT YEARS
# ============================================================
event_years = {
    "PTBT (1963)": 1963,
    "NPT (1968)": 1968,
    "CTBT (1996)": 1996,
    "TPNW (2017)": 2017
}

# ============================================================
# DATA CLEAN
# ============================================================
years = range(1960, 2026)
df_long = []

for col in treaty_cols:
    tmp = (
        df[["year", col]]
        .dropna()
        .drop_duplicates(subset=["year"])
        .sort_values("year")
    )
    tmp["treaty"] = col
    tmp = tmp.rename(columns={col: "count"})
    df_long.append(tmp)

df_long = pd.concat(df_long)

# ============================================================
# PLOT
# ============================================================
fig = go.Figure()

colors = {
    "PTBT (1963)": "#1f77b4",
    "NPT (1968)": "#2ca02c",
    "CTBT (1996)": "#ff7f0e",
    "TPNW (2017)": "#d62728"
}

for treaty in treaty_cols:
    data_t = df_long[df_long["treaty"] == treaty]
    
    fig.add_trace(
        go.Scatter(
            x=data_t["year"],
            y=data_t["count"],
            mode="lines",
            name=treaty,
            line=dict(width=3, color=colors[treaty]),
            hovertemplate=
                "<b>%{text}</b><br>" +
                "Year: %{x}<br>" +
                "Countries: %{y}<extra></extra>",
            text=[treaty]*len(data_t)
        )
    )

    # Vertical event line
    fig.add_vline(
        x=event_years[treaty],
        line=dict(color=colors[treaty], dash="dot", width=1.5),
        opacity=0.6
    )

# ============================================================
# LAYOUT
# ============================================================
fig.update_layout(
    title=dict(
        text="<b>The Rise of Global Nuclear Safety</b><br>"
             "<span style='font-size:13px;color:#666;'>"
             "Number of countries approving major nuclear treaties over time"
             "</span>",
        x=0.02
    ),
    width=1100,
    height=600,
    plot_bgcolor="white",
    paper_bgcolor="white",
    font=dict(size=13),
    xaxis=dict(
        title="Year",
        range=[1960, 2025],
        showgrid=True,
        gridcolor="#e0e0e0",
        griddash="dot"
    ),
    yaxis=dict(
        title="Number of signatory countries",
        showgrid=True,
        gridcolor="#e0e0e0",
        griddash="dot"
    ),
    legend=dict(
        title="Treaty",
        orientation="h",
        y=-0.25,
        x=0.02
    )
)


fig


In [5]:
from const import VISUALIZATIONS_DIR

fig.write_html(VISUALIZATIONS_DIR / "the_rise_of_global_nuclear_safety.html")