In [1]:
import pickle

with open("JUPITER_MASTER_SPECTRA.pkl", "rb") as f:
    data = pickle.load(f)

type(data)

dict

In [2]:
data.keys()

dict_keys(['target', 'wavelength', 'flux', 'notes'])

In [3]:
for k, v in data.items():
    print(k, type(v))


target <class 'str'>
wavelength <class 'numpy.ndarray'>
flux <class 'numpy.ndarray'>
notes <class 'str'>


In [None]:
# =========================
# 0) Imports + config
# =========================
import pickle, numpy as np, time
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

PKL_PATH = "JUPITER_MASTER_SPECTRA.pkl"

# Speed knobs
SEED = 7
N_RESAMPLE = 1024        # 512 if you want even faster
N_SYNTH = 1500           # 800–2000 is plenty for dry run
EPOCHS = 7               # 5–10
BATCH = 128
LR = 1e-3

device = "cuda" if torch.cuda.is_available() else "cpu"
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
if device == "cuda":
    torch.cuda.manual_seed_all(SEED)

# Label space: Jupiter UV -> atmospheric species (not "elements/minerals")
SPECIES = ["CH4", "NH3", "C2H2", "C2H6"]

# Simple band templates (refine later using real line/band libraries)
BANDS = {
    "CH4": [(2350, 220), (2750, 260)],
    "NH3": [(2050,  90), (2150,  90)],
    "C2H2": [(2700, 110), (2810, 100)],
    "C2H6": [(2400, 140), (2550, 120)],
}

def gaussian(x, mu, sigma):
    return np.exp(-0.5 * ((x - mu) / sigma) ** 2)

# =========================
# 1) Load + clean real spectrum
# =========================
with open(PKL_PATH, "rb") as f:
    real = pickle.load(f)

w = np.asarray(real["wavelength"], dtype=float)
f = np.asarray(real["flux"], dtype=float)

mask = np.isfinite(w) & np.isfinite(f)
w, f = w[mask], f[mask]
idx = np.argsort(w)
w, f = w[idx], f[idx]

print("Loaded:", real.get("target", "unknown"), "| points:", len(w), "| device:", device)

# =========================
# 2) Resample to fixed length (fast + consistent)
# =========================
def resample_to_fixed(wave, flux, n=N_RESAMPLE):
    w_new = np.linspace(wave.min(), wave.max(), n)
    f_new = np.interp(w_new, wave, flux)
    return w_new.astype(np.float32), f_new.astype(np.float32)

w_fix, f_fix = resample_to_fixed(w, f, N_RESAMPLE)

# =========================
# 3) Preprocess -> channels (C, N)
# =========================
def robust_norm(x):
    med = np.median(x)
    iqr = np.percentile(x, 75) - np.percentile(x, 25)
    if iqr <= 0:
        iqr = 1.0
    return (x - med) / iqr

def make_channels(wave, flux):
    x = robust_norm(flux)
    d1 = np.gradient(x, wave)
    d2 = np.gradient(d1, wave)
    X = np.stack([x, d1, d2], axis=0).astype(np.float32)  # (3, N)
    return X

X_real = make_channels(w_fix, f_fix)   # (3, N)
X_real_t = torch.tensor(X_real[None, ...], dtype=torch.float32).to(device)  # (1,3,N)

# =========================
# 4) Synthetic data generator
# =========================
def sample_labels():
    y = {sp: 0 for sp in SPECIES}
    # independent balanced presence
    for sp in SPECIES:
        y[sp] = 1 if rng.random() < 0.5 else 0
    if sum(y.values()) == 0:
        y[rng.choice(SPECIES)] = 1
    return y

def synth_spectrum(wave, labels):
    # continuum: gentle curve + offset
    x = (wave - wave.min()) / (wave.max() - wave.min())
    cont = 1.0 + 0.08 * (x - 0.5) + 0.05 * (x - 0.5) ** 2
    cont += rng.normal(0, 0.01)

    spec = cont.copy()

    # multiplicative absorption dips
    for sp, present in labels.items():
        if not present:
            continue
        for (c, w0) in BANDS[sp]:
            # --- CH4 stronger + broader (key change) ---
            if sp == "CH4":
                depth = rng.uniform(0.08, 0.25)
                width = w0 * rng.uniform(0.9, 1.6)
            else:
                depth = rng.uniform(0.03, 0.15)
                width = w0 * rng.uniform(0.7, 1.3)

            dip = 1.0 - depth * gaussian(wave, c, width)
            spec *= dip

    # point noise + mild correlation
    noise = rng.normal(0, 0.01, size=wave.shape[0])
    noise = np.convolve(noise, np.ones(7)/7, mode="same")
    spec = spec + noise
    return spec.astype(np.float32)

def build_dataset(wave, n=N_SYNTH):
    X_list, Y_list = [], []
    for _ in range(n):
        lab = sample_labels()
        spec = synth_spectrum(wave, lab)
        X = make_channels(wave, spec)  # (3, N)
        y = np.array([lab[sp] for sp in SPECIES], dtype=np.float32)
        X_list.append(X)
        Y_list.append(y)
    X = np.stack(X_list, axis=0)  # (B, 3, N)
    Y = np.stack(Y_list, axis=0)  # (B, K)
    return X, Y

X, Y = build_dataset(w_fix, N_SYNTH)

perm = rng.permutation(len(X))
X, Y = X[perm], Y[perm]
n_train = int(0.85 * len(X))
X_train, Y_train = X[:n_train], Y[:n_train]
X_val,   Y_val   = X[n_train:], Y[n_train:]

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train), torch.tensor(Y_train)),
    batch_size=BATCH, shuffle=True
)
val_loader = DataLoader(
    TensorDataset(torch.tensor(X_val), torch.tensor(Y_val)),
    batch_size=BATCH, shuffle=False
)

print("Synthetic dataset:", X_train.shape, X_val.shape)

# Optional sanity check (highly recommended)
print("Label prevalence (train):")
prev = Y_train.mean(axis=0)
for sp, p in zip(SPECIES, prev):
    print(f"{sp}: {float(p):.3f}")


# =========================
# 5) Models: MLP, CNN, GRU
# =========================
K = len(SPECIES)
C = 3
N = N_RESAMPLE

class MLP(nn.Module):
    def __init__(self, c=C, n=N, k=K):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(c*n, 256),
            nn.ReLU(),
            nn.Dropout(0.15),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, k)
        )
    def forward(self, x):  # x: (B,C,N)
        return self.net(x)

class CNN1D(nn.Module):
    def __init__(self, in_ch=C, k=K):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(128, k)
        )
    def forward(self, x):
        return self.net(x)

class GRUClassifier(nn.Module):
    def __init__(self, c=C, k=K, hidden=64):
        super().__init__()
        self.gru = nn.GRU(input_size=c, hidden_size=hidden, num_layers=1,
                          batch_first=True, bidirectional=False)
        self.head = nn.Sequential(
            nn.Linear(hidden, k)
        )
    def forward(self, x):  # x: (B,C,N) -> (B,N,C)
        x = x.permute(0, 2, 1)
        _, h = self.gru(x)          # h: (1,B,H)
        h = h[-1]                   # (B,H)
        return self.head(h)

# =========================
# 6) Train + evaluate utilities
# =========================
loss_fn = nn.BCEWithLogitsLoss()

def micro_f1_from_logits(logits, y_true, thr=0.5):
    probs = torch.sigmoid(logits)
    y_hat = (probs >= thr).float()
    tp = (y_hat * y_true).sum()
    fp = (y_hat * (1 - y_true)).sum()
    fn = ((1 - y_hat) * y_true).sum()
    denom = (2*tp + fp + fn).clamp(min=1e-8)
    return (2*tp / denom).item()

@torch.no_grad()
def evaluate(model):
    model.eval()
    total_loss = 0.0
    all_logits, all_y = [], []
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        total_loss += loss_fn(logits, yb).item() * xb.size(0)
        all_logits.append(logits)
        all_y.append(yb)
    total_loss /= len(X_val)
    logits = torch.cat(all_logits, dim=0)
    y_true = torch.cat(all_y, dim=0)
    f1 = micro_f1_from_logits(logits, y_true)
    return total_loss, f1

def train_model(model, name):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    best = {"loss": 1e9, "state": None, "f1": 0.0}
    t0 = time.time()

    for epoch in range(1, EPOCHS + 1):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()

        vloss, vf1 = evaluate(model)
        if vloss < best["loss"]:
            best = {"loss": vloss, "state": {k: v.detach().cpu() for k, v in model.state_dict().items()}, "f1": vf1}
        print(f"{name} | epoch {epoch:02d} | val_loss={vloss:.4f} | microF1={vf1:.3f}")

    dt = time.time() - t0
    return best, dt

# =========================
# 7) Train all three
# =========================
results = []

mlp = MLP()
best_mlp, t_mlp = train_model(mlp, "MLP")
results.append(("MLP", best_mlp["loss"], best_mlp["f1"], t_mlp, best_mlp))

cnn = CNN1D()
best_cnn, t_cnn = train_model(cnn, "CNN1D")
results.append(("CNN1D", best_cnn["loss"], best_cnn["f1"], t_cnn, best_cnn))

gru = GRUClassifier(hidden=64)
best_gru, t_gru = train_model(gru, "GRU")
results.append(("GRU", best_gru["loss"], best_gru["f1"], t_gru, best_gru))

print("\n=== Model comparison (lower loss better) ===")
for name, lossv, f1v, dt, _ in sorted(results, key=lambda x: x[1]):
    print(f"{name:5s} | val_loss={lossv:.4f} | microF1={f1v:.3f} | train_time={dt:.1f}s")

# pick best by val_loss
best_name, best_loss, best_f1, best_dt, best_blob = sorted(results, key=lambda x: x[1])[0]
print(f"\nBest model: {best_name} | val_loss={best_loss:.4f} | microF1={best_f1:.3f}")

# =========================
# 8) Inference on real Jupiter spectrum
# =========================
def load_best_model(name, state):
    if name == "MLP":
        m = MLP()
    elif name == "CNN1D":
        m = CNN1D()
    else:
        m = GRUClassifier(hidden=64)
    m.load_state_dict(state)
    m.to(device).eval()
    return m

best_model = load_best_model(best_name, best_blob["state"])

with torch.no_grad():
    logits = best_model(X_real_t)[0]
    probs = torch.sigmoid(logits).cpu().numpy()

print("\nPredicted probabilities on REAL Jupiter spectrum:")
for sp, p in sorted(zip(SPECIES, probs), key=lambda x: -x[1]):
    print(f"{sp:>4}: {p:.3f}")

# =========================
# 9) Save best model artifact (optional)
# =========================
# save_path = f"/mnt/data/{real.get('target','TARGET').upper()}_{best_name}_best.pt"
# torch.save({
#     "model_type": best_name,
#     "state_dict": best_blob["state"],
#     "species": SPECIES,
#     "n_resample": N_RESAMPLE,
#     "notes": "Trained on synthetic band-mixtures; dry-run model."
# }, save_path)
# print("\nSaved best model to:", save_path)


Loaded: JUPITER | points: 1024 | device: cpu
Synthetic dataset: (1275, 3, 1024) (225, 3, 1024)
Label prevalence (train):
CH4: 0.518
NH3: 0.496
C2H2: 0.536
C2H6: 0.518
MLP | epoch 01 | val_loss=0.2431 | microF1=0.909
MLP | epoch 02 | val_loss=0.1464 | microF1=0.932
MLP | epoch 03 | val_loss=0.1213 | microF1=0.946
MLP | epoch 04 | val_loss=0.1198 | microF1=0.951
MLP | epoch 05 | val_loss=0.1012 | microF1=0.958
MLP | epoch 06 | val_loss=0.1119 | microF1=0.954
MLP | epoch 07 | val_loss=0.1033 | microF1=0.953
CNN1D | epoch 01 | val_loss=0.6889 | microF1=0.626
CNN1D | epoch 02 | val_loss=0.6752 | microF1=0.591
CNN1D | epoch 03 | val_loss=0.6451 | microF1=0.645
CNN1D | epoch 04 | val_loss=0.6169 | microF1=0.658
CNN1D | epoch 05 | val_loss=0.6079 | microF1=0.676
CNN1D | epoch 06 | val_loss=0.5981 | microF1=0.640
CNN1D | epoch 07 | val_loss=0.5903 | microF1=0.676
GRU | epoch 01 | val_loss=0.6912 | microF1=0.470
GRU | epoch 02 | val_loss=0.6860 | microF1=0.484
GRU | epoch 03 | val_loss=0.6762 | 