In [1]:
import plotly.graph_objects as go
import pandas as pd
import numpy as np

# ==================== LOAD DATA ====================
from const import DATA_DIR
proliferation_df = pd.read_csv(DATA_DIR / 'nuclear-weapons' / 'nuclear-weapons-proliferation' / 'nuclear-weapons-proliferation.csv')
world = proliferation_df[proliferation_df['Entity'] == 'World'].copy()

# Define the three metrics
metrics = [
    'Number of countries considering nuclear weapons',
    'Number of countries pursuing nuclear weapons', 
    'Number of countries possessing nuclear weapons'
]

# We'll create nodes for each decade (or you can use every 5–10 years)
world = world.sort_values('Year')
years = world['Year'].values
values = world[metrics].values

# Choose key years for clarity and beauty (every ~10 years + latest)
key_years = [1945, 1950, 1960, 1970, 1980, 1990, 2000, 2010, 2020, 2023]  # adjust to your data range
key_idx = [np.argmin(np.abs(years - y)) for y in key_years]
key_years_actual = years[key_idx]

# Extract values at those years
data_at_keys = values[key_idx]

# ==================== BUILD SANKEY NODES & LINKS ====================
labels = []
links_source = []
links_target = []
links_value = []
links_color = []

# Node order: [Considering_1945, Pursuing_1945, Possessing_1945, Considering_1950, ...]
node_colors = []

color_map = {
    'considering': '#FFD93D',   # Gold
    'pursuing':     '#FF6B6B',   # Coral red
    'possessing':   '#FF3B30'    # Deep blood red
}

opacity_map = {
    'considering': 0.5,
    'pursuing':0.7,
    'possessing':0.9
}

# Create nodes
for i, year in enumerate(key_years_actual):
    for status in ['considering', 'pursuing', 'possessing']:
        label = ""
        labels.append(label)
        if status == 'considering':
            node_colors.append(color_map['considering'])
        elif status == 'pursuing':
            node_colors.append(color_map['pursuing'])
        else:
            node_colors.append(color_map['possessing'])

# Create flows between consecutive periods
for period in range(len(key_years_actual) - 1):
    base_idx = period * 3
    
    # Within-period retention + transitions (we approximate realistic flows)
    c1, p1, s1 = data_at_keys[period]      # current period
    c2, p2, s2 = data_at_keys[period + 1]  # next period
    
    # 1. Countries that stay in "Considering"
    stay_considering = min(c1, c2)
    links_source.append(base_idx + 0)
    links_target.append(base_idx + 3)
    links_value.append(stay_considering)
    links_color.append(f"rgba(255,217,61,{opacity_map['considering']})")

    # 2. Considering → Pursuing
    c_to_p = max(0, p2 - p1 + stay_considering)  # rough estimate
    if c_to_p > 0:
        links_source.append(base_idx + 0)
        links_target.append(base_idx + 4)
        links_value.append(c_to_p)
        links_color.append("rgba(255,170,100,0.8)")

    # 3. Pursuing → Possessing
    p_to_s = max(0, s2 - s1)
    if p_to_s > 0:
        links_source.append(base_idx + 1)
        links_target.append(base_idx + 5)
        links_value.append(p_to_s)
        links_color.append("rgba(255,59,48,0.9)")

    # 4. Stay in Pursuing
    stay_pursuing = p1 - p_to_s
    if stay_pursuing > 0:
        links_source.append(base_idx + 1)
        links_target.append(base_idx + 4)
        links_value.append(stay_pursuing)
        links_color.append(f"rgba(255,107,107,{opacity_map['pursuing']})")

    # 5. Stay in Possessing (almost all do)
    links_source.append(base_idx + 2)
    links_target.append(base_idx + 5)
    links_value.append(s1)
    links_color.append(f"rgba(255,59,48,{opacity_map['possessing']})")

# ==================== FINAL FIGURE ====================
fig_prolif = go.Figure(data=[go.Sankey(
    arrangement="snap",
    node=dict(
        pad=30,
        thickness=20,
        line=dict(color="white", width=0),
        label=labels,
        color=node_colors,
        hovertemplate='%{value}<extra></extra>'
    ),
    link=dict(
        source=links_source,
        target=links_target,
        value=links_value,
        color=links_color,
        hovertemplate='%{value}<extra></extra>'
    )
)])

# Add dummy traces for legend
for status in ['considering', 'pursuing', 'possessing']:
    fig_prolif.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(size=10, color=color_map[status]),
        name=status.capitalize(),
        showlegend=True
    ))

# Create timeline annotations
annotations = []
for i, year in enumerate(key_years_actual):
    # Calculate normalized x position (0 to 1)
    x_pos = i / (len(key_years_actual) - 1)
    annotations.append(dict(
        x=x_pos,
        y=-0.05,  # Position below the plot
        xref="paper",
        yref="paper",
        text=str(year),
        showarrow=False,
        font=dict(size=12, color="#ddd"),
        xanchor="center",
        yanchor="top"
    ))

fig_prolif.update_layout(
    title_text="<b>The Irreversible Path: Nuclear Proliferation as a One-Way Journey</b>",
    title_font_size=26,
    title_font_color="#ffffff",
    font=dict(family="Arial, sans-serif", size=14, color="#ddd"),
    paper_bgcolor="#0b0e17",
    plot_bgcolor="rgba(0,0,0,0)",
    # height=600,
    # width=1000,
    margin=dict(l=40, r=40, t=100, b=80),
    annotations=annotations,
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, fixedrange=True),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, fixedrange=True),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    )
)

fig_prolif.show()

In [2]:
from const import VISUALIZATIONS_DIR

fig_prolif.write_html(VISUALIZATIONS_DIR / "nuclear-proliferation.html")