# 🌌 Inferencia Astrofísica desde Ondas Gravitacionales con Modelos Surrogados

Esta notebook implementa un sistema completo para generar, entrenar e inferir **parámetros físicos** de sistemas binarios compactos (como agujeros negros o estrellas de neutrones) a partir de señales de ondas gravitacionales.

Utiliza modelos surrogados basados en **MLPs y CNNs** y se valida tanto con **datos sintéticos** como con datos reales de **GW150914 (LIGO)**.

### ✅ Objetivos:
- Simular señales GW con parámetros físicos realistas.
- Entrenar modelos surrogados para inferir esos parámetros.
- Validar estadísticamente los modelos.
- Aplicar inferencia directa a un evento real (GW150914).


In [None]:
!pip install pycbc torch h5py pandas scikit-learn matplotlib tqdm

ModuleNotFoundError: No module named 'pycbc'

In [None]:
import numpy as np
import os
import json
import torch
from pycbc.waveform import get_td_waveform
from scipy.signal import spectrogram
from tqdm import tqdm

# Carpeta de salida
output_dir = "gw_dataset"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/signals", exist_ok=True)
os.makedirs(f"{output_dir}/metadata", exist_ok=True)

# Número de muestras a generar
N = 5000
for i in tqdm(range(N), desc="Generando señales"):
    m1 = np.random.uniform(20, 80)
    m2 = np.random.uniform(10, m1)
    spin1z = np.random.uniform(-0.99, 0.99)
    spin2z = np.random.uniform(-0.99, 0.99)
    distance = np.random.uniform(100, 1000)
    redshift = np.random.uniform(0.06, 0.12)       # centrado alrededor de ~0.09
    eccentricity = np.random.uniform(0.0, 0.05)    # la mayoría de GWs tienen órbitas casi circulares
    polarization = np.random.uniform(0.0, 2*np.pi)
    ra = np.random.uniform(0.0, 2*np.pi)
    dec = np.random.uniform(-np.pi/2, np.pi/2)
    tc = np.random.uniform(-0.1, 0.1)

    # Clasificación del tipo de fuente
    if m1 < 3 and m2 < 3:
        source_type = "BNS"
    elif m1 >= 5 and m2 >= 5:
        source_type = "BBH"
    else:
        source_type = "BHNS"

    try:
        hp, hc = get_td_waveform(
            approximant="IMRPhenomD",
            mass1=m1,
            mass2=m2,
            spin1z=spin1z,
            spin2z=spin2z,
            delta_t=1.0/8192,    # 🔧 MÁS RESOLUCIÓN
            f_lower=10,          # 🔧 MÁS BAJA FRECUENCIA
            distance=distance
        )
    except Exception as e:
        print(f"Error en muestra {i}: {e}")
        continue

    noise = np.random.normal(0, 1e-22, len(hp))
    hp += noise

    np.save(f"{output_dir}/signals/hp_{i}.npy", hp.numpy())
    np.save(f"{output_dir}/signals/hc_{i}.npy", hc.numpy())

    metadata = {
        "m1": m1,
        "m2": m2,
        "spin1z": spin1z,
        "spin2z": spin2z,
        "distance": distance,
        "redshift": redshift,
        "eccentricity": eccentricity,
        "polarization": polarization,
        "ra": ra,
        "dec": dec,
        "tc": tc,
        "source_type": source_type,
        "sample_rate": 8192,
        "delta_t": float(hp.delta_t),
        "f_lower": 10,
        "length": len(hp)
    }

    with open(f"{output_dir}/metadata/params_{i}.json", "w") as f:
        json.dump(metadata, f, indent=2)

# 🔧 Normalización y espectrogramas
targets_min = np.array([5, 5, -0.99, -0.99, 100, 0.06, 0.0, 0.0, 0.0, -np.pi/2, -0.1])
targets_max = np.array([80, 80, 0.99, 0.99, 1000, 0.12, 0.05, 2*np.pi, 2*np.pi, np.pi/2, 0.1])

X_list, y_list, class_list = [], [], []
signal_dir = "gw_dataset/signals"
meta_dir = "gw_dataset/metadata"
source_type_map = {"BBH": 0, "BNS": 1, "BHNS": 2}

valid_count = 0
for i in tqdm(range(N), desc="Procesando espectrogramas"):
    try:
        signal_path = os.path.join(signal_dir, f"hp_{i}.npy")
        meta_path = os.path.join(meta_dir, f"params_{i}.json")

        hp = np.load(signal_path)
        with open(meta_path, "r") as f:
            meta = json.load(f)

        f_, t_, Sxx = spectrogram(hp, fs=8192, nperseg=128)
        Sxx = np.log10(Sxx + 1e-10)

        # 🔧 Ajustar tamaño exacto 64x64 con padding si es necesario
        Sxx = Sxx[:64, :64]
        pad_h = 64 - Sxx.shape[0]
        pad_w = 64 - Sxx.shape[1]
        Sxx = np.pad(Sxx, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)

        X_list.append(Sxx)
        valid_count += 1

        targets = np.array([
            meta["m1"], meta["m2"], meta["spin1z"], meta["spin2z"],
            meta["distance"], meta["redshift"], meta["eccentricity"],
            meta["polarization"], meta["ra"], meta["dec"], meta["tc"]
        ])
        targets_norm = (targets - targets_min) / (targets_max - targets_min)
        y_list.append(targets_norm)
        class_list.append(source_type_map[meta["source_type"]])

    except Exception as e:
        print(f"⚠️ Error en muestra {i}: {e}")
        continue

print(f"✅ Señales válidas generadas: {valid_count}/{N}")

# Convertir a tensores
X = torch.tensor(np.stack(X_list)).unsqueeze(1).float()
y = torch.tensor(np.stack(y_list)).float()
labels = torch.tensor(class_list).long()


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SurrogateCNN(nn.Module):
    def __init__(self):
        super(SurrogateCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(128 * 6 * 6, 256)
        self.fc_reg = nn.Linear(256, 11)
        self.fc_cls = nn.Linear(256, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        out_reg = self.fc_reg(x)
        out_cls = F.log_softmax(self.fc_cls(x), dim=1)  # ✅ log_softmax
        return out_reg, out_cls

In [None]:
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

model = SurrogateCNN()
dataset = TensorDataset(X, y, labels)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=1e-5)
loss_reg = nn.MSELoss()
loss_cls = nn.NLLLoss()

alpha = 1.0
beta = 0.1

prev_loss = None
patience_counter = 0
patience_limit = 5
tolerance = 0.01
max_epochs = 10  # por si tarda mucho

for epoch in range(max_epochs):
    model.train()
    total_loss = 0
    for xb, yb, lb in loader:
        pred_reg, pred_cls = model(xb)
        loss = alpha * loss_reg(pred_reg, yb) + beta * loss_cls(pred_cls, lb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Total Loss: {total_loss:.4f}")

    # Comprobación de convergencia
    if prev_loss is not None and abs(total_loss - prev_loss) < tolerance:
        patience_counter += 1
    else:
        patience_counter = 0

    if patience_counter >= patience_limit:
        print(f"✅ Convergencia alcanzada en la época {epoch+1}. Deteniendo entrenamiento.")
        break

    prev_loss = total_loss

# Guardar modelo final
torch.save(model.state_dict(), "modelo_surrogado_weights.pt")


In [None]:
model = SurrogateCNN()
model.load_state_dict(torch.load("modelo_surrogado_weights.pt", map_location="cpu"))
model.eval()

with torch.no_grad():
    pred_reg, pred_cls = model(X[0:1])
    pred_reg = pred_reg.numpy()[0]
    pred_cls = pred_cls.argmax(dim=1).item()

    # Desnormalización
    pred_reg = pred_reg * (targets_max - targets_min) + targets_min

    # Validación física
    pred_reg[6] = np.clip(pred_reg[6], 0, 1)  # excentricidad
    pred_reg[9] = np.clip(pred_reg[9], -np.pi/2, np.pi/2)  # declinación
    pred_reg[10] = np.clip(pred_reg[10], -0.1, 0.1)  # tc

    print("Predicción física:")
    for name, val in zip([
        "m1", "m2", "spin1z", "spin2z", "distance", "redshift",
        "eccentricity", "polarization", "ra", "dec", "tc"
    ], pred_reg):
        print(f"{name}: {val:.4f}")

    tipo_fuente = ["BBH", "BNS", "BHNS"][pred_cls]
    print("Tipo de fuente predicho:", tipo_fuente)


In [None]:
# ID de la muestra sobre la que hiciste la predicción
sample_id = 0

# Cargar valores reales desde el JSON correspondiente
with open(f"gw_dataset/metadata/params_{sample_id}.json") as f:
    valores_reales = json.load(f)

# Emparejar campos relevantes
campos = ["m1", "m2", "spin1z", "spin2z", "distance", "redshift", "eccentricity", "polarization", "ra", "dec", "tc"]
valores_reales_vec = np.array([valores_reales[c] for c in campos])

# Calcular errores
errores_abs = np.abs(pred_reg - valores_reales_vec)
errores_rel = errores_abs / (np.abs(valores_reales_vec) + 1e-8)

# Mostrar resultados
print("\n📊 Comparación modelo vs. real:")
print(f"{'Parámetro':<15} {'Predicho':>10} {'Real':>10} {'Error abs.':>12} {'Error rel.':>12}")
for i, nombre in enumerate(campos):
    print(f"{nombre:<15} {pred_reg[i]:10.4f} {valores_reales_vec[i]:10.4f} {errores_abs[i]:12.4f} {errores_rel[i]*100:11.2f}%")


MAE medio (regresión): 86234352.000
Accuracy tipo de sistema: 0.200
              precision    recall  f1-score   support

         BNS       0.00      0.00      0.00        10
        BHNS       0.00      0.00      0.00         6
         BBH       0.20      1.00      0.33         4

    accuracy                           0.20        20
   macro avg       0.07      0.33      0.11        20
weighted avg       0.04      0.20      0.07        20



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
