# NIMBY Rails Network Analysis

Visual exploration of a NIMBY Rails transit network built from a database created by [nimby2sql](https://github.com/rlvelte/nimby2sql).

**Usage:** Set `DB_PATH` below to your `.db` file, or leave it as-is to use generated sample data.

In [None]:
import sqlite3, os, math, random, textwrap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import folium
from folium.plugins import HeatMap
import plotly.express as px
import plotly.graph_objects as go

sns.set_theme(style="darkgrid", palette="viridis")
plt.rcParams["figure.dpi"] = 120

# ---------- CONFIGURATION ----------
DB_PATH = "nimby_rails.db"  # <-- change to your real DB path
USE_SAMPLE = not os.path.exists(DB_PATH)
# -----------------------------------

In [None]:
def _build_sample_db(path="_sample_nimby.db"):
    """Generate a realistic sample NIMBY Rails DB for demo purposes."""
    if os.path.exists(path):
        return path

    random.seed(42)
    np.random.seed(42)

    # Stations across Germany
    base_stations = [
        ("Berlin Hbf", 13.369, 52.525),
        ("München Hbf", 11.560, 48.140),
        ("Hamburg Hbf", 10.007, 53.553),
        ("Frankfurt Hbf", 8.663, 50.107),
        ("Köln Hbf", 6.959, 50.943),
        ("Stuttgart Hbf", 9.182, 48.784),
        ("Düsseldorf Hbf", 6.794, 51.220),
        ("Leipzig Hbf", 12.382, 51.345),
        ("Dresden Hbf", 13.733, 51.040),
        ("Hannover Hbf", 9.741, 52.377),
        ("Nürnberg Hbf", 11.082, 49.446),
        ("Bremen Hbf", 8.814, 53.083),
        ("Dortmund Hbf", 7.460, 51.517),
        ("Essen Hbf", 7.014, 51.451),
        ("Mannheim Hbf", 8.470, 49.480),
        ("Karlsruhe Hbf", 8.402, 48.994),
        ("Augsburg Hbf", 10.886, 48.366),
        ("Freiburg Hbf", 7.841, 47.997),
        ("Erfurt Hbf", 11.039, 50.972),
        ("Rostock Hbf", 12.131, 54.078),
        ("Kassel-Wilhelmshöhe", 9.447, 51.313),
        ("Würzburg Hbf", 9.936, 49.802),
        ("Kiel Hbf", 10.131, 54.315),
        ("Ulm Hbf", 9.983, 48.399),
        ("Magdeburg Hbf", 11.627, 52.131),
        ("Regensburg Hbf", 12.100, 49.012),
        ("Aachen Hbf", 6.089, 50.768),
        ("Mainz Hbf", 8.259, 50.001),
        ("Potsdam Hbf", 13.066, 52.392),
        ("Lübeck Hbf", 10.670, 53.868),
    ]
    # Add smaller stations with jittered coords
    extra = []
    for i in range(40):
        base = random.choice(base_stations)
        name = f"{base[0]} {random.choice(['Nord','Süd','Ost','West','Park','Hafen'])} {i}"
        lon = base[1] + random.gauss(0, 0.25)
        lat = base[2] + random.gauss(0, 0.12)
        extra.append((name, lon, lat))
    all_stations = base_stations + extra

    conn = sqlite3.connect(path)
    c = conn.cursor()
    c.executescript(textwrap.dedent("""
        CREATE TABLE stations (
            station_id TEXT PRIMARY KEY,
            name TEXT NOT NULL,
            lon REAL NOT NULL,
            lat REAL NOT NULL
        );
        CREATE TABLE lines (
            line_id TEXT PRIMARY KEY,
            name TEXT NOT NULL,
            code TEXT NOT NULL,
            color TEXT
        );
        CREATE TABLE line_stops (
            line_id TEXT NOT NULL,
            stop_index INTEGER NOT NULL,
            station_id TEXT NOT NULL,
            arrival_s INTEGER NOT NULL,
            departure_s INTEGER NOT NULL,
            leg_distance_m REAL NOT NULL,
            PRIMARY KEY (line_id, stop_index)
        );
        CREATE VIEW line_stops_enriched AS
        SELECT ls.line_id, l.name AS line_name, l.code AS line_code,
               l.color AS line_color, ls.stop_index, ls.station_id,
               s.name AS station_name, s.lon, s.lat,
               ls.arrival_s, ls.departure_s, ls.leg_distance_m
        FROM line_stops ls
        JOIN lines l ON l.line_id = ls.line_id
        JOIN stations s ON s.station_id = ls.station_id;
    """))

    sid_map = {}
    for i, (name, lon, lat) in enumerate(all_stations):
        sid = hex(0x1000 + i)
        sid_map[i] = sid
        c.execute("INSERT INTO stations VALUES (?,?,?,?)", (sid, name, lon, lat))

    colors = ["#e6194b","#3cb44b","#4363d8","#f58231","#911eb4",
              "#42d4f4","#f032e6","#bfef45","#fabed4","#469990",
              "#dcbeff","#9A6324","#800000","#aaffc3","#808000"]
    n_lines = 15
    for li in range(n_lines):
        lid = hex(0x5000 + li)
        name = f"Line {li+1}"
        code = f"L{li+1:02d}"
        c.execute("INSERT INTO lines VALUES (?,?,?,?)",
                  (lid, name, code, colors[li % len(colors)]))
        # pick 5-12 stations for this line
        n_stops = random.randint(5, 12)
        chosen = random.sample(range(len(all_stations)), n_stops)
        # sort by longitude for a vaguely geographic order
        chosen.sort(key=lambda idx: all_stations[idx][1])
        t = random.randint(0, 3600)
        for si, idx in enumerate(chosen):
            leg = random.uniform(2000, 40000)
            speed = random.uniform(60, 160)  # km/h
            travel = leg / (speed * 1000 / 3600)
            t += int(travel)
            arr = t
            dwell = random.randint(30, 180)
            dep = arr + dwell
            t = dep
            c.execute("INSERT INTO line_stops VALUES (?,?,?,?,?,?)",
                      (lid, si, sid_map[idx], arr, dep, round(leg, 1)))

    conn.commit()
    conn.close()
    return path

if USE_SAMPLE:
    # Delete stale sample so it regenerates with new data
    if os.path.exists("_sample_nimby.db"):
        os.remove("_sample_nimby.db")
    DB_PATH = _build_sample_db()
    print(f"Using generated sample DB: {DB_PATH}")
else:
    print(f"Using real DB: {DB_PATH}")

conn = sqlite3.connect(DB_PATH)
df = pd.read_sql("SELECT * FROM line_stops_enriched", conn)
df_stations = pd.read_sql("SELECT * FROM stations", conn)
df_lines = pd.read_sql("SELECT * FROM lines", conn)
print(f"Loaded {len(df_stations)} stations, {len(df_lines)} lines, {len(df)} stop records")

## 1. Station Density Heatmap (Interactive Map)
An interactive Folium map showing station density as a heatmap. Busier areas (more stations & lines) glow hotter.

In [None]:
# Weight by number of lines serving each station
station_weight = (
    df.groupby(["station_id", "station_name", "lat", "lon"])["line_id"]
    .nunique()
    .reset_index(name="lines_serving")
)

# Auto-fit map to data bounds
sw = [station_weight["lat"].min(), station_weight["lon"].min()]
ne = [station_weight["lat"].max(), station_weight["lon"].max()]
center_lat = (sw[0] + ne[0]) / 2
center_lon = (sw[1] + ne[1]) / 2

m = folium.Map(location=[center_lat, center_lon],
               tiles="CartoDB dark_matter")
m.fit_bounds([sw, ne], padding=[20, 20])

heat_data = station_weight[["lat", "lon", "lines_serving"]].values.tolist()
HeatMap(heat_data, radius=20, blur=15, max_zoom=13,
        gradient={0.2: "blue", 0.4: "lime", 0.6: "yellow", 1: "red"}).add_to(m)

# Add station markers
for _, row in station_weight.iterrows():
    folium.CircleMarker(
        [row["lat"], row["lon"]],
        radius=2 + row["lines_serving"] * 1.5,
        color="white", fill=True, fill_opacity=0.7,
        tooltip=f"{row['station_name']} ({int(row['lines_serving'])} lines)"
    ).add_to(m)

m

## 2. Top Hub Stations
Stations served by the most distinct lines - the major interchange hubs of the network.

In [None]:
top_hubs = station_weight.nlargest(20, "lines_serving")

fig, ax = plt.subplots(figsize=(10, 6))
palette = sns.color_palette("magma", len(top_hubs))
bars = ax.barh(top_hubs["station_name"], top_hubs["lines_serving"],
               color=palette)
ax.set_xlabel("Lines Serving")
ax.set_title("Top 20 Hub Stations by Line Count")
ax.invert_yaxis()
for bar, val in zip(bars, top_hubs["lines_serving"]):
    ax.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height()/2,
            str(int(val)), va="center", fontsize=9)
plt.tight_layout()
plt.show()

## 3. Dwell Time Analysis
How long do trains stop at each station? Distribution of dwell times (departure - arrival) per line.

In [None]:
df["dwell_s"] = df["departure_s"] - df["arrival_s"]

dwell_by_line = (
    df.groupby("line_name")["dwell_s"]
    .agg(["min", "mean", "max", "count"])
    .round(1)
    .sort_values("mean", ascending=False)
)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Box plot
sns.boxplot(data=df, x="dwell_s", y="line_name",
            order=dwell_by_line.index, ax=axes[0], palette="coolwarm")
axes[0].set_xlabel("Dwell Time (seconds)")
axes[0].set_ylabel("")
axes[0].set_title("Dwell Time Distribution per Line")

# Histogram of all dwell times
axes[1].hist(df["dwell_s"], bins=30, color="#4363d8", edgecolor="white", alpha=0.85)
axes[1].axvline(df["dwell_s"].mean(), color="red", ls="--", label=f"Mean: {df['dwell_s'].mean():.0f}s")
axes[1].set_xlabel("Dwell Time (seconds)")
axes[1].set_ylabel("Frequency")
axes[1].set_title("Overall Dwell Time Distribution")
axes[1].legend()

plt.tight_layout()
plt.show()

## 4. Station-Line Connectivity Heatmap
Which stations are served by which lines? A heatmap showing the connectivity matrix.

In [None]:
# Pivot: station vs line presence
pivot = df.pivot_table(index="station_name", columns="line_name",
                       values="stop_index", aggfunc="count", fill_value=0)
# Convert to binary (served / not served)
pivot_binary = (pivot > 0).astype(int)

# Sort stations by total connectivity, lines by number of stations
pivot_binary = pivot_binary.loc[
    pivot_binary.sum(axis=1).sort_values(ascending=False).index
]
pivot_binary = pivot_binary[
    pivot_binary.sum(axis=0).sort_values(ascending=False).index
]

# Show top 30 most connected stations
display_pivot = pivot_binary.head(30)

fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(display_pivot, cmap="YlOrRd", linewidths=0.3, linecolor="gray",
            cbar_kws={"label": "Served (1) / Not (0)"},
            xticklabels=True, yticklabels=True, ax=ax)
ax.set_title("Station-Line Connectivity Matrix (Top 30 Stations)")
ax.set_xlabel("Line")
ax.set_ylabel("Station")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

## 5. Network Map with Line Routes
Interactive map showing each line as a colored polyline connecting its stations in order.

In [None]:
m2 = folium.Map(location=[center_lat, center_lon],
                tiles="CartoDB positron")
m2.fit_bounds([sw, ne], padding=[20, 20])

line_colors = df_lines.set_index("line_id")["color"].to_dict()
line_names = df_lines.set_index("line_id")["name"].to_dict()

for lid in df["line_id"].unique():
    line_df = df[df["line_id"] == lid].sort_values("stop_index")
    coords = line_df[["lat", "lon"]].values.tolist()
    color = line_colors.get(lid, "#333333")
    name = line_names.get(lid, lid)
    if len(coords) >= 2:
        folium.PolyLine(
            coords, color=color, weight=3, opacity=0.8,
            tooltip=name
        ).add_to(m2)

for _, row in df_stations.iterrows():
    folium.CircleMarker(
        [row["lat"], row["lon"]], radius=4,
        color="#333", fill=True, fill_color="white",
        fill_opacity=0.9, weight=1,
        tooltip=row["name"]
    ).add_to(m2)

m2

## 6. Leg Distance Distribution
How far apart are consecutive stops on each line?

In [None]:
df["leg_km"] = df["leg_distance_m"] / 1000

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Violin plot per line
order = df.groupby("line_name")["leg_km"].median().sort_values(ascending=False).index
sns.violinplot(data=df, x="leg_km", y="line_name", order=order,
               ax=axes[0], palette="mako", inner="quart", linewidth=0.8)
axes[0].set_xlabel("Leg Distance (km)")
axes[0].set_ylabel("")
axes[0].set_title("Leg Distance Distribution per Line")

# KDE of all leg distances
sns.kdeplot(df["leg_km"], fill=True, ax=axes[1], color="#4363d8", alpha=0.6)
axes[1].axvline(df["leg_km"].median(), color="red", ls="--",
                label=f"Median: {df['leg_km'].median():.1f} km")
axes[1].set_xlabel("Leg Distance (km)")
axes[1].set_title("Overall Leg Distance Density")
axes[1].legend()

plt.tight_layout()
plt.show()

## 7. Implied Speed Between Stops
Estimated average speed between consecutive stops (leg distance / travel time). Helps spot slow segments.

In [None]:
# Compute travel time to next stop
speed_df = df.sort_values(["line_id", "stop_index"]).copy()
speed_df["next_arrival"] = speed_df.groupby("line_id")["arrival_s"].shift(-1)
speed_df["travel_s"] = speed_df["next_arrival"] - speed_df["departure_s"]
speed_df["next_leg_m"] = speed_df.groupby("line_id")["leg_distance_m"].shift(-1)
speed_df = speed_df.dropna(subset=["travel_s", "next_leg_m"])
speed_df = speed_df[speed_df["travel_s"] > 0]
speed_df["speed_kmh"] = (speed_df["next_leg_m"] / 1000) / (speed_df["travel_s"] / 3600)
speed_df = speed_df[speed_df["speed_kmh"] < 500]  # filter unrealistic

fig = px.scatter(
    speed_df, x="leg_km", y="speed_kmh",
    color="line_name", hover_data=["station_name"],
    title="Implied Speed vs Leg Distance",
    labels={"leg_km": "Leg Distance (km)", "speed_kmh": "Speed (km/h)",
            "line_name": "Line"},
    opacity=0.7
)
fig.update_layout(height=500)
fig.show()

## 8. Line-Station Sunburst
Interactive sunburst showing lines in the inner ring and their stations in the outer ring. Size = number of stops.

In [None]:
sun_df = df.groupby(["line_name", "station_name"]).size().reset_index(name="stops")

fig = px.sunburst(
    sun_df, path=["line_name", "station_name"], values="stops",
    title="Network Structure: Lines and Stations",
    color="stops", color_continuous_scale="Viridis"
)
fig.update_layout(height=600)
fig.show()

## 9. Timetable Timeline (Gantt-style)
Each line shown as a timeline from first departure to last arrival, visualising the span of service.

In [None]:
timeline = df.groupby("line_name").agg(
    start=pd.NamedAgg(column="arrival_s", aggfunc="min"),
    end=pd.NamedAgg(column="departure_s", aggfunc="max"),
    n_stops=pd.NamedAgg(column="stop_index", aggfunc="count")
).reset_index().sort_values("start")

# Convert to hours:minutes for display
def s_to_hm(s):
    h, m = divmod(int(s), 3600)
    m = m // 60
    return f"{h}:{m:02d}"

fig, ax = plt.subplots(figsize=(12, 5))
colors = sns.color_palette("husl", len(timeline))
for i, (_, row) in enumerate(timeline.iterrows()):
    ax.barh(row["line_name"], row["end"] - row["start"],
            left=row["start"], color=colors[i], edgecolor="white", height=0.6)
    ax.text(row["end"] + 50, i, f"{s_to_hm(row['start'])}\u2013{s_to_hm(row['end'])}",
            va="center", fontsize=8)

ax.set_xlabel("Time (seconds from midnight)")
ax.set_title("Service Span per Line (Gantt-style Timeline)")
plt.tight_layout()
plt.show()

## 10. Total Route Length per Line
Sum of all leg distances for each line - which lines cover the most ground?

In [None]:
route_len = (
    df.groupby("line_name")["leg_distance_m"]
    .sum()
    .div(1000)
    .sort_values(ascending=True)
    .reset_index(name="total_km")
)

fig = px.bar(
    route_len, x="total_km", y="line_name", orientation="h",
    color="total_km", color_continuous_scale="Turbo",
    title="Total Route Length per Line",
    labels={"total_km": "Total Distance (km)", "line_name": ""}
)
fig.update_layout(height=450, showlegend=False)
fig.show()

## 11. Correlation Heatmap of Numeric Features
How do arrival time, departure time, dwell time, and leg distance relate to each other?

In [None]:
num_cols = ["arrival_s", "departure_s", "dwell_s", "leg_distance_m", "lat", "lon"]
corr = df[num_cols].corr()

fig, ax = plt.subplots(figsize=(7, 6))
mask = np.triu(np.ones_like(corr, dtype=bool), k=1)
sns.heatmap(corr, mask=mask, annot=True, fmt=".2f", cmap="RdBu_r",
            center=0, square=True, linewidths=0.5, ax=ax)
ax.set_title("Feature Correlation Matrix")
plt.tight_layout()
plt.show()

## 12. Network Summary Statistics

In [None]:
total_km = df["leg_distance_m"].sum() / 1000
avg_stops = df.groupby("line_id")["stop_index"].count().mean()
avg_lines_per_station = station_weight["lines_serving"].mean()
max_hub = station_weight.loc[station_weight["lines_serving"].idxmax()]

summary = pd.DataFrame({
    "Metric": [
        "Total Stations", "Total Lines", "Total Stop Records",
        "Total Network Distance (km)", "Avg Stops per Line",
        "Avg Lines per Station", "Biggest Hub",
        "Mean Dwell Time (s)", "Median Leg Distance (km)"
    ],
    "Value": [
        len(df_stations), len(df_lines), len(df),
        f"{total_km:,.1f}", f"{avg_stops:.1f}",
        f"{avg_lines_per_station:.1f}",
        f"{max_hub['station_name']} ({int(max_hub['lines_serving'])} lines)",
        f"{df['dwell_s'].mean():.0f}",
        f"{df['leg_km'].median():.1f}"
    ]
})
summary.style.hide(axis='index').set_properties(**{'text-align': 'left'})

In [None]:
conn.close()
print("Done! Connection closed.")