In [3]:
import os, sys
import numpy as np

# Para poder importar src/ desde notebooks/
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)

from src.generate_signals import generate_baseband

# -------------------------
# Parámetros del dataset
# -------------------------
modulations = ["ASK2", "ASK4", "BPSK", "QPSK", "PSK8", "QAM16", "QAM32"]
snr_values = [0, 5, 10, 15, 20]   # SNRs posibles para cada ejemplo

Nsym = 256        # símbolos por ejemplo (input de la CNN)
N_per_mod = 1000  # ejemplos por modulación

# Train/Val split
train_ratio = 0.8

# -------------------------
# Generación
# -------------------------
X_list = []
y_list = []
snr_list = []
label_map = {m: i for i, m in enumerate(modulations)}  # p.ej. {"ASK2":0, ...}

for mod in modulations:
    label = label_map[mod]
    print(f"Generando {N_per_mod} ejemplos de {mod}...")

    for _ in range(N_per_mod):
        snr_db = np.random.choice(snr_values)
        out = generate_baseband(mod, Nsym=Nsym, snr_db=snr_db)
        s = out["s_noisy"]   # tamaño Nsym (complejo)

        # Representación I/Q -> shape (2, Nsym)
        iq = np.stack([np.real(s), np.imag(s)], axis=0).astype(np.float32)

        X_list.append(iq)
        y_list.append(label)
        snr_list.append(snr_db)

X = np.stack(X_list, axis=0)              # (N_total, 2, Nsym)
y = np.array(y_list, dtype=np.int64)      # (N_total,)
snr_arr = np.array(snr_list, dtype=np.int64)

print("X shape:", X.shape)
print("y shape:", y.shape)

# -------------------------
# Mezclar y dividir en train/val
# -------------------------
N_total = X.shape[0]
idx = np.random.permutation(N_total)

X = X[idx]
y = y[idx]
snr_arr = snr_arr[idx]

N_train = int(train_ratio * N_total)

X_train = X[:N_train]
y_train = y[:N_train]
snr_train = snr_arr[:N_train]

X_val = X[N_train:]
y_val = y[N_train:]
snr_val = snr_arr[N_train:]

print("Train:", X_train.shape, y_train.shape)
print("Val  :", X_val.shape, y_val.shape)

# -------------------------
# Guardar en data/
# -------------------------
save_path = os.path.join(ROOT_DIR, "data", "iq_dataset.npz")
np.savez(
    save_path,
    X_train=X_train,
    y_train=y_train,
    snr_train=snr_train,
    X_val=X_val,
    y_val=y_val,
    snr_val=snr_val,
    modulations=np.array(modulations),
)

print("Dataset guardado en:", save_path)


Generando 1000 ejemplos de ASK2...
Generando 1000 ejemplos de ASK4...
Generando 1000 ejemplos de BPSK...
Generando 1000 ejemplos de QPSK...
Generando 1000 ejemplos de PSK8...
Generando 1000 ejemplos de QAM16...
Generando 1000 ejemplos de QAM32...
X shape: (7000, 2, 256)
y shape: (7000,)
Train: (5600, 2, 256) (5600,)
Val  : (1400, 2, 256) (1400,)
Dataset guardado en: c:\Users\walte\Documents\ProyectoCom2\data\iq_dataset.npz
