In [6]:
import pickle, numpy as np, time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from scipy.signal import savgol_filter
import optuna

# -------------------------
# Config
# -------------------------
PKL_PATH = r"Spectral Analysis Planets pkl\JUPITER_MASTER_SPECTRA.pkl"

SEED = 7
N_RESAMPLE = 1024

# Final training
N_SYNTH_FINAL = 4000
EPOCHS_FINAL = 25
PATIENCE_FINAL = 7

BATCH = 128

# Optuna
N_TRIALS = 25
N_SYNTH_TUNE = 1600
EPOCHS_TUNE = 12
PATIENCE_TUNE = 4

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed_all(SEED)

SPECIES = ["CH4", "NH3", "C2H2", "C2H6"]

BANDS = {
    "CH4": [(2350, 260), (2750, 320)],
    "NH3": [(2050, 140), (2150, 140)],
    "C2H2": [(2700, 140), (2810, 130)],
    "C2H6": [(2400, 170), (2550, 150)],
}

def gaussian(x, mu, sigma):
    return np.exp(-0.5 * ((x - mu) / sigma) ** 2)

# -------------------------
# Load + resample
# -------------------------
with open(PKL_PATH, "rb") as f:
    real = pickle.load(f)

w = np.asarray(real["wavelength"], dtype=float)
f = np.asarray(real["flux"], dtype=float)
mask = np.isfinite(w) & np.isfinite(f)
w, f = w[mask], f[mask]
idx = np.argsort(w)
w, f = w[idx], f[idx]

def resample_to_fixed(wave, flux, n=N_RESAMPLE):
    w_new = np.linspace(wave.min(), wave.max(), n)
    f_new = np.interp(w_new, wave, flux)
    return w_new.astype(np.float32), f_new.astype(np.float32)

w_fix, f_fix = resample_to_fixed(w, f, N_RESAMPLE)
print("Loaded:", real.get("target","unknown"), "| points:", len(w_fix), "| device:", device)

# -------------------------
# Baseline + channels (baseline window is tunable)
# -------------------------
def compute_baseline(flux, win=151, poly=3):
    n = len(flux)
    win = int(win)
    if win >= n:
        win = n - 1
    if win < 11:
        win = 11
    if win % 2 == 0:
        win += 1
    b = savgol_filter(flux.astype(float), window_length=win, polyorder=poly)
    eps = 1e-12
    b = np.clip(b, np.percentile(b, 1), np.percentile(b, 99)) + eps
    return b.astype(np.float32)

def make_channels(wave, flux, win):
    base = compute_baseline(flux, win=win, poly=3)
    r = (flux / base) - 1.0
    r = (r - np.median(r)) / (np.std(r) + 1e-8)
    d1 = np.gradient(r, wave)
    d2 = np.gradient(d1, wave)
    return np.stack([r, d1, d2], axis=0).astype(np.float32)

# -------------------------
# Dataset builder (DETERMINISTIC via seed)
# -------------------------
def build_dataset(wave, n, win, seed):
    local_rng = np.random.default_rng(seed)
    baseline = compute_baseline(f_fix, win=win, poly=3)

    def sample_labels_local():
        y = {sp: 0 for sp in SPECIES}
        y["CH4"] = 1 if local_rng.random() < 0.85 else 0
        y["NH3"] = 1 if (y["CH4"] and local_rng.random() < 0.25) else 0
        y["C2H2"] = 1 if local_rng.random() < 0.20 else 0
        y["C2H6"] = 1 if local_rng.random() < 0.20 else 0
        if sum(y.values()) == 0:
            y["CH4"] = 1
        return y

    def synth_spectrum_local(labels):
        cont = baseline.copy()
        x = (wave - wave.min()) / (wave.max() - wave.min())
        drift = 1.0 + local_rng.normal(0, 0.01) + local_rng.normal(0, 0.01) * (x - 0.5)
        spec = cont * drift

        for sp, present in labels.items():
            if not present:
                continue
            for (c, w0) in BANDS[sp]:
                if sp == "CH4":
                    depth = local_rng.uniform(0.08, 0.28)
                elif sp == "C2H2":
                    depth = local_rng.uniform(0.02, 0.10)
                else:
                    depth = local_rng.uniform(0.03, 0.18)

                width = w0 * local_rng.uniform(0.85, 1.25)
                c_jit = c + local_rng.normal(0, 15)
                spec *= (1.0 - depth * gaussian(wave, c_jit, width))

        sigma = 0.01 * (np.max(spec) - np.min(spec) + 1e-8)
        noise = local_rng.normal(0, sigma, size=wave.shape[0])
        noise = np.convolve(noise, np.ones(7)/7, mode="same")
        return (spec + noise).astype(np.float32)

    X_list, Y_list = [], []
    for _ in range(n):
        lab = sample_labels_local()
        spec = synth_spectrum_local(lab)
        X_list.append(make_channels(wave, spec, win=win))
        Y_list.append(np.array([lab[sp] for sp in SPECIES], dtype=np.float32))

    X = np.stack(X_list, axis=0)
    Y = np.stack(Y_list, axis=0)

    perm = local_rng.permutation(len(X))
    X, Y = X[perm], Y[perm]
    n_train = int(0.85 * len(X))
    return X[:n_train], Y[:n_train], X[n_train:], Y[n_train:]

# -------------------------
# Train MLP once
# -------------------------
def train_mlp_once(h1, h2, drop1, drop2, lr, wd, win, n_synth, epochs, patience, seed):
    X_train, Y_train, X_val, Y_val = build_dataset(w_fix, n=n_synth, win=win, seed=seed)

    train_loader = DataLoader(
        TensorDataset(torch.tensor(X_train), torch.tensor(Y_train)),
        batch_size=BATCH, shuffle=True
    )
    val_loader = DataLoader(
        TensorDataset(torch.tensor(X_val), torch.tensor(Y_val)),
        batch_size=BATCH, shuffle=False
    )

    C, N = 3, N_RESAMPLE
    K = len(SPECIES)

    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Flatten(),
                nn.Linear(C*N, h1),
                nn.ReLU(),
                nn.Dropout(drop1),
                nn.Linear(h1, h2),
                nn.ReLU(),
                nn.Dropout(drop2),
                nn.Linear(h2, K)
            )
        def forward(self, x):
            return self.net(x)

    model = MLP().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = nn.BCEWithLogitsLoss()

    @torch.no_grad()
    def eval_val():
        model.eval()
        tot = 0.0
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            tot += loss_fn(model(xb), yb).item() * xb.size(0)
        return tot / len(X_val)

    best_loss, best_state, bad = 1e9, None, 0

    for _ in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        vloss = eval_val()
        if vloss < best_loss - 1e-4:
            best_loss = vloss
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    model.load_state_dict(best_state)
    model.eval()

    # Real Jupiter inference logits (for Optuna + final report)
    X_real = make_channels(w_fix, f_fix, win=win)
    X_real_t = torch.tensor(X_real[None, ...], dtype=torch.float32).to(device)
    with torch.no_grad():
        real_logits = model(X_real_t)[0].cpu().numpy()

    return float(best_loss), real_logits, best_state

# -------------------------
# Optuna objective
# -------------------------
def objective(trial):
    # suggest first
    h1 = trial.suggest_categorical("h1", [256, 384, 512, 768])
    h2 = trial.suggest_categorical("h2", [128, 192, 256, 384])
    drop1 = trial.suggest_float("drop1", 0.10, 0.35)
    drop2 = trial.suggest_float("drop2", 0.00, 0.25)
    lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
    wd = trial.suggest_float("weight_decay", 1e-6, 3e-4, log=True)
    win = trial.suggest_categorical("baseline_win", [101, 151, 201, 251])
    T = trial.suggest_float("temp", 1.0, 1.7)

    seed = 10_000 + trial.number

    vloss, logits, _ = train_mlp_once(
        h1=h1, h2=h2, drop1=drop1, drop2=drop2,
        lr=lr, wd=wd, win=win,
        n_synth=N_SYNTH_TUNE,
        epochs=EPOCHS_TUNE,
        patience=PATIENCE_TUNE,
        seed=seed
    )

    probs = 1 / (1 + np.exp(-logits / T))

    penalty = 0.0
    ch4 = float(probs[SPECIES.index("CH4")])
    if ch4 < 0.5:
        penalty += (0.5 - ch4) * 2.0

    non_ch4_max = float(np.max([probs[SPECIES.index(s)] for s in SPECIES if s != "CH4"]))
    if non_ch4_max < 0.10:
        penalty += (0.10 - non_ch4_max) * 1.0

    trial.set_user_attr("jupiter_probs", {sp: float(p) for sp, p in zip(SPECIES, probs)})
    return float(vloss + penalty)

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=N_TRIALS)

print("\nBest params:", study.best_params)
print("Best objective:", study.best_value)
print("Jupiter probs (best trial):", study.best_trial.user_attrs["jupiter_probs"])

# -------------------------
# Final train with best params
# -------------------------
bp = study.best_params
FINAL_SEED = 999

final_vloss, final_logits, final_state = train_mlp_once(
    h1=bp["h1"], h2=bp["h2"], drop1=bp["drop1"], drop2=bp["drop2"],
    lr=bp["lr"], wd=bp["weight_decay"], win=bp["baseline_win"],
    n_synth=N_SYNTH_FINAL, epochs=EPOCHS_FINAL, patience=PATIENCE_FINAL,
    seed=FINAL_SEED
)

T = bp["temp"]
final_probs = 1 / (1 + np.exp(-final_logits / T))

print("\nFINAL (Optuna-tuned) Jupiter probabilities:")
for sp, p in sorted(zip(SPECIES, final_probs), key=lambda x: -x[1]):
    print(f"{sp:>4}: {p:.3f}")

# Save tuned model
save_path = "JUPITER_MLP_OPTUNA.pt"
torch.save({
    "planet": "JUPITER",
    "species": SPECIES,
    "bands": BANDS,
    "best_params": bp,
    "temp": T,
    "val_loss": final_vloss,
    "state_dict": final_state,
    "n_resample": N_RESAMPLE,
}, save_path)
print("\nSaved:", save_path)

[32m[I 2026-02-08 17:46:11,109][0m A new study created in memory with name: no-name-b6e0d86c-3723-477f-80b6-89c905785425[0m


Loaded: JUPITER | points: 1024 | device: cpu


[32m[I 2026-02-08 17:46:15,332][0m Trial 0 finished with value: 0.45023446083068847 and parameters: {'h1': 256, 'h2': 256, 'drop1': 0.25237939807920184, 'drop2': 0.04207927375960574, 'lr': 0.0001998040602020569, 'weight_decay': 5.4752328836882275e-06, 'baseline_win': 101, 'temp': 1.2371647590388182}. Best is trial 0 with value: 0.45023446083068847.[0m
[32m[I 2026-02-08 17:46:18,008][0m Trial 1 finished with value: 0.41696301897366844 and parameters: {'h1': 384, 'h2': 192, 'drop1': 0.19203631002213528, 'drop2': 0.09483857646754132, 'lr': 0.00013131511314116314, 'weight_decay': 2.1306818578067545e-05, 'baseline_win': 101, 'temp': 1.5271151217442496}. Best is trial 1 with value: 0.41696301897366844.[0m
[32m[I 2026-02-08 17:46:20,778][0m Trial 2 finished with value: 0.4391616145769755 and parameters: {'h1': 512, 'h2': 256, 'drop1': 0.21658611187466575, 'drop2': 0.10816339259499558, 'lr': 0.0003118040949231639, 'weight_decay': 2.3166643234161838e-06, 'baseline_win': 151, 'temp': 1.1


Best params: {'h1': 384, 'h2': 192, 'drop1': 0.33524365842588316, 'drop2': 0.1651058652976325, 'lr': 0.0007110310543639465, 'weight_decay': 4.907620008540501e-06, 'baseline_win': 101, 'temp': 1.3345064815359007}
Best objective: 0.3833982288837433
Jupiter probs (best trial): {'CH4': 0.9137865304946899, 'NH3': 0.1581314653158188, 'C2H2': 0.48879021406173706, 'C2H6': 0.5005111694335938}

FINAL (Optuna-tuned) Jupiter probabilities:
 CH4: 0.974
C2H6: 0.073
C2H2: 0.052
 NH3: 0.007

Saved: JUPITER_MLP_OPTUNA.pt
