<a href="https://colab.research.google.com/gist/pouyahosseinzadeh/3f6b122fb4ba725e2dc5ed0e8cd37b87/sep-multi-instrument-dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Collection

In [None]:
!pip install cdflib

Collecting cdflib
  Downloading cdflib-1.3.6-py3-none-any.whl.metadata (2.7 kB)
Downloading cdflib-1.3.6-py3-none-any.whl (78 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: cdflib
Successfully installed cdflib-1.3.6


# Vx

In [None]:
import os, re
import pandas as pd
import numpy as np
import requests
from pathlib import Path
from datetime import timedelta
from tqdm import tqdm
import cdflib

# ==========================
# CONFIG / PATHS
# ==========================
EVENTS_CSV = Path("/content/sample_data/1998_2013_MEMSEP_dataset.csv")
OUT_CSV    = Path("/content/1998_2013_MEMSEP_dataset_with_Vx.csv")  # final output here
CACHE_DIR  = Path("/content/_omni_hro2_cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)

BASE_URL = "https://cdaweb.gsfc.nasa.gov/pub/data/omni/omni_cdaweb/hro2_5min"

FREQ = "5min"
WINDOW_MIN = 24 * 60
NUM_STEPS = WINDOW_MIN // 5  # 288
COLS = [f"Vx_{i}" for i in range(NUM_STEPS, 0, -1)]  # Vx_288 ... Vx_1

# Candidate names for Vx
VX_SCALARS = [
    "Vx", "VX", "VX_GSE", "VX_GSM", "V_GSE_X", "V_GSM_X", "V_X_GSE", "V_X_GSM",
    "Vx_gse", "Vx_gsm"
]
VX_VECTORS = [
    "V_GSE", "V_GSM", "VGSE", "VGSM", "PLASMA_V_GSE", "PLASMA_V_GSM"
]

def http_get(url, timeout=60):
    r = requests.get(url, timeout=timeout)
    r.raise_for_status()
    return r

def latest_month_cdf_url(year: int, month: int) -> str:
    """Pick the latest version omni_hro2_5min_YYYYMM01_vNN.cdf in the year dir."""
    ydir = f"{BASE_URL}/{year:04d}/"
    html = http_get(ydir).text
    tag = f"{year:04d}{month:02d}01"
    pat = re.compile(rf"omni_hro2_5min_{tag}_v(\d+)\.cdf", re.I)
    versions = [int(v) for v in pat.findall(html)]
    if not versions:
        raise FileNotFoundError(f"No hro2_5min CDF for {tag} under {ydir}")
    vmax = max(versions)
    return f"{ydir}omni_hro2_5min_{tag}_v{vmax:02d}.cdf"

def ensure_month_cdf(year: int, month: int) -> Path:
    """Download month CDF into cache."""
    url = latest_month_cdf_url(year, month)
    local = CACHE_DIR / url.split("/")[-1]
    if not local.exists():
        data = http_get(url).content
        with open(local, "wb") as f:
            f.write(data)
    return local

def get_zvars(info) -> list:
    """Support both dict- and object-style cdflib CDFInfo."""
    try:
        # dict-like
        return list(info["zVariables"])
    except Exception:
        # object-like (attributes)
        z = getattr(info, "zVariables", None)
        if z is None:
            raise TypeError("cdf_info() has no zVariables field")
        return list(z)

def read_vx_from_cdf(cdf_path: Path) -> pd.DataFrame:
    """Return DataFrame with DatetimeIndex and one column 'Vx'."""
    cdf = cdflib.CDF(str(cdf_path))
    info = cdf.cdf_info()
    zvars = get_zvars(info)

    # time variable
    epoch_var = None
    for cand in zvars:
        if "epoch" in cand.lower():
            epoch_var = cand
            break
    if epoch_var is None:
        raise KeyError(f"No Epoch var in {cdf_path.name}")

    times = cdflib.cdfepoch.to_datetime(cdf.varget(epoch_var))
    times = pd.to_datetime(times, utc=True).tz_convert(None)

    # scalar names first
    vx_series = None
    for name in VX_SCALARS:
        if name in zvars:
            arr = cdf.varget(name)
            vx_series = pd.Series(arr, index=times, name="Vx").astype(float)
            used = name
            break

    # vector names, take X component [:,0]
    if vx_series is None:
        for name in VX_VECTORS:
            if name in zvars:
                arr = cdf.varget(name)
                if hasattr(arr, "ndim") and arr.ndim == 2 and arr.shape[1] >= 1:
                    arr = arr[:, 0]
                vx_series = pd.Series(arr, index=times, name="Vx").astype(float)
                used = f"{name}[:,0]"
                break

    if vx_series is None:
        # last resort: any 2D V*GSE/GSM var
        candidates = [v for v in zvars if ("V" in v.upper() and ("GSE" in v.upper() or "GSM" in v.upper()))]
        for name in candidates:
            arr = cdf.varget(name)
            if hasattr(arr, "ndim") and arr.ndim == 2 and arr.shape[1] >= 1:
                arr = arr[:, 0]
                vx_series = pd.Series(arr, index=times, name="Vx").astype(float)
                used = f"{name}[:,0]"
                break

    if vx_series is None:
        raise KeyError(f"No usable Vx var in {cdf_path.name}; first vars: {zvars[:12]}")

    # Clean fills, sort, dedup
    vx_series.replace([99999.8984375, 99999.9, 99999, 9999, 1e31, -1e31, 1e30, -1e30], np.nan, inplace=True)
    # after the .replace(...):
    vx_series[np.isclose(vx_series, 99999.9, atol=2)] = np.nan
    vx_series[vx_series.abs() > 5000] = np.nan
    vx_series = vx_series[~vx_series.index.duplicated(keep="last")].sort_index()

    # Snap timestamps to exact 5-min bins
    snapped = (vx_series.index.view("int64") // (5*60*1_000_000_000)) * (5*60*1_000_000_000)
    vx_series.index = pd.to_datetime(snapped)

    return vx_series.to_frame()  # col 'Vx'

def months_for_window(onset: pd.Timestamp):
    prev = (onset - pd.Timedelta(days=1))
    return {(prev.year, prev.month), (onset.year, onset.month)}




def build_window(month_dfs: list, onset: pd.Timestamp) -> np.ndarray:
    """
    Build a 5-min grid for the 24h BEFORE onset,
    snapping onset up to the next 5-min boundary (ceil).
    Last sample is strictly < onset (i.e., end_aligned - 5 min).
    """
    NUM_STEPS = 288  # 24h @ 5min
    FREQ = "5min"

    # 1) snap onset to next 5-min boundary
    end_aligned = pd.Timestamp(onset).ceil("5min")  # e.g., 03:21 -> 03:25

    # 2) exact 5-min index: [end_aligned-24h, end_aligned-5min]
    idx = pd.date_range(end=end_aligned - pd.Timedelta(minutes=5),
                        periods=NUM_STEPS, freq=FREQ)

    # 3) combine months and snap their timestamps to 5-min bins
    df_all = pd.concat(month_dfs, ignore_index=False).sort_index()
    df_all = df_all[~df_all.index.duplicated(keep="last")]

    # snap source to exact 5-min bins (handles small clock skews)
    snapped_ns = (df_all.index.view("int64") // (5*60*1_000_000_000)) * (5*60*1_000_000_000)
    df_all.index = pd.to_datetime(snapped_ns)

    # 4) align + fill (dataset style)
    s = df_all.reindex(idx)["Vx"]
    s = s.ffill().bfill().fillna(-9999.0)

    return s.to_numpy()



# ==========================
# MAIN
# ==========================
df = pd.read_csv(EVENTS_CSV)
if "FlrOnset" not in df.columns:
    raise SystemExit("FlrOnset column missing in your CSV.")
df["FlrOnset"] = pd.to_datetime(df["FlrOnset"], utc=True).dt.tz_convert(None)

# Prefetch month files
needed = set()
for t in df["FlrOnset"]:
    needed |= months_for_window(t)

month_cache = {}
print("Downloading & parsing OMNI hro2_5min CDFs …")
for (y, m) in tqdm(sorted(needed)):
    try:
        cdf_path = ensure_month_cdf(y, m)
        month_cache[(y, m)] = read_vx_from_cdf(cdf_path)
    except Exception as e:
        print(f"[WARN] {y}-{m:02d}: {e}")
        month_cache[(y, m)] = pd.DataFrame({"Vx": []})

# Build Vx windows
vx_mat = np.full((len(df), NUM_STEPS), -9999.0, dtype=float)

print("Building 24h Vx windows before FlrOnset …")
for i, onset in enumerate(tqdm(df["FlrOnset"].tolist())):
    months = months_for_window(onset)
    dfs = [month_cache.get(k, pd.DataFrame({"Vx": []})) for k in months]
    try:
        vals = build_window(dfs, onset)
        vx_mat[i, :] = vals
    except Exception:
        # keep -9999 row
        pass

# Attach to end of CSV, preserve everything
vx_df = pd.DataFrame(vx_mat, columns=COLS)
df_out = pd.concat([df, vx_df], axis=1)
df_out.to_csv(OUT_CSV, index=False)
print(f"✓ Saved: {OUT_CSV}")

all_missing = int((vx_df == -9999.0).all(axis=1).sum())
print(f"[SUMMARY] Rows with all -9999: {all_missing} / {len(df)}")


Downloading & parsing OMNI hro2_5min CDFs …


100%|██████████| 163/163 [02:35<00:00,  1.05it/s]


Building 24h Vx windows before FlrOnset …


100%|██████████| 17794/17794 [01:34<00:00, 187.92it/s]


✓ Saved: /content/1998_2013_MEMSEP_dataset_with_Vx.csv
[SUMMARY] Rows with all -9999: 84 / 17794


In [None]:
import re, numpy as np, pandas as pd

IN  = "/content/1998_2013_MEMSEP_dataset_with_Vx.csv"
OUT = "/content/1998_2013_MEMSEP_dataset_with_Vx_filled.csv"

df = pd.read_csv(IN)

# grab Vx_288 … Vx_1
vx_cols = [c for c in df.columns if re.fullmatch(r"Vx_\d+", c)]
vx_cols = sorted(vx_cols, key=lambda s: int(s.split("_")[1]), reverse=True)  # 288→1

# to numeric; non-numeric -> NaN
X = df[vx_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=float)

# treat ALL common sentinels + absurd values as missing
sentinels = { 999.989990234375, -9999, 99999.8984375, 99999.9, 99999, 9999, 1e31, -1e31, 1e30, -1e30 }
mask_bad = np.isin(X, list(sentinels)) | np.isclose(X, 99999.9, atol=2) | (~np.isfinite(X)) | (np.abs(X) > 5000)
X[mask_bad] = np.nan

# interpolate each row across the 288 points
idx = np.arange(X.shape[1])
for i in range(X.shape[0]):
    m = ~np.isnan(X[i])
    if m.sum() >= 2:
        X[i, ~m] = np.interp(idx[~m], idx[m], X[i, m])  # linear; fills ends by edge values
    elif m.sum() == 1:
        X[i, :] = X[i, m][0]  # only one real point (flat fill)

# any remaining all-NaN rows --- fill with column medians, then global median as last resort
col_med = np.nanmedian(X, axis=0)
nan_rows, nan_cols = np.where(np.isnan(X))
X[nan_rows, nan_cols] = col_med[nan_cols]
gmed = np.nanmedian(X)
X[np.isnan(X)] = gmed

df[vx_cols] = X
df.to_csv(OUT, index=False)
print("Saved:", OUT)


Saved: /content/1998_2013_MEMSEP_dataset_with_Vx_filled.csv


# >10 MeV

In [None]:
import os, re, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind

os.makedirs("/content/sample_data/figs_gt10", exist_ok=True)

# files
variables = {
    "Vx": "/content/sample_data/1998_2013_MEMSEP_dataset_with_Vx_filled.csv",
    "FlowSpeed": "/content/sample_data/1998_2013_MEMSEP_dataset_with_flow_speed_filled.csv",
    "ProtonDensity": "/content/sample_data/1998_2013_MEMSEP_dataset_with_proton_density.csv",
    "ProtonTemp": "/content/sample_data/1998_2013_MEMSEP_dataset_with_T_filled.csv",
    "F": "/content/sample_data/1998_2013_MEMSEP_dataset_with_F_filled.csv",
    "Xs": "/content/sample_data/1998_2013_MEMSEP_dataset_with_Xs_filled.csv",
    "Xl": "/content/sample_data/1998_2013_MEMSEP_dataset_with_Xl_filled.csv",
    "P4": "/content/sample_data/1998_2013_MEMSEP_dataset_with_P4_filled.csv",
    "P5": "/content/sample_data/1998_2013_MEMSEP_dataset_with_P5_filled.csv",
    "P6": "/content/sample_data/1998_2013_MEMSEP_dataset_with_P6_filled.csv",
}

SENTINELS = { -9999, 99999.8984375, 99999.9, 99999, 9999, 1e31, -1e31, 1e30, -1e30 }
THRESH_10PFU = 10.0  # >10 MeV NOAA threshold at sep_peak_2

def pick_timeseries_block(df: pd.DataFrame) -> list[str]:
    pat = re.compile(r"^(?P<prefix>.+)_(?P<idx>\d+)$")
    groups = {}
    for c in df.columns:
        m = pat.match(str(c))
        if not m: continue
        p = m.group("prefix"); k = int(m.group("idx"))
        groups.setdefault(p, []).append((k, c))
    if not groups: return []
    def score(items):
        idxs = [k for k,_ in items]
        return (len(items), max(idxs)-min(idxs))
    pref = max(groups.items(), key=lambda kv: score(kv[1]))[0]
    items = sorted(groups[pref], key=lambda t: t[0], reverse=True)
    seen, cols = set(), []
    for k, c in items:
        if k not in seen:
            cols.append(c); seen.add(k)
    if len(cols) >= 288: cols = cols[:288]
    return cols

def sanitize_numeric(frame: pd.DataFrame) -> np.ndarray:
    X = frame.apply(pd.to_numeric, errors="coerce").to_numpy(dtype=float)
    mask_bad = np.isin(X, list(SENTINELS)) | np.isclose(X, 99999.9, atol=2) | (~np.isfinite(X))
    X[mask_bad] = np.nan
    return X

def mean_sem_ci(X: np.ndarray):
    m = np.nanmean(X, axis=0)
    s = np.nanstd(X, axis=0, ddof=1)
    n = np.sum(~np.isnan(X), axis=0).astype(float)
    sem = s / np.sqrt(np.maximum(n, 1))
    z = 1.959963984540054
    ci = z * sem
    return m, sem, ci, n

def contiguous_regions(mask: np.ndarray):
    if mask.size == 0: return
    in_region = False; start = 0
    for i, v in enumerate(mask):
        if v and not in_region: in_region = True; start = i
        elif not v and in_region: in_region = False; yield (start, i-1)
    if in_region: yield (start, len(mask)-1)

def build_gt10_sep_vs_nsep(path: str):
    df = pd.read_csv(path)
    if "event_type" not in df.columns:
        raise ValueError(f"'event_type' column missing in {path}")
    df = df.drop(columns=[c for c in df.columns if str(c).lower().startswith("unnamed")], errors="ignore")

    # time-series block
    cols = pick_timeseries_block(df)
    if not cols:
        raise ValueError(f"No time-series block found in {path}")

    # >10 MeV mask: event_type==1 & sep_peak_2 >= THRESH_10PFU
    if "sep_peak_2" not in df.columns:
        raise ValueError(f"'sep_peak_2' missing in {path} (needed for >10 MeV selection)")
    peak10 = pd.to_numeric(df["sep_peak_2"], errors="coerce")
    gt10_mask = (df["event_type"] == 1) & (peak10 >= THRESH_10PFU)

    SEP_gt10  = sanitize_numeric(df.loc[gt10_mask, cols])           # SEP (>10 MeV subset)
    NSEP      = sanitize_numeric(df.loc[df["event_type"] == 0, cols])  # all NSEPs

    return cols, SEP_gt10, NSEP, int(gt10_mask.sum()), int((df["event_type"]==0).sum())

def plot_curves_with_ci(varname, cols, SEP, NSEP, outdir):
    T = len(cols); x_hours = np.linspace(24, 0, num=T)
    m_s, _, ci_s, n_s = mean_sem_ci(SEP)
    m_n, _, ci_n, n_n = mean_sem_ci(NSEP)

    pvals = np.full(T, np.nan)
    for j in range(T):
        a, b = SEP[:, j], NSEP[:, j]
        _, p = ttest_ind(a, b, equal_var=False, nan_policy="omit"); pvals[j] = p
    sig = (pvals < 0.05)

    plt.figure(figsize=(8, 4.5))
    for a, b in contiguous_regions(sig):
        xa, xb = x_hours[a], x_hours[b]
        plt.axvspan(xb, xa, color="0.9", zorder=0)

    plt.plot(x_hours, m_s, lw=2.0, label=f"SEP >10 MeV (n={int(np.nanmax(n_s))})")
    plt.fill_between(x_hours, m_s-ci_s, m_s+ci_s, alpha=0.25, linewidth=0)

    plt.plot(x_hours, m_n, lw=2.0, label=f"NSEP (n={int(np.nanmax(n_n))})")
    plt.fill_between(x_hours, m_n-ci_n, m_n+ci_n, alpha=0.25, linewidth=0)

    plt.title(f"{varname}: Mean ±95% CI (shaded: p<0.05)")
    plt.xlabel("Hours Before Flare"); plt.ylabel(varname)
    plt.xlim(24, 0); plt.axvline(0, color="k", ls="--", lw=1)
    plt.grid(True, ls="--", alpha=0.4); plt.legend(frameon=False)
    plt.tight_layout()
    base = os.path.join(outdir, f"curves_{varname}_gt10")
    plt.savefig(base + ".pdf"); plt.savefig(base + ".png", dpi=300); plt.close()
    print(f"Saved: {base}.pdf/.png")

def plot_effectsize_heatmap_rownorm(effects, outdir):
    varnames = list(effects.keys())
    D = np.vstack([effects[v]["diff"] for v in varnames])
    P = np.vstack([effects[v]["p"] for v in varnames])
    Dn = (D - np.nanmean(D, axis=1, keepdims=True)) / (np.nanstd(D, axis=1, ddof=1, keepdims=True) + 1e-9)
    x_hours = np.linspace(24, 0, D.shape[1])

    plt.figure(figsize=(14, 6))
    im = plt.imshow(Dn, aspect="auto", cmap="coolwarm",
                    extent=[24, 0, 0, len(varnames)], vmin=-2.5, vmax=2.5)
    cb = plt.colorbar(im, label="Row-normalized effect")

    for i in range(len(varnames)):
        sig = P[i] < 0.05
        xs = x_hours[sig]; ys = np.full(xs.shape, i + 0.5)
        plt.scatter(xs, ys, marker="|", s=30, c="k")

    plt.yticks(np.arange(0.5, len(varnames)+0.5), varnames)
    plt.xlabel("Hours Before Flare"); plt.ylabel("Variable")
    plt.title("Row-normalized Effect: SEP (>10 MeV) − NSEP (markers: p<0.05)", fontsize=16)
    plt.tight_layout()
    base = os.path.join(outdir, "effectsize_heatmap_rownorm_gt10")
    plt.savefig(base + ".pdf"); plt.savefig(base + ".png", dpi=300); plt.close()
    print(f"Saved: {base}.pdf/.png")

def make_summary_key_times(varname, cols, SEP, NSEP, hours=[24,12,6,0]):
    T = len(cols); x_hours = np.linspace(24, 0, num=T)
    rows = []
    for h in hours:
        j = np.argmin(np.abs(x_hours - h))
        a, b = SEP[:, j], NSEP[:, j]
        mean_sep, mean_nsp = np.nanmean(a), np.nanmean(b)
        diff = mean_sep - mean_nsp
        _, p = ttest_ind(a, b, equal_var=False, nan_policy="omit")
        rows.append({"variable": varname, "hours_before": float(x_hours[j]),
                     "sep_gt10_mean": mean_sep, "nsep_mean": mean_nsp,
                     "difference": diff, "p_value": p})
    return rows

# ---------- RUN ----------
effects = {}
summary_rows = []
outdir = "/content/sample_data/figs_gt10"

for var, path in variables.items():
    print(f"\n=== {var} ===")
    try:
        cols, SEP_gt10, NSEP, n_gt10, n_nsep = build_gt10_sep_vs_nsep(path)
        print(f"Counts -> SEP >10 MeV: {n_gt10}, NSEP: {n_nsep}")
        if (SEP_gt10.size == 0) or (NSEP.size == 0):
            print("[WARN] Empty group(s); skipping plots."); continue
    except Exception as e:
        print(f"[WARN] {var}: {e}"); continue

    # curves
    plot_curves_with_ci(var, cols, SEP_gt10, NSEP, outdir)

    # effects for heatmap
    diffs, pvals = [], []
    for j in range(len(cols)):
        a, b = SEP_gt10[:, j], NSEP[:, j]
        diffs.append(np.nanmean(a) - np.nanmean(b))
        _, p = ttest_ind(a, b, equal_var=False, nan_policy="omit")
        pvals.append(p)
    effects[var] = {"diff": np.array(diffs), "p": np.array(pvals)}

    # summary rows
    summary_rows += make_summary_key_times(var, cols, SEP_gt10, NSEP, hours=[24,12,6,0])

# heatmap
if effects:
    plot_effectsize_heatmap_rownorm(effects, outdir)

# summary CSV
if summary_rows:
    df_sum = pd.DataFrame(summary_rows)
    out_csv = os.path.join(outdir, "summary_key_times_gt10.csv")
    df_sum.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv}")



=== Vx ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_Vx_gt10.pdf/.png

=== FlowSpeed ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_FlowSpeed_gt10.pdf/.png

=== ProtonDensity ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_ProtonDensity_gt10.pdf/.png

=== ProtonTemp ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_ProtonTemp_gt10.pdf/.png

=== F ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_F_gt10.pdf/.png

=== Xs ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_Xs_gt10.pdf/.png

=== Xl ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_Xl_gt10.pdf/.png

=== P4 ===
Counts -> SEP >10 MeV: 168, NSEP: 17542
Saved: /content/sample_data/figs_gt10/curves_P4_gt10.pdf/.png

=== P5 ===
Counts -> SEP >10 MeV: 168