In [None]:
#Works well for M=4,N=4, groups, have permutation and constellation rotation
# MM-SEFDM: Complete Standalone Working Code
# All functions included - ready to run independently

import numpy as np
import matplotlib.pyplot as plt
from itertools import permutations
from math import factorial, log2, floor
import pickle
import os
import gc

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

# ============ Helper Functions ============

def simple_progress(iterable, desc="Progress", leave=True):
    """Simple progress indicator"""
    total = len(iterable) if hasattr(iterable, '__len__') else None
    for i, item in enumerate(iterable):
        if total and ((i + 1) % max(1, total // 10) == 0 or (i + 1) == total):
            print(f"\r{desc}: {i+1}/{total}", end='', flush=True)
        yield item
    if leave:
        print()

# ============ Signal Processing (From Working Code) ============

def sefdm_modulate(freq_signal, alpha):
    """SEFDM modulation with bandwidth compression alpha"""
    N = len(freq_signal)
    time_signal = np.zeros(N, dtype=complex)
    for n in range(N):
        for k in range(N):
            time_signal[n] += freq_signal[k] * np.exp(2j * np.pi * k * alpha * n / N)
    time_signal /= np.sqrt(N)
    return time_signal

def generate_multimode_lut(N, M):
    """Generate multi-mode LUT with mode-dependent constellations"""
    all_perms = list(permutations(range(N)))
    index_bits = floor(log2(factorial(N)))
    num_valid_perms = 2 ** index_bits
    valid_perms = all_perms[:num_valid_perms]

    base_symbols = np.exp(2j * np.pi * np.arange(M) / M)
    mode_constants = [np.exp(2j * np.pi * q / (N * M)) for q in range(N)]
    mode_constellations = [mode_constants[k] * base_symbols for k in range(N)]

    lut = {}
    label = 0

    for perm in valid_perms:
        for symbol_indices in np.ndindex(*([M]*N)):
            freq_signal = np.zeros(N, dtype=complex)
            for pos_idx, subcarrier_idx in enumerate(perm):
                symbol = mode_constellations[pos_idx][symbol_indices[pos_idx]]
                freq_signal[subcarrier_idx] = symbol
            lut[label] = {'freq_signal': freq_signal}
            label += 1

    bits_per_symbol = int(np.log2(len(lut)))
    print(f"LUT: {len(lut)} entries, {bits_per_symbol} bits/group")
    return lut, bits_per_symbol

def generate_rayleigh_channel(N, channel_model='flat'):
    """Generate Rayleigh fading channel"""
    if channel_model == 'flat':
        h = (np.random.randn() + 1j * np.random.randn()) / np.sqrt(2)
        return h * np.ones(N, dtype=complex)
    return (np.random.randn(N) + 1j * np.random.randn(N)) / np.sqrt(2)

def add_awgn(signal, snr_db):
    """Add AWGN noise to signal"""
    if snr_db == np.inf:
        return signal
    snr_linear = 10 ** (snr_db / 10)
    signal_power = np.mean(np.abs(signal)**2)
    noise_power = signal_power / snr_linear
    return signal + np.sqrt(noise_power / 2) * (np.random.randn(*signal.shape) +
                                                  1j * np.random.randn(*signal.shape))

# ============ Custom Layers ============

class PositionalEncoding(layers.Layer):
    """Trainable positional encoding"""
    def __init__(self, d_model, max_len=10, **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.max_len = max_len

    def build(self, input_shape):
        self.pos_embedding = self.add_weight(
            name='pos_embedding',
            shape=(self.max_len, self.d_model),
            initializer='uniform',
            trainable=True
        )
        super().build(input_shape)

    def call(self, x):
        batch_size = tf.shape(x)[0]
        seq_len = tf.shape(x)[1]
        positions = self.pos_embedding[:seq_len, :]
        positions = tf.expand_dims(positions, 0)
        positions = tf.tile(positions, [batch_size, 1, 1])
        return x + positions

    def get_config(self):
        config = super().get_config()
        config.update({"d_model": self.d_model, "max_len": self.max_len})
        return config

class ChannelAdaptiveGating(layers.Layer):
    """Channel-adaptive gating mechanism"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        feature_dim = input_shape[0][-1]
        self.gate_net = keras.Sequential([
            layers.Dense(feature_dim * 2, activation='relu'),
            layers.Dense(feature_dim, activation='sigmoid')
        ], name='gate_network')
        super().build(input_shape)

    def call(self, inputs):
        features, channel_conditions = inputs
        gates = self.gate_net(channel_conditions)
        if len(features.shape) == 3:
            gates = tf.expand_dims(gates, 1)
        return features * gates

    def get_config(self):
        return super().get_config()

class TransformerBlock(layers.Layer):
    """Transformer block with multi-head attention"""
    def __init__(self, d_model, num_heads, ff_dim, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout

    def build(self, input_shape):
        self.att = layers.MultiHeadAttention(
            num_heads=self.num_heads,
            key_dim=self.d_model // self.num_heads,
            dropout=self.dropout_rate
        )
        self.ffn = keras.Sequential([
            layers.Dense(self.ff_dim, activation='gelu'),
            layers.Dropout(self.dropout_rate),
            layers.Dense(self.d_model),
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(self.dropout_rate)
        self.dropout2 = layers.Dropout(self.dropout_rate)
        super().build(input_shape)

    def call(self, x, training=False):
        attn_output = self.att(x, x, training=training)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        ffn_output = self.ffn(out1, training=training)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "d_model": self.d_model,
            "num_heads": self.num_heads,
            "ff_dim": self.ff_dim,
            "dropout": self.dropout_rate
        })
        return config

# ============ Architecture ============

def build_advanced_cnn_transformer_hybrid(input_shape, num_classes,
                                         d_model=64, num_heads=4, num_layers=2,
                                         ff_dim=128, dropout=0.15):
    """
    Complete working architecture:
    - CNN with residual connections
    - Trainable positional encoding
    - Channel-adaptive gating
    - Transformer blocks with multi-head attention
    - Dual pooling
    - GELU activations
    """
    signal_input = layers.Input(shape=input_shape, name='signal_input')
    alpha_input = layers.Input(shape=(1,), name='alpha_input')
    snr_input = layers.Input(shape=(1,), name='snr_input')
    channel_type_input = layers.Input(shape=(1,), name='channel_type_input')

    # CNN with residual connections
    conv1 = layers.Conv1D(32, 3, padding='same', activation='relu', name='conv1_k3')(signal_input)
    conv1 = layers.LayerNormalization(name='conv1_ln')(conv1)
    conv1 = layers.Dropout(dropout, name='conv1_dropout')(conv1)

    conv2 = layers.Conv1D(48, 5, padding='same', activation='relu', name='conv2_k5')(conv1)
    conv2 = layers.LayerNormalization(name='conv2_ln')(conv2)
    conv2 = layers.Dropout(dropout, name='conv2_dropout')(conv2)

    conv3 = layers.Conv1D(64, 7, padding='same', activation='relu', name='conv3_k7')(conv2)
    conv3 = layers.LayerNormalization(name='conv3_ln')(conv3)

    residual = layers.Conv1D(64, 1, name='residual_proj')(signal_input)
    cnn_features = layers.Add(name='residual_add')([conv3, residual])
    cnn_features = layers.Dropout(dropout, name='conv_final_dropout')(cnn_features)

    # Channel embeddings (alpha, SNR, channel type)
    alpha_embed = layers.Dense(64, activation='relu')(alpha_input)
    alpha_embed = layers.LayerNormalization()(alpha_embed)
    alpha_embed = layers.Dense(64, activation='relu')(alpha_embed)
    alpha_embed = layers.LayerNormalization()(alpha_embed)
    alpha_embed = layers.Dense(d_model, activation='tanh')(alpha_embed)

    snr_embed = layers.Dense(64, activation='relu')(snr_input)
    snr_embed = layers.LayerNormalization()(snr_embed)
    snr_embed = layers.Dense(d_model, activation='tanh')(snr_embed)

    channel_type_embed = layers.Dense(32, activation='relu')(channel_type_input)
    channel_type_embed = layers.Dense(d_model//2, activation='tanh')(channel_type_embed)

    channel_conditions = layers.Concatenate()([alpha_embed, snr_embed, channel_type_embed])
    channel_conditions = layers.Dense(d_model*2, activation='gelu')(channel_conditions)
    channel_conditions = layers.Dropout(dropout)(channel_conditions)
    channel_conditions = layers.Dense(d_model, activation='tanh')(channel_conditions)

    # Channel-adaptive gating
    cnn_features_gated = ChannelAdaptiveGating()([cnn_features, channel_conditions])

    # Positional encoding + Transformer
    x = PositionalEncoding(d_model, max_len=input_shape[0])(cnn_features_gated)

    for i in range(num_layers):
        x = TransformerBlock(d_model, num_heads, ff_dim, dropout)(x)

    # Dual pooling
    global_avg = layers.GlobalAveragePooling1D()(x)
    global_max = layers.GlobalMaxPooling1D()(x)
    pooled = layers.Concatenate()([global_avg, global_max])

    # Fusion with channel conditions
    combined = layers.Concatenate()([pooled, channel_conditions])

    # Classification head
    x = layers.Dense(256, activation='gelu')(combined)
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(dropout*1.5)(x)

    x = layers.Dense(128, activation='gelu')(x)
    x = layers.LayerNormalization()(x)
    x = layers.Dropout(dropout)(x)

    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = Model(
        inputs=[signal_input, alpha_input, snr_input, channel_type_input],
        outputs=outputs,
        name='Advanced_CNN_Transformer_Hybrid'
    )

    return model

# ============ Data Generation ============

def generate_curriculum_data(lut, group_size, alpha_range, snr_range, num_samples,
                            channel_type='awgn', curriculum_level='easy'):
    """Generate training data with curriculum strategy"""
    print(f"\nGenerating {num_samples} samples (Curriculum: {curriculum_level})")
    print(f"  Channel: {channel_type}")
    print(f"  Alpha: {alpha_range[0]}-{alpha_range[1]}")
    print(f"  SNR: {snr_range[0]}-{snr_range[1]} dB")

    alpha_low, alpha_high = alpha_range
    snr_low, snr_high = snr_range

    if curriculum_level == 'easy':
        alpha_bins = np.linspace(alpha_low, alpha_high, 5)
        snr_bins = np.linspace(snr_low, snr_high, 6)
    else:
        alpha_bins = np.linspace(alpha_low, alpha_high, 7)
        snr_bins = np.linspace(snr_low, snr_high, 8)

    samples_per_stratum = max(1, num_samples // (len(alpha_bins) * len(snr_bins)))

    X_real, X_imag, y = [], [], []
    alpha_values, snr_values, channel_types = [], [], []

    total_samples = len(alpha_bins) * len(snr_bins) * samples_per_stratum
    sample_count = 0

    for i in range(len(alpha_bins)):
        for j in range(len(snr_bins)):
            for _ in range(samples_per_stratum):
                sample_count += 1
                if sample_count % (total_samples // 10) == 0:
                    print(f"Progress: {sample_count}/{total_samples}", end='\r', flush=True)

                label = np.random.randint(0, len(lut))

                alpha = np.random.uniform(alpha_bins[i],
                                         alpha_bins[min(i+1, len(alpha_bins)-1)])
                snr_db = np.random.uniform(snr_bins[j],
                                          snr_bins[min(j+1, len(snr_bins)-1)])

                tx_freq = lut[label]['freq_signal']
                tx_time = sefdm_modulate(tx_freq, alpha)

                if channel_type == 'awgn':
                    channel = np.ones(group_size, dtype=complex)
                    channel_type_val = 0.0
                else:
                    channel = generate_rayleigh_channel(group_size, 'flat')
                    channel_type_val = 1.0

                rx_time = tx_time * channel
                rx_time = add_awgn(rx_time, snr_db)
                rx_time_eq = rx_time / (channel + 1e-10)

                X_real.append(np.real(rx_time_eq))
                X_imag.append(np.imag(rx_time_eq))
                y.append(label)
                alpha_values.append(alpha)
                snr_values.append(snr_db)
                channel_types.append(channel_type_val)

    print()

    X = np.stack([np.array(X_real), np.array(X_imag)], axis=-1).astype(np.float32)
    y = np.array(y, dtype=np.int32)
    alpha_values = np.array(alpha_values, dtype=np.float32).reshape(-1, 1)
    snr_values = np.array(snr_values, dtype=np.float32).reshape(-1, 1)
    channel_types = np.array(channel_types, dtype=np.float32).reshape(-1, 1)

    X_mean, X_std = np.mean(X), np.std(X)
    X = (X - X_mean) / (X_std + 1e-8)

    alpha_normalized = (alpha_values - alpha_low) / (alpha_high - alpha_low)
    snr_normalized = (snr_values - snr_low) / (snr_high - snr_low)

    print(f"Data: X={X.shape}")
    return (X, alpha_normalized, snr_normalized, channel_types, y,
            X_mean, X_std, alpha_low, alpha_high, snr_low, snr_high)

# ============ Testing Functions ============

def transmit_receive_grouped_system(labels, lut, N_total, groups, alpha, snr_db,
                                    channel_type='rayleigh', rayleigh_model='flat'):
    """Transmit and receive for grouped system"""
    group_size = N_total // groups
    tx_time_total = np.zeros(N_total, dtype=complex)

    for g, label in enumerate(labels):
        tx_freq = lut[label]['freq_signal']
        tx_time = sefdm_modulate(tx_freq, alpha)
        tx_time_total[g*group_size:(g+1)*group_size] = tx_time

    if channel_type == 'awgn':
        channel_total = np.ones(N_total, dtype=complex)
    else:
        channel_total = np.zeros(N_total, dtype=complex)
        for g in range(groups):
            h = generate_rayleigh_channel(group_size, rayleigh_model)
            channel_total[g*group_size:(g+1)*group_size] = h

    rx_time_total = tx_time_total * channel_total
    rx_time_total = add_awgn(rx_time_total, snr_db)

    rx_groups = []
    channel_groups = []
    for g in range(groups):
        rx_groups.append(rx_time_total[g*group_size:(g+1)*group_size])
        channel_groups.append(channel_total[g*group_size:(g+1)*group_size])

    return rx_groups, channel_groups

def test_ber(model, lut, N_total, groups, alpha_test, snr_test,
            X_mean, X_std, alpha_train_range, snr_train_range,
            num_symbols=1000, channel_type='rayleigh'):
    """Test BER"""
    group_size = N_total // groups
    bits_per_group = int(np.log2(len(lut)))

    alpha_low, alpha_high = alpha_train_range
    snr_low, snr_high = snr_train_range

    total_bit_errors = 0
    total_bits = 0

    channel_type_value = 0.0 if channel_type == 'awgn' else 1.0

    for i in simple_progress(range(num_symbols),
                            desc=f"  α={alpha_test:.2f}, SNR={snr_test}dB",
                            leave=False):
        true_labels = [np.random.randint(0, len(lut)) for _ in range(groups)]

        rx_groups, channel_groups = transmit_receive_grouped_system(
            true_labels, lut, N_total, groups, alpha_test, snr_test,
            channel_type, 'flat'
        )

        for g in range(groups):
            rx_eq = rx_groups[g] / (channel_groups[g] + 1e-10)
            X_test = np.stack([np.real(rx_eq), np.imag(rx_eq)], axis=-1)
            X_test = X_test.reshape(1, group_size, 2)
            X_test = (X_test - X_mean) / (X_std + 1e-8)

            alpha_norm = np.array([[(alpha_test - alpha_low) / (alpha_high - alpha_low)]])
            snr_norm = np.array([[(snr_test - snr_low) / (snr_high - snr_low)]])
            channel_type_norm = np.array([[channel_type_value]])

            pred = model.predict([X_test, alpha_norm, snr_norm, channel_type_norm], verbose=0)
            detected_label = np.argmax(pred[0])

            true_bits = np.array([int(b) for b in format(true_labels[g], f'0{bits_per_group}b')])
            det_bits = np.array([int(b) for b in format(detected_label, f'0{bits_per_group}b')])
            total_bit_errors += np.sum(true_bits != det_bits)
            total_bits += bits_per_group

    ber = total_bit_errors / total_bits if total_bits > 0 else 0
    return max(ber, 1e-7)

# ============ Main Training Script ============

if __name__ == "__main__":
    # System parameters
    N_TOTAL = 16
    GROUPS = 4
    N = N_TOTAL // GROUPS  # N=4
    M = 2

    # Curriculum parameters
    BASE_ALPHA_RANGE = (0.5, 1.0)
    BASE_SNR_RANGE = (0, 35)

    PHASE1_ALPHA_RANGE = (0.7, 1.0)
    PHASE1_SNR_RANGE = (15, 35)

    # Samples
    NUM_TRAIN_P1 = 200000
    NUM_VAL_P1 = 40000
    NUM_TRAIN_P2 = 200000
    NUM_VAL_P2 = 40000

    BATCH_SIZE = 128
    EPOCHS_P1 = 10
    EPOCHS_P2 = 60

    # Test parameters
    ALPHA_TEST_VALUES = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    SNR_TEST_VALUES = np.arange(0, 36, 5)
    NUM_TEST = 1000

    print("="*70)
    print("MM-SEFDM: COMPLETE STANDALONE CODE")
    print("="*70)
    print(f"System: N={N}, Groups={GROUPS}, M={M}-PSK")
    print(f"Architecture: CNN+PosEnc+ChannelGating+Transformer")
    print(f"\nPhase 1 (AWGN, Easy): α∈{PHASE1_ALPHA_RANGE}, SNR∈{PHASE1_SNR_RANGE}")
    print(f"Phase 2 (Rayleigh, Full): α∈{BASE_ALPHA_RANGE}, SNR∈{BASE_SNR_RANGE}")
    print(f"Strategy: Fresh model for each phase")
    print("="*70)

    # Generate LUT
    lut, bits_per_group = generate_multimode_lut(N, M)
    num_classes = len(lut)

    save_path = f"./MM_SEFDM_Standalone_N{N}"
    os.makedirs(save_path, exist_ok=True)

    weights_file_p1 = os.path.join(save_path, "phase1_weights.h5")
    weights_file_p2 = os.path.join(save_path, "phase2_weights.h5")

    # ============ PHASE 1: AWGN Pretraining ============
    print(f"\n{'='*70}")
    print("PHASE 1: AWGN PRETRAINING (Easy Curriculum)")
    print(f"{'='*70}")

    X_train_p1, alpha_train_p1, snr_train_p1, channel_type_train_p1, y_train_p1, \
    X_mean_p1, X_std_p1, alpha_low_p1, alpha_high_p1, snr_low_p1, snr_high_p1 = \
        generate_curriculum_data(
            lut, N, PHASE1_ALPHA_RANGE, PHASE1_SNR_RANGE, NUM_TRAIN_P1,
            channel_type='awgn', curriculum_level='easy'
        )

    X_val_p1, alpha_val_p1, snr_val_p1, channel_type_val_p1, y_val_p1, _, _, _, _, _, _ = \
        generate_curriculum_data(
            lut, N, PHASE1_ALPHA_RANGE, PHASE1_SNR_RANGE, NUM_VAL_P1,
            channel_type='awgn', curriculum_level='easy'
        )

    print("\n>>> Building Phase 1 model...")
    tf.keras.backend.clear_session()
    gc.collect()

    model_p1 = build_advanced_cnn_transformer_hybrid(
        (N, 2), num_classes, d_model=64, num_heads=4, num_layers=2,
        ff_dim=128, dropout=0.15
    )

    model_p1.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    print(f"Parameters: {model_p1.count_params():,}")

    callbacks_p1 = [
        EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True,
                     mode='max', verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                        min_lr=1e-6, verbose=1),
        ModelCheckpoint(weights_file_p1, monitor='val_accuracy', save_best_only=True,
                       mode='max', save_weights_only=True, verbose=1)
    ]

    print("\n>>> Training Phase 1...")
    history_p1 = model_p1.fit(
        [X_train_p1, alpha_train_p1, snr_train_p1, channel_type_train_p1], y_train_p1,
        validation_data=([X_val_p1, alpha_val_p1, snr_val_p1, channel_type_val_p1], y_val_p1),
        batch_size=BATCH_SIZE,
        epochs=EPOCHS_P1,
        callbacks=callbacks_p1,
        verbose=2
    )

    best_acc_p1 = max(history_p1.history['val_accuracy'])
    print(f"\n✓ Phase 1 Complete! Best val accuracy: {best_acc_p1:.4f}")

    del X_train_p1, alpha_train_p1, snr_train_p1, channel_type_train_p1, y_train_p1
    del X_val_p1, alpha_val_p1, snr_val_p1, channel_type_val_p1, y_val_p1
    del model_p1
    gc.collect()

    # ============ PHASE 2: Rayleigh Training ============
    print(f"\n{'='*70}")
    print("PHASE 2: RAYLEIGH TRAINING (Full Curriculum, FRESH MODEL)")
    print(f"{'='*70}")

    X_train_p2, alpha_train_p2, snr_train_p2, channel_type_train_p2, y_train_p2, \
    X_mean_p2, X_std_p2, alpha_low_p2, alpha_high_p2, snr_low_p2, snr_high_p2 = \
        generate_curriculum_data(
            lut, N, BASE_ALPHA_RANGE, BASE_SNR_RANGE, NUM_TRAIN_P2,
            channel_type='rayleigh', curriculum_level='full'
        )

    X_val_p2, alpha_val_p2, snr_val_p2, channel_type_val_p2, y_val_p2, _, _, _, _, _, _ = \
        generate_curriculum_data(
            lut, N, BASE_ALPHA_RANGE, BASE_SNR_RANGE, NUM_VAL_P2,
            channel_type='rayleigh', curriculum_level='full'
        )

    print("\n>>> Building Phase 2 model (fresh)...")
    tf.keras.backend.clear_session()
    gc.collect()

    model_p2 = build_advanced_cnn_transformer_hybrid(
        (N, 2), num_classes, d_model=64, num_heads=4, num_layers=2,
        ff_dim=128, dropout=0.15
    )

    model_p2.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-4, clipnorm=1.0),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    print(f"Parameters: {model_p2.count_params():,}")

    callbacks_p2 = [
        EarlyStopping(monitor='val_accuracy', patience=15, restore_best_weights=True,
                     mode='max', verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                        min_lr=1e-6, verbose=1),
        ModelCheckpoint(weights_file_p2, monitor='val_accuracy', save_best_only=True,
                       mode='max', save_weights_only=True, verbose=1)
    ]

    print("\n>>> Training Phase 2...")
    history_p2 = model_p2.fit(
        [X_train_p2, alpha_train_p2, snr_train_p2, channel_type_train_p2], y_train_p2,
        validation_data=([X_val_p2, alpha_val_p2, snr_val_p2, channel_type_val_p2], y_val_p2),
        batch_size=BATCH_SIZE,
        epochs=EPOCHS_P2,
        callbacks=callbacks_p2,
        verbose=2
    )

    best_acc_p2 = max(history_p2.history['val_accuracy'])
    print(f"\n✓ Phase 2 Complete! Best val accuracy: {best_acc_p2:.4f}")

    print(f"\n{'='*70}")
    print("TRAINING SUMMARY")
    print(f"{'='*70}")
    print(f"Phase 1 (AWGN):     {best_acc_p1:.4f}")
    print(f"Phase 2 (Rayleigh): {best_acc_p2:.4f}")
    print(f"{'='*70}")

    model_p2.load_weights(weights_file_p2)

    with open(os.path.join(save_path, 'stats.pkl'), 'wb') as f:
        pickle.dump({
            'X_mean': X_mean_p2, 'X_std': X_std_p2,
            'alpha_low': alpha_low_p2, 'alpha_high': alpha_high_p2,
            'snr_low': snr_low_p2, 'snr_high': snr_high_p2
        }, f)

    del X_train_p2, alpha_train_p2, snr_train_p2, channel_type_train_p2, y_train_p2
    del X_val_p2, alpha_val_p2, snr_val_p2, channel_type_val_p2, y_val_p2
    gc.collect()

    # ============ TESTING ============
    print(f"\n{'='*70}")
    print("TESTING ON RAYLEIGH CHANNEL")
    print(f"{'='*70}")

    ber_grid = np.zeros((len(ALPHA_TEST_VALUES), len(SNR_TEST_VALUES)))

    for i, alpha_test in enumerate(ALPHA_TEST_VALUES):
        print(f"\nTesting Alpha = {alpha_test:.2f}")
        for j, snr_test in enumerate(SNR_TEST_VALUES):
            ber = test_ber(
                model_p2, lut, N_TOTAL, GROUPS, alpha_test, snr_test,
                X_mean_p2, X_std_p2,
                (alpha_low_p2, alpha_high_p2),
                (snr_low_p2, snr_high_p2),
                num_symbols=NUM_TEST,
                channel_type='rayleigh'
            )
            ber_grid[i, j] = ber
            print(f"  SNR={snr_test:2d} dB: BER={ber:.6e}")

    with open(os.path.join(save_path, 'results.pkl'), 'wb') as f:
        pickle.dump({
            'alpha_values': ALPHA_TEST_VALUES,
            'snr_values': list(SNR_TEST_VALUES),
            'ber_grid': ber_grid.tolist(),
            'best_acc_p1': best_acc_p1,
            'best_acc_p2': best_acc_p2
        }, f)

    # ============ PLOTTING ============
    plt.figure(figsize=(14, 8))
    colors = plt.cm.viridis(np.linspace(0, 1, len(ALPHA_TEST_VALUES)))

    for i, alpha_val in enumerate(ALPHA_TEST_VALUES):
        label = f'α={alpha_val}' + (' (OFDM)' if alpha_val == 1.0 else '')
        plt.semilogy(SNR_TEST_VALUES, ber_grid[i, :],
                    marker='o', linewidth=2.5, markersize=7,
                    color=colors[i], label=label)

    plt.xlabel('SNR (dB)', fontsize=14, fontweight='bold')
    plt.ylabel('BER', fontsize=14, fontweight='bold')
    plt.title(f'MM-SEFDM: Complete Architecture\n'
              f'CNN+PosEnc+ChannelGating+Transformer\n'
              f'N={N}, {num_classes} classes, Groups={GROUPS}',
              fontsize=12, fontweight='bold')
    plt.grid(True, which='both', alpha=0.3)
    plt.legend(loc='best', fontsize=11)
    plt.ylim([1e-5, 1])
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'ber_standalone.png'), dpi=300)
    plt.show()

    print(f"\n{'='*70}")
    print("✓ TRAINING AND TESTING COMPLETE!")
    print(f"{'='*70}")
    print(f"Results saved to: {save_path}")
    print(f"{'='*70}")