<a href="https://colab.research.google.com/github/srijani19/Testform1/blob/main/DeepFAVIB_QPSK.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Uniform QPSK, N=8 (No WGAN)

#!/usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Use TF1-style graph execution (disable eager mode in TF2)
tf.compat.v1.disable_eager_execution()

# ——— Hyperparameters ———
TRAIN_SNRs = [0, 4, 10, 13]         # SNRs (in dB) used for training
TEST_SNRs  = np.linspace(0, 15, 16) # SNR sweep (0→15 dB) for evaluation
N_train    = N_test = 100_000       # Number of training & test samples
N_z        = 8                      # Latent categorical variable size
hidden_dims= [300, 200, 100]        # MLP hidden layer sizes (encoder/decoder)
λ          = 0.05                   # KL divergence regularization weight
τ          = 1.0                    # Gumbel-softmax temperature
ε          = 0.01                   # Forward-link error probability
batch_size = 2000                   # Mini-batch size
epochs     = 1000                   # Training epochs per SNR
lr         = 1e-3                   # Learning rate

# Define QPSK constellation points (Gray-mapped, normalized to unit power)
const = np.array([[1,1],[-1,1],[-1,-1],[1,-1]], np.float32)/np.sqrt(2)

# ---------------- Dense Layer Helper ----------------
def dense(x, in_dim, out_dim, name, act=tf.nn.relu):
    W = tf.Variable(tf.random.normal([in_dim,out_dim], stddev=1/np.sqrt(in_dim/2)), name='W_'+name)
    b = tf.Variable(tf.zeros([out_dim]), name='b_'+name)
    y = tf.matmul(x,W)+b
    return act(y) if act else y

# ---------------- Build Encoder/Decoder Graph ----------------
def build_model():
    tf.compat.v1.reset_default_graph()
    x_ph = tf.compat.v1.placeholder(tf.float32, [None,2], name='x')
    y_ph = tf.compat.v1.placeholder(tf.int32,   [None],   name='y')
    is_tr = tf.compat.v1.placeholder_with_default(True, shape=(), name='is_tr')

    # Encoder
    h = x_ph
    for i,d in enumerate(hidden_dims):
        h = dense(h, int(h.shape[1]), d, f'enc{i}')
    logits_z = dense(h, hidden_dims[-1], N_z, 'logitz', act=None)

    # Gumbel-softmax
    g = -tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logits_z))+1e-20)+1e-20)
    z_soft = tf.nn.softmax((logits_z+g)/τ)
    z_hard = tf.one_hot(tf.argmax(logits_z,1), N_z)
    z = tf.cond(is_tr, lambda:z_soft, lambda:z_hard)

    # Forward-link errors
    mask = tf.random.uniform(tf.shape(z)) < ε
    rand_z = tf.one_hot(tf.random.uniform([tf.shape(z)[0]], 0, N_z, tf.int32), N_z)
    z_noisy = tf.where(mask, rand_z, z)

    # Decoder
    h2 = z_noisy
    for i,d in enumerate(hidden_dims):
        h2 = dense(h2, int(h2.shape[1]), d, f'dec{i}')
    logits_x = dense(h2, hidden_dims[-1], 4, 'logits_x', act=None)

    # Loss
    ce = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_ph,logits=logits_x))
    qz = tf.nn.softmax(logits_z)
    kl = tf.reduce_mean(tf.reduce_sum(qz*(tf.math.log(qz+1e-20)-tf.math.log(1.0/N_z)),axis=1))
    loss = ce + λ*kl
    train_op = tf.compat.v1.train.AdamOptimizer(lr).minimize(loss)

    return x_ph, y_ph, is_tr, logits_x, loss, train_op

# ---------------- Main Training & Evaluation ----------------
def main():
    x_ph, y_ph, is_tr, logits_x, loss, train_op = build_model()
    sess = tf.compat.v1.Session()
    sess.run(tf.compat.v1.global_variables_initializer())
    results = {}

    for snr in TRAIN_SNRs:
        print(f"\n=== TRAIN @ {snr} dB ===")
        sigma = np.sqrt(1/(2*10**(snr/10)))
        tx = np.random.randint(0,4,N_train)
        x_train = const[tx] + sigma*np.random.randn(N_train,2)

        # Training loop
        for ep in range(epochs):
            idx = np.random.choice(N_train, batch_size, replace=False)
            sess.run(train_op, {x_ph:x_train[idx], y_ph:tx[idx], is_tr:True})
            if ep%200==0:
                lval = sess.run(loss, {x_ph:x_train[idx], y_ph:tx[idx], is_tr:True})
                print(f"  epoch {ep:4d}, loss {lval:.4f}")

        # Evaluation
        ser = []
        for s in TEST_SNRs:
            sigte = np.sqrt(1/(2*10**(s/10)))
            tx_te = np.random.randint(0,4,N_test)
            x_te = const[tx_te] + sigte*np.random.randn(N_test,2)
            preds = []
            for i in range(0,N_test,batch_size):
                p = sess.run(tf.argmax(logits_x,1), {x_ph:x_te[i:i+batch_size], is_tr:False})
                preds.append(p)
            preds = np.concatenate(preds)
            ser.append(np.mean(preds!=tx_te))
        results[snr] = ser

    plt.figure(figsize=(8,6))
    for snr, ser in results.items():
        plt.semilogy(TEST_SNRs, ser, 'o-', label=f"train @ {snr} dB")
    plt.grid(True, which='both', ls='--', alpha=0.5)
    plt.xlabel("Test SNR (dB)")
    plt.ylabel("SER")
    plt.title("Fig 4(b) – Deep FAVIB (ε=0.01)")
    plt.xlim(0,17.5)
    plt.ylim(1e-5,1)
    plt.legend()
    plt.tight_layout()
    plt.show()

if __name__=="__main__":
    main()


KeyboardInterrupt: 

In [None]:
#Uniform QPSK, N=8 (With WGAN-GP)

#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Use TF1 graph mode (for TF2 environments)
tf.compat.v1.disable_eager_execution()

# ======================== Parameters =========================
TRAIN_SNRs = [0, 4, 10, 13]   # SNRs (in dB) used for training
TEST_SNRs  = np.linspace(0, 15, 16)  # SNR sweep for evaluation

N_train, N_test = int(1e5), int(1e5)  # dataset sizes

N_z = 8                 # latent categorical size
hidden_dims = [100, 50] # encoder/decoder MLP widths

# Loss weights and training knobs
lambda_kl   = 0.01
beta_gp     = 1.0
g_loss_scale= 0.01
batch_size  = 2000
epochs      = 1000
tau         = 1.0
epsilon     = 0.01
lr          = 1e-4

# QPSK constellation
const = np.array([[ 1, 1],
                  [-1, 1],
                  [-1,-1],
                  [ 1,-1]], dtype=np.float32) / np.sqrt(2)

# =================== Dense Layer Helper ===================
def dense(input, in_dim, out_dim, name, activation=tf.nn.relu):
    W = tf.Variable(
        tf.random.normal([in_dim, out_dim], stddev=1/np.sqrt(in_dim/2)),
        name=f"W_{name}"
    )
    b = tf.Variable(tf.zeros([out_dim]), name=f"b_{name}")
    out = tf.matmul(input, W) + b
    return activation(out) if activation else out

# ========================= Model Graph ==========================
def build_model():
    tf.compat.v1.reset_default_graph()
    x = tf.compat.v1.placeholder(tf.float32, [None, 2])   # noisy I/Q input
    y = tf.compat.v1.placeholder(tf.int32,   [None])      # true QPSK index
    is_training = tf.compat.v1.placeholder_with_default(True, shape=())

    # ----- Encoder -----
    h = x
    for i, dim in enumerate(hidden_dims):
        h = dense(h, int(h.shape[1]), dim, f"enc{i}")
    logits_z = dense(h, int(h.shape[1]), N_z, "logitz", activation=None)

    # ----- Gumbel-Softmax -----
    g = -tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logits_z)) + 1e-20))
    z_soft = tf.nn.softmax((logits_z + g) / tau)
    z_hard = tf.one_hot(tf.argmax(logits_z, axis=1), N_z)
    z = tf.cond(is_training, lambda: z_soft, lambda: z_hard)

    # ----- Forward channel latent noise -----
    noise_mask = tf.random.uniform(tf.shape(z)) < epsilon
    z_rand = tf.one_hot(
        tf.random.uniform([tf.shape(z)[0]], 0, N_z, dtype=tf.int32),
        N_z
    )
    z_noisy = tf.where(noise_mask, z_rand, z)

    # ----- Decoder -----
    h_dec = z_noisy
    for i, dim in enumerate(hidden_dims):
        h_dec = dense(h_dec, int(h_dec.shape[1]), dim, f"dec{i}")
    logits_x = dense(h_dec, int(h_dec.shape[1]), 4, "logits_x", activation=None)

    # =================== WGAN-GP Critic ===================
    z_real = tf.one_hot(
        tf.random.uniform([batch_size], 0, N_z, dtype=tf.int32),
        N_z
    )

    def critic(z_in):
        h = dense(z_in, N_z, 50, "d1")
        return dense(h, 50, 1, "d2", activation=None)

    D_real = critic(z_real)
    D_fake = critic(z_soft)

    # Gradient Penalty
    alpha   = tf.random.uniform([batch_size, 1])
    interp  = alpha * z_real + (1 - alpha) * z_soft
    with tf.GradientTape() as tape:
        tape.watch(interp)
        D_interp = critic(interp)
    grads     = tape.gradient(D_interp, interp)
    grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1) + 1e-8)
    gp        = tf.reduce_mean(tf.square(grad_norm - 1.0))

    # ======================== Losses ============================
    ce = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits_x)
    )
    q_z = tf.nn.softmax(logits_z)
    kl  = tf.reduce_mean(
        tf.reduce_sum(q_z * (tf.math.log(q_z + 1e-20) - tf.math.log(1.0 / N_z)), axis=1)
    )
    d_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + beta_gp * gp
    g_loss = -tf.reduce_mean(D_fake)

    loss = ce + lambda_kl * kl + g_loss_scale * g_loss

    all_vars = tf.compat.v1.trainable_variables()
    d_vars   = [v for v in all_vars if "d" in v.name]
    g_vars   = [v for v in all_vars if ("enc" in v.name or "dec" in v.name or "logitz" in v.name)]

    train_op_d = tf.compat.v1.train.AdamOptimizer(lr).minimize(d_loss, var_list=d_vars)
    train_op_g = tf.compat.v1.train.AdamOptimizer(lr).minimize(loss,   var_list=g_vars)

    return x, y, is_training, logits_x, loss, train_op_g, train_op_d

# ===================== Training & Evaluation =====================
def run():
    x_ph, y_ph, is_tr, logits_x, loss, train_op_g, train_op_d = build_model()
    sess = tf.compat.v1.Session()
    sess.run(tf.compat.v1.global_variables_initializer())
    results = {}

    for snr in TRAIN_SNRs:
        print(f"\nTraining @ {snr} dB")
        σ = np.sqrt(1 / (2 * 10**(snr / 10)))
        tx = np.random.randint(0, 4, N_train)
        x_train = const[tx] + σ * np.random.randn(N_train, 2)

        for ep in range(epochs):
            idx = np.random.choice(N_train, batch_size)
            xb, yb = x_train[idx], tx[idx]
            if ep % 5 == 0:
                sess.run(train_op_d, {x_ph: xb, y_ph: yb, is_tr: True})
            sess.run(train_op_g, {x_ph: xb, y_ph: yb, is_tr: True})

            if ep % 200 == 0:
                lv = sess.run(loss, {x_ph: xb, y_ph: yb, is_tr: True})
                print(f"Epoch {ep}, Loss = {lv:.4f}")

        # ----- Evaluation -----
        ser = []
        for snr_te in TEST_SNRs:
            σ_te = np.sqrt(1 / (2 * 10**(snr_te / 10)))
            tx_te = np.random.randint(0, 4, N_test)
            x_te  = const[tx_te] + σ_te * np.random.randn(N_test, 2)

            preds = []
            for i in range(0, N_test, batch_size):
                xb = x_te[i:i + batch_size]
                preds.extend(sess.run(tf.argmax(logits_x, axis=1),
                                      {x_ph: xb, is_tr: False}))
            ser.append(np.mean(np.array(preds) != tx_te))

        results[snr] = ser

    return results

# ============================ Run ===============================
results = run()

# ============================ Plot ==============================
plt.figure(figsize=(8, 6))
for snr, ser in results.items():
    plt.semilogy(TEST_SNRs, ser, marker='o', label=f"Train @ {snr} dB")
plt.grid(True, which='both', ls='--', alpha=0.5)
plt.xlabel("Test SNR (dB)")
plt.ylabel("SER")
plt.title("Figure 4(b) – Deep FAVIB + Balanced WGAN-GP")
plt.legend()
plt.ylim(1e-4, 1)
plt.tight_layout()
plt.show()


In [None]:
#Non-Uniform QPSK, N=8 (No WGAN)

#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# ---------------- TF Setup ----------------
tf.compat.v1.disable_eager_execution()   # run in TF1-style graph mode
np.random.seed(0)                        # numpy reproducibility
tf.compat.v1.set_random_seed(0)          # tensorflow reproducibility

# ---------------- Experiment Parameters ----------------
TRAIN_SNRs = [0, 4, 10, 13]              # training SNRs in dB
TEST_SNRs  = np.linspace(0, 15, 16)      # sweep of SNRs for evaluation
N_train = N_test = 100_000               # number of samples

N_z = 8                                  # latent categorical size
hidden_dims = [100, 50]                  # encoder/decoder hidden layers
lambda_kl = 0.1                          # weight for KL divergence term
tau_start, tau_end = 1.0, 0.7            # Gumbel-Softmax τ (annealed)
epsilon = 0.01                           # forward-link latent noise

batch_size = 2000                        # mini-batch size
epochs = 1000                            # training epochs

# Learning rate schedule
lr_g_start, lr_g_end = 3e-4, 1e-4

# ---------------- Non-uniform priors ----------------
latent_probs = np.array(
    [0.30, 0.20, 0.15, 0.10, 0.08, 0.07, 0.06, 0.04], np.float32
)
latent_probs /= latent_probs.sum()

# QPSK constellation
const = np.array([
    [ 1,  1],
    [-1,  1],
    [-1, -1],
    [ 1, -1]
], np.float32) / np.sqrt(2.0)

M = 4
p_qpsk = np.array([0.50, 0.20, 0.20, 0.10], np.float32)  # non-uniform source probs

# ---------------- Dense Layer ----------------
def dense(x, in_dim, out_dim, name, act=tf.nn.relu):
    W = tf.Variable(
        tf.random.normal([in_dim, out_dim], stddev=1/np.sqrt(max(in_dim/2, 1))),
        name=f"W_{name}"
    )
    b = tf.Variable(tf.zeros([out_dim]), name=f"b_{name}")
    y = tf.matmul(x, W) + b
    return act(y) if act else y

# ---------------- Build Model ----------------
def build_model():
    tf.compat.v1.reset_default_graph()

    x = tf.compat.v1.placeholder(tf.float32, [None, 2], name="x")
    y = tf.compat.v1.placeholder(tf.int32,   [None],   name="y")
    is_tr  = tf.compat.v1.placeholder_with_default(True, shape=(), name="is_tr")
    tau_ph = tf.compat.v1.placeholder_with_default(tau_start, shape=(), name="tau")
    lr_ph  = tf.compat.v1.placeholder_with_default(lr_g_start, shape=(), name="lr")

    prior = tf.constant(latent_probs, tf.float32)

    # Encoder
    h = x
    for i, d in enumerate(hidden_dims):
        h = dense(h, int(h.shape[1]), d, f"enc{i}")
    logits_z = dense(h, int(h.shape[1]), N_z, "logitz", act=None)

    # Gumbel–Softmax
    g = -tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logits_z)) + 1e-20) + 1e-20)
    z_soft = tf.nn.softmax((logits_z + g) / tau_ph)
    z_hard = tf.one_hot(tf.argmax(logits_z, axis=1), N_z)
    z_clean = tf.cond(is_tr, lambda: z_soft, lambda: z_hard)

    # Forward channel latent flips
    noise_mask = tf.random.uniform(tf.shape(z_clean)) < epsilon
    lat_logp   = tf.math.log(prior[None, :] + 1e-20)
    z_rand_idx = tf.random.categorical(lat_logp, tf.shape(z_clean)[0])[:, 0]
    z_rand     = tf.one_hot(z_rand_idx, N_z)
    z_used     = tf.cond(is_tr, lambda: z_clean, lambda: tf.where(noise_mask, z_rand, z_clean))

    # Decoder
    h2 = z_used
    for i, d in enumerate(hidden_dims):
        h2 = dense(h2, int(h2.shape[1]), d, f"dec{i}")
    logits_x = dense(h2, int(h2.shape[1]), M, "logits_x", act=None)

    # Loss: CE + KL(q||r)
    ce = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits_x))
    qz = tf.nn.softmax(logits_z)
    kl_q_r = tf.reduce_mean(
        tf.reduce_sum(qz * (tf.math.log(qz + 1e-20) - tf.math.log(prior + 1e-20)), axis=1)
    )
    loss = ce + lambda_kl * kl_q_r
    train_op = tf.compat.v1.train.AdamOptimizer(lr_ph, beta1=0.5, beta2=0.9).minimize(loss)

    entropy_q = -tf.reduce_mean(tf.reduce_sum(qz * tf.math.log(qz + 1e-20), axis=1))
    perplexity = tf.exp(entropy_q)

    return x, y, is_tr, tau_ph, lr_ph, logits_x, ce, kl_q_r, loss, train_op, perplexity

# ---------------- Training & Evaluation ----------------
def run():
    (x_ph, y_ph, is_tr, tau_ph, lr_ph,
     logits_x, ce_t, kl_t, loss_t, train_op, perp_t) = build_model()

    sess = tf.compat.v1.Session()
    sess.run(tf.compat.v1.global_variables_initializer())
    results = {}

    for snr in TRAIN_SNRs:
        print(f"\n=== TRAIN @ {snr} dB (NO-GAN; params match WGAN) ===")
        sigma = np.sqrt(1.0 / (2.0 * 10.0**(snr / 10.0)))
        tx = np.random.choice(M, size=N_train, p=p_qpsk)
        x_train = const[tx] + sigma * np.random.randn(N_train, 2).astype(np.float32)

        for ep in range(epochs):
            frac = ep / max(epochs - 1, 1)
            lr_now  = lr_g_end + 0.5 * (lr_g_start - lr_g_end) * (1 + np.cos(np.pi * frac))
            tau_now = tau_end +      (tau_start - tau_end) * (1 - frac)

            idx = np.random.choice(N_train, batch_size, replace=False)
            xb, yb = x_train[idx], tx[idx]
            sess.run(train_op, {x_ph: xb, y_ph: yb, is_tr: True, tau_ph: tau_now, lr_ph: lr_now})

            if ep % 200 == 0:
                ce_v, kl_v, loss_v, perp_v = sess.run(
                    [ce_t, kl_t, loss_t, perp_t],
                    {x_ph: xb, y_ph: yb, is_tr: True, tau_ph: tau_now}
                )
                print(f"  ep {ep:4d} | CE {ce_v:.4f}  KL(q||r) {kl_v:.4f}  "
                      f"Tot {loss_v:.4f}  Perp~{perp_v:.2f}")

        # SER evaluation
        ser = []
        for s in TEST_SNRs:
            sigma_te = np.sqrt(1.0 / (2.0 * 10.0**(s / 10.0)))
            tx_te = np.random.choice(M, size=N_test, p=p_qpsk)
            x_te  = const[tx_te] + sigma_te * np.random.randn(N_test, 2).astype(np.float32)

            preds = []
            for i in range(0, N_test, batch_size):
                xb = x_te[i:i+batch_size]
                p  = sess.run(tf.argmax(logits_x, axis=1), {x_ph: xb, is_tr: False})
                preds.append(p)
            preds = np.concatenate(preds)
            ser.append(np.mean(preds != tx_te))
        results[snr] = ser

    return results

# ---------------- Run & Plot ----------------
if __name__ == "__main__":
    results = run()
    plt.figure(figsize=(8, 6))
    for snr, ser in results.items():
        plt.semilogy(TEST_SNRs, ser, marker='o', label=f"Train @ {snr} dB")
    plt.grid(True, which='both', ls='--', alpha=0.5)
    plt.xlabel("Test SNR (dB)")
    plt.ylabel("SER")
    plt.title("Deep FAVIB (NO GAN) — Non-Uniform QPSK")
    plt.legend()
    plt.ylim(1e-4, 1)
    plt.tight_layout()
    plt.show()


In [None]:
#Non-Uniform QPSK, N=8 (With WGAN-GP)
#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# ----- TF1 graph mode (compatible inside TF2) -----
tf.compat.v1.disable_eager_execution()

# ----- Reproducibility -----
np.random.seed(0)
tf.compat.v1.set_random_seed(0)

# ---------------- Experiment Parameters ----------------
TRAIN_SNRs = [0, 4, 10, 13]          # training SNRs (dB)
TEST_SNRs = np.linspace(0, 15, 16)   # evaluation SNR sweep (dB)
N_train = N_test = 100_000           # samples per training/test run

N_z = 8                              # latent categorical dimension
hidden_dims = [100, 50]              # encoder/decoder hidden layers
lambda_kl = 0.1                      # weight for KL(q||r)
tau_start, tau_end = 1.0, 0.7        # Gumbel-Softmax τ schedule (annealed)
epsilon = 0.01                       # latent flip probability (test-time forward noise)

batch_size = 2000                    # batch size
epochs = 1000                        # epochs per SNR run

# Learning rate schedules (cosine decay, match NO-GAN baseline)
lr_g_start, lr_g_end = 3e-4, 1e-4    # generator LR schedule
lr_d_start, lr_d_end = 3e-4, 1e-4    # critic LR schedule

# WGAN-GP specific params
critic_steps = 3                     # critic updates per generator update
beta_gp = 10.0                       # gradient penalty weight
g_loss_scale = 0.01                  # weight of adversarial term in generator loss

# ---------------- Latent prior r(z) ----------------
latent_probs = np.array([0.30, 0.20, 0.15, 0.10,
                         0.08, 0.07, 0.06, 0.04], np.float32)
latent_probs /= latent_probs.sum()

# ---------------- Modulation: Non-uniform QPSK ----------------
const = np.array([[1, 1],
                  [-1, 1],
                  [-1, -1],
                  [1, -1]], np.float32) / np.sqrt(2.0)
M = 4
p_qpsk = np.array([0.50, 0.20, 0.20, 0.10], np.float32)

# ---------------- Dense Layer ----------------
def dense(x, in_dim, out_dim, name, act=tf.nn.relu):
    W = tf.Variable(
        tf.random.normal([in_dim, out_dim], stddev=1 / np.sqrt(max(in_dim / 2, 1))),
        name=f"W_{name}"
    )
    b = tf.Variable(tf.zeros([out_dim]), name=f"b_{name}")
    y = tf.matmul(x, W) + b
    return act(y) if act is not None else y

# ---------------- Build Model ----------------
def build_model():
    tf.compat.v1.reset_default_graph()

    # --- Placeholders ---
    x = tf.compat.v1.placeholder(tf.float32, [None, 2], name="x")     # noisy I/Q input
    y = tf.compat.v1.placeholder(tf.int32, [None], name="y")          # ground-truth symbols
    is_tr = tf.compat.v1.placeholder_with_default(True, shape=(), name="is_tr")    # training flag
    tau_ph = tf.compat.v1.placeholder_with_default(tau_start, shape=(), name="tau")  # Gumbel τ
    lr_g = tf.compat.v1.placeholder_with_default(lr_g_start, shape=(), name="lr_g")  # gen LR
    lr_d = tf.compat.v1.placeholder_with_default(lr_d_start, shape=(), name="lr_d")  # critic LR

    prior = tf.constant(latent_probs, tf.float32)

    # --- Encoder: x -> logits over latent categories ---
    h = x
    for i, d in enumerate(hidden_dims):
        h = dense(h, int(h.shape[1]), d, f"enc{i}")
    logits_z = dense(h, int(h.shape[1]), N_z, "logitz", act=None)

    # --- Gumbel-Softmax bottleneck ---
    g = -tf.math.log(-tf.math.log(tf.random.uniform(tf.shape(logits_z)) + 1e-20) + 1e-20)
    z_soft = tf.nn.softmax((logits_z + g) / tau_ph)        # training: soft sample
    z_hard = tf.one_hot(tf.argmax(logits_z, axis=1), N_z)  # eval: hard one-hot
    z_clean = tf.cond(is_tr, lambda: z_soft, lambda: z_hard)

    # --- Forward channel latent noise at test-time ---
    noise_mask = tf.random.uniform(tf.shape(z_clean)) < epsilon
    lat_logp = tf.math.log(prior[None, :] + 1e-20)
    z_rand_idx = tf.random.categorical(lat_logp, tf.shape(z_clean)[0])[:, 0]
    z_rand = tf.one_hot(z_rand_idx, N_z)
    z_used = tf.cond(is_tr, lambda: z_clean,
                     lambda: tf.where(noise_mask, z_rand, z_clean))

    # --- Decoder: latent -> logits over QPSK symbols ---
    h2 = z_used
    for i, d in enumerate(hidden_dims):
        h2 = dense(h2, int(h2.shape[1]), d, f"dec{i}")
    logits_x = dense(h2, int(h2.shape[1]), M, "logits_x", act=None)

    # --- Task/regularization losses ---
    ce = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits_x)
    )
    qz = tf.nn.softmax(logits_z)
    kl_q_r = tf.reduce_mean(
        tf.reduce_sum(qz * (tf.math.log(qz + 1e-20) - tf.math.log(prior + 1e-20)), axis=1)
    )

    # ---------------- WGAN-GP critic in latent space ----------------
    with tf.compat.v1.variable_scope("critic", reuse=tf.compat.v1.AUTO_REUSE):

        def critic(z_in):
            h = dense(z_in, N_z, 64, "d1")
            return dense(h, 64, 1, "d2", act=None)

        # Real latent codes sampled from prior r(z)
        z_real_idx = tf.random.categorical(tf.math.log(prior[None, :]), batch_size)[:, 0]
        z_real = tf.one_hot(z_real_idx, N_z)

        # Critic scores
        D_real = critic(z_real)                             # score for real codes
        D_fake_detached = critic(tf.stop_gradient(z_soft))  # score for fake codes (detached)

        # Gradient penalty on interpolations
        alpha = tf.random.uniform([batch_size, 1], 0.0, 1.0)
        interp = alpha * z_real + (1.0 - alpha) * tf.stop_gradient(z_soft)
        D_interp = critic(interp)
        grads = tf.gradients(D_interp, [interp])[0]
        grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=1) + 1e-8)
        gp = tf.reduce_mean(tf.square(grad_norm - 1.0))

        # Critic loss: E[D(fake)] - E[D(real)] + λ_gp * GP
        d_loss = tf.reduce_mean(D_fake_detached) - tf.reduce_mean(D_real) + beta_gp * gp

        # Generator adversarial term uses non-detached z_soft
        D_fake_for_G = critic(z_soft)
        g_wgan = -tf.reduce_mean(D_fake_for_G)

    # --- Total generator loss (task + reg + adversarial) ---
    total_gen_loss = ce + lambda_kl * kl_q_r + g_loss_scale * g_wgan

    # --- Variable partitions ---
    d_vars = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='critic'
    )
    g_vars = [v for v in tf.compat.v1.trainable_variables() if v not in d_vars]

    # --- Optimizers ---
    train_op_d = tf.compat.v1.train.AdamOptimizer(lr_d, beta1=0.5, beta2=0.9).minimize(
        d_loss, var_list=d_vars
    )
    train_op_g = tf.compat.v1.train.AdamOptimizer(lr_g, beta1=0.5, beta2=0.9).minimize(
        total_gen_loss, var_list=g_vars
    )

    # Diagnostics
    entropy_q = -tf.reduce_mean(tf.reduce_sum(qz * tf.math.log(qz + 1e-20), axis=1))
    perplexity = tf.exp(entropy_q)

    return (x, y, is_tr, tau_ph, lr_g, lr_d,
            logits_x, ce, kl_q_r, d_loss, g_wgan, total_gen_loss,
            train_op_g, train_op_d, perplexity)

# ---------------- Training & Evaluation ----------------
def run():
    (x_ph, y_ph, is_tr, tau_ph, lr_g_ph, lr_d_ph,
     logits_x, ce_t, kl_t, dloss_t, gw_t, tot_t,
     train_op_g, train_op_d, perp_t) = build_model()

    sess = tf.compat.v1.Session()
    sess.run(tf.compat.v1.global_variables_initializer())

    results = {}

    # Train at each specified SNR
    for snr in TRAIN_SNRs:
        print(f"\n=== TRAIN @ {snr} dB (WGAN-GP; params match baseline) ===")

        sigma = np.sqrt(1.0 / (2.0 * 10.0**(snr / 10.0)))
        tx = np.random.choice(M, size=N_train, p=p_qpsk)
        x_train = const[tx] + sigma * np.random.randn(N_train, 2).astype(np.float32)

        for ep in range(epochs):
            # Cosine LR decay; τ anneals linearly
            frac = ep / max(epochs - 1, 1)
            lr_g_now = lr_g_end + 0.5 * (lr_g_start - lr_g_end) * (1 + np.cos(np.pi * frac))
            lr_d_now = lr_d_end + 0.5 * (lr_d_start - lr_d_end) * (1 + np.cos(np.pi * frac))
            tau_now = tau_end + (tau_start - tau_end) * (1 - frac)

            idx = np.random.choice(N_train, batch_size, replace=False)
            xb, yb = x_train[idx], tx[idx]

            # Critic updates
            for _ in range(critic_steps):
                sess.run(
                    train_op_d,
                    {x_ph: xb, y_ph: yb, is_tr: True, tau_ph: tau_now,
                     lr_g_ph: lr_g_now, lr_d_ph: lr_d_now}
                )

            # Generator update
            sess.run(
                train_op_g,
                {x_ph: xb, y_ph: yb, is_tr: True, tau_ph: tau_now,
                 lr_g_ph: lr_g_now, lr_d_ph: lr_d_now}
            )

            if ep % 200 == 0:
                ce_v, kl_v, d_v, gw_v, tot_v, perp_v = sess.run(
                    [ce_t, kl_t, dloss_t, gw_t, tot_t, perp_t],
                    {x_ph: xb, y_ph: yb, is_tr: True, tau_ph: tau_now}
                )
                print(f"  ep {ep:4d} | CE {ce_v:.4f}  KL {kl_v:.4f}  "
                      f"D {d_v:.4f}  G_wgan {gw_v:.4f}  Tot {tot_v:.4f}  Perp~{perp_v:.2f}")

        # --- SER evaluation over the test SNR sweep ---
        ser = []
        for s in TEST_SNRs:
            sigma_te = np.sqrt(1.0 / (2.0 * 10.0**(s / 10.0)))
            tx_te = np.random.choice(M, size=N_test, p=p_qpsk)
            x_te = const[tx_te] + sigma_te * np.random.randn(N_test, 2).astype(np.float32)

            preds = []
            for i in range(0, N_test, batch_size):
                xb = x_te[i:i + batch_size]
                p = sess.run(tf.argmax(logits_x, axis=1), {x_ph: xb, is_tr: False})
                preds.append(p)
            preds = np.concatenate(preds)
            ser.append(np.mean(preds != tx_te))

        results[snr] = ser

    return results

# ---------------- Run & Plot ----------------
if __name__ == "__main__":
    results = run()

    # Plot SER vs test SNR for each training SNR
    plt.figure(figsize=(8, 6))
    for snr, ser in results.items():
        plt.semilogy(TEST_SNRs, ser, marker='o', label=f"Train @ {snr} dB")
    plt.grid(True, which='both', ls='--', alpha=0.5)
    plt.xlabel("Test SNR (dB)")
    plt.ylabel("SER")
    plt.title("Deep FAVIB + WGAN-GP — Non-Uniform QPSK")
    plt.legend()
    plt.ylim(1e-4, 1)
    plt.tight_layout()
    plt.show()
