# Main Code

In [2]:
import os
import warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
import numpy as np
try:
    import sionna
except ImportError as e:
    import os
    os.system("pip install sionna")
    import sionna

from sionna.channel import utils
from comsys import Transmitter, Receiver, Channel
from phasenoise import phase_noise
from plots import SNRVsLinewidthPlotter, PlotInputOutput, PlotNN, SNRVsLinewidthPlotterNN
# from neural_network import MyModel
from neural_network_new import MyModel

class Transceiver:
    def __init__(self, gpu_num=0):
        self.set_gpu(gpu_num)
        self.dtype = tf.complex64

        # Parameters
        self.beta = 0.1
        self.span_in_symbols = 128
        self.samples_per_symbol = 10
        self.beta_2 = -21.67
        self.t_norm = 1e-12
        self.z_norm = 1e3
        self.linewidth = 200e3
        self.f_c = 193.55e12
        self.length_sp = 4000.0
        self.alpha = 0.046
        self.num_bits_per_symbol = 4
        self.batch_size = 1
        self.num_symbols = 11030

        mem = 501
        if not mem % 2:
            warnings.warn("Even number of filter taps for moving average. Expanding by 1.")
            mem = mem + (1 - mem % 2)
        self.filter = tf.ones((mem, 1, 1, 1), dtype=tf.as_dtype(self.dtype).real_dtype)
        self.mem_cut = (mem // 2) * 2
        if not self.mem_cut:
            self.mem_cut = None        

        self.transmitter = Transmitter(self.num_bits_per_symbol, self.batch_size, self.num_symbols, self.samples_per_symbol, self.beta, self.span_in_symbols)
        self.channel = Channel(self.alpha, self.beta_2, self.f_c, self.length_sp, self.t_norm, self.dtype)
        self.receiver = Receiver(self.linewidth, self.t_norm, self.samples_per_symbol, self.transmitter.rcf)
        
        # Initialize the neural network model
        self.nn_equalise = MyModel()
        self.optimizer = tf.keras.optimizers.Adam()

    def set_gpu(self, gpu_num):
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}"
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            try:
                tf.config.experimental.set_memory_growth(gpus[0], True)
            except RuntimeError as e:
                print(e)
        tf.get_logger().setLevel('ERROR')

    def cal_EVM(self, x, y_cpr):
        mse = tf.reduce_mean(tf.math.square(tf.abs(x[..., self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf + 500: -(self.receiver.gardner.avglenhalf + 
                                                                                    self.receiver.gardner.flenhalf) ]-y_cpr)))
        return 10*tf.math.log(mse)/tf.math.log(10.0)

    def calculate_snr(self, x, y_cpr, y_cpr_wo_tr):
        signal_power = tf.reduce_mean(tf.math.square(tf.abs(x[..., self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf + 500: -(self.receiver.gardner.avglenhalf + 
                                                                                    self.receiver.gardner.flenhalf) ])))
        noise_power = tf.reduce_mean(tf.math.square(tf.abs(x[..., self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf + 500: -(self.receiver.gardner.avglenhalf + 
                                                                                    self.receiver.gardner.flenhalf) ]-y_cpr)))
        snr = 10 * tf.math.log(signal_power / noise_power) / tf.math.log(10.0)
            
        signal_power_wo_tr = tf.reduce_mean(tf.math.square(tf.abs(x)))
        noise_power_wo_tr = tf.reduce_mean(tf.math.square(tf.abs(x[..., self.mem_cut:] - y_cpr_wo_tr))) 

        snr_wo_tr = 10 * tf.math.log(signal_power_wo_tr / noise_power_wo_tr) / tf.math.log(10.0)
        
        return snr, snr_wo_tr

    def evaluate_snr(self, linewidths, link_distances, its=100):
        results = {'with_tr': {}, 'without_tr': {}}
        for length in link_distances:
            self.channel.length_sp = length
            self.channel = Channel(self.alpha, self.beta_2, self.f_c, length, self.t_norm, self.dtype)
            assert abs(float(self.channel.ss_fn._length) - float(length)) < 1e-3
            snr_values_with_tr = []
            snr_values_without_tr = []
            for lw in linewidths:
                self.receiver.linewidth = lw
                sent_symbols = tf.TensorArray(dtype=tf.complex64, size=its, dynamic_size=False)
                received_symbols_with_tr = tf.TensorArray(dtype=tf.complex64, size=its, dynamic_size=False)
                received_symbols_wo_tr = tf.TensorArray(dtype=tf.complex64, size=its, dynamic_size=False)
                for i in range(its):
                    y_cpr_wo_tr, y_cpr, sent, received = self.run()
                    sent_symbols = sent_symbols.write(i, sent)
                    received_symbols_with_tr = received_symbols_with_tr.write(i, received)
                    received_symbols_wo_tr = received_symbols_wo_tr.write(i, y_cpr_wo_tr)
                sent_symbols = sent_symbols.stack()
                received_symbols_with_tr = received_symbols_with_tr.stack()
                received_symbols_wo_tr = received_symbols_wo_tr.stack()
                snr_with_tr, snr_without_tr = self.calculate_snr(sent_symbols, received_symbols_with_tr, received_symbols_wo_tr)
                snr_values_with_tr.append(snr_with_tr.numpy())
                snr_values_without_tr.append(snr_without_tr.numpy())
            results['with_tr'][length] = snr_values_with_tr
            results['without_tr'][length] = snr_values_without_tr
        return results

    def run(self):
        x = self.transmitter.generate_qam_symbols()

        x_us = self.transmitter.upsample(x)

        x_rcf = self.transmitter.apply_rcf(x_us)

        x_rcf_padded, padding_left, padding_right = self.transmitter.pad_signal(x_rcf)

        y = self.channel.transmit(x_rcf_padded)
        
        y_pn = self.receiver.add_phase_noise(y, phase_noise())

        y_mf = self.receiver.matched_filter(y_pn)

        y_cdc = self.channel.compensate_dispersion(y_mf)

        y_lpf = self.receiver.low_pass_filter(y_cdc)
        
        y_ds = self.receiver.downsample(y_cdc, padding_left, padding_right)

        y_normalised = self.receiver.normalize(y_ds)

        y_tr, timing_errors = self.receiver.timing_recovery(y_normalised)

        y_cpr = self.receiver.cpr(y_tr, x[..., self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf: -(self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf)], self.dtype, self.mem_cut, self.filter)

        y_cpr_wo_tr = self.receiver.cpr(y_normalised, x, self.dtype, self.mem_cut, self.filter)
        
        y_cpr_normalised = self.receiver.normalize(y_cpr)
        
        return y_cpr_wo_tr, y_cpr, x, y_cpr_normalised
        
    def cal_mse(self, x, y):
        mse = tf.reduce_mean(tf.math.square(tf.abs(x - y)))
        signal_power = tf.reduce_mean(tf.math.square(tf.abs(x)))
        snr = 10 * tf.math.log(signal_power / mse) / tf.math.log(10.0)
        
        return mse, snr

    def train_and_test(self, iterations, linewidths, link_distances):
        # Initialize result storage
        test_results = {lw: {length: {"mse": [], "snr": []} for length in link_distances} for lw in linewidths}
    
        for idx, lw in enumerate(linewidths):
            print(f"\nTraining with linewidth: {lw} Hz")
    
            # Load weights if not the first linewidth
            if idx > 0:
                self.nn_equalise.load_weights("final_model_weights.h5")
                print(f"Loaded model weights for continued training.")
            
            for length in link_distances:
                # Update the link distance in the channel
                self.channel.length_sp = length
                print(f"\nTraining with linewidth: {lw} Hz and Distance: {length} km")
                
                # Training on the current linewidth and link distance
                for i in range(iterations):
                    y_cpr_wo_tr, y_cpr, tx_symbols, rx_symbols = self.run()
                    tx_symbols_short = tx_symbols[..., self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf + 500: -(self.receiver.gardner.avglenhalf + 
                                                                                                        self.receiver.gardner.flenhalf)]
                    x_train = rx_symbols
                    y_train = tx_symbols_short
        
                    with tf.GradientTape() as tape:
                        network_out = self.nn_equalise(x_train, training=True)
                        network_out = self.receiver.normalize(network_out)
                        loss, _ = self.cal_mse(y_train, network_out)
        
                    gradients = tape.gradient(loss, self.nn_equalise.trainable_variables)
                    self.optimizer.apply_gradients(zip(gradients, self.nn_equalise.trainable_variables))
                    print(f"Iteration {i+1}/{iterations}, Loss: {loss.numpy()}")
    
                # Save the model weights after training and testing
                self.nn_equalise.save_weights("final_model_weights.h5")
                print(f"Model weights saved after training and testing with linewidth {lw} Hz.\n")
        
                # Test the model and store results for the current linewidth and link distance
                print(f"\nTesting with linewidth: {lw} Hz and Distance: {length} km")
                original_mse, original_snr, nn_mse, nn_snr, tx_symbols_arr, rx_symbols_arr, rx_symbols_nn_arr = self.test(self.nn_equalise, num_symbols=100)
                print(f"Testing MSE - Linewidth: {lw}, Link Distance: {length}, Original: {original_mse.numpy()}, Neural Network: {nn_mse.numpy()}")
                    
                # Store results for the current linewidth and link distance
                test_results[lw][length]["mse"].append({"original_mse": original_mse.numpy(), "nn_mse": nn_mse.numpy()})
                test_results[lw][length]["snr"].append({"original_snr": original_snr.numpy(), "nn_snr": nn_snr.numpy()})
    
        return test_results

    
    def test(self, model, num_symbols):
        tx_symbols_arr = tf.TensorArray(dtype=tf.complex64, size=num_symbols)
        rx_symbols_arr = tf.TensorArray(dtype=tf.complex64, size=num_symbols)
        rx_symbols_nn_arr = tf.TensorArray(dtype=tf.complex64, size=num_symbols)
        for i in range(num_symbols):
            y_cpr_wo_tr, y_cpr, tx_symbols, rx_symbols = self.run()
            tx_symbols_short = tx_symbols[..., self.receiver.gardner.avglenhalf + self.receiver.gardner.flenhalf + 500: -(self.receiver.gardner.avglenhalf + 
                                                                                    self.receiver.gardner.flenhalf) ]
            network_out = model.predict(rx_symbols)
            network_out = self.receiver.normalize(network_out)
            # print("network_out : ", network_out)
            tx_symbols_arr = tx_symbols_arr.write(i, tx_symbols_short)
            rx_symbols_arr = rx_symbols_arr.write(i, rx_symbols)
            rx_symbols_nn_arr = rx_symbols_nn_arr.write(i, network_out)

        tx_symbols_arr = tx_symbols_arr.stack()
        # print("tx_symbols_arr : ", tx_symbols_arr)
        rx_symbols_arr = rx_symbols_arr.stack()
        # print("rx_symbols_arr : ", rx_symbols_arr)
        rx_symbols_nn_arr = rx_symbols_nn_arr.stack()
        # print("rx_symbols_nn_arr : ", rx_symbols_nn_arr)
        org_mse, org_snr = self.cal_mse(tf.reshape(tx_symbols_arr, [-1]), tf.reshape(rx_symbols_arr, [-1]))
        nn_mse, nn_snr = self.cal_mse(tf.reshape(tx_symbols_arr, [-1]), tf.reshape(rx_symbols_nn_arr, [-1]))
        return org_mse, org_snr, nn_mse, nn_snr, tx_symbols_arr, rx_symbols_arr, rx_symbols_nn_arr

if __name__ == "__main__":    
    pipeline = Transceiver()
    its = 100
    sent_symbols = tf.TensorArray(dtype=tf.complex64, size=its, dynamic_size=False)
    received_symbols = tf.TensorArray(dtype=tf.complex64, size=its, dynamic_size=False)

    for i in range(its):
        y_cpr_wo_tr, y_cpr, sent, received = pipeline.run()
        sent_symbols = sent_symbols.write(i, sent)
        received_symbols = received_symbols.write(i, received)

    sent_symbols = sent_symbols.stack()
    received_symbols = received_symbols.stack()

    print("Final shape of received symbols:", received_symbols.shape)
    print("Final shape of sent symbols:", sent_symbols.shape)

Final shape of received symbols: (100, 1, 10000)
Final shape of sent symbols: (100, 1, 11030)


In [None]:
print("Start Train")
iterations = 2000
linewidths = [100e3, 200e3, 300e3, 400e3, 500e3, 750e3, 1000e3]
link_distances = [1e3, 2e3, 4e3, 5e3]
test_results = pipeline.train_and_test(iterations=iterations, linewidths = linewidths, link_distances=link_distances)

Start Train

Training with linewidth: 100000.0 Hz

Training with linewidth: 100000.0 Hz and Distance: 1000.0 km
Iteration 1/2000, Loss: 0.0004390437970869243
Iteration 2/2000, Loss: 0.0002537276304792613
Iteration 3/2000, Loss: 0.00027076929109171033
Iteration 4/2000, Loss: 0.0002160911390092224
Iteration 5/2000, Loss: 0.0002765471872407943
Iteration 6/2000, Loss: 0.0002891708572860807
Iteration 7/2000, Loss: 0.00022931273269932717
Iteration 8/2000, Loss: 0.0001633406209293753
Iteration 9/2000, Loss: 0.00018091502715833485
Iteration 10/2000, Loss: 0.00017868807481136173
Iteration 11/2000, Loss: 0.00016352348029613495
Iteration 12/2000, Loss: 0.00023644856992177665
Iteration 13/2000, Loss: 0.00017571345961187035
Iteration 14/2000, Loss: 0.0002465861034579575
Iteration 15/2000, Loss: 0.00018677716434467584
Iteration 16/2000, Loss: 0.00020323078206274658
Iteration 17/2000, Loss: 0.00016260436677839607
Iteration 18/2000, Loss: 0.00015705097757745534
Iteration 19/2000, Loss: 0.0004903390654

In [None]:
# print("Start Test")
# trained_model = load_model('nn_model')
# pipeline.nn_equalise.load_weights('final_model_weights.h5')
# original_mse, nn_mse, tx_symbols_arr, rx_symbols_arr, rx_symbols_nn_arr = pipeline.test(trained_model, num_symbols=100)
# print(f"Original MSE: {original_mse.numpy()}, Neural Network MSE: {nn_mse.numpy()}")

In [None]:
# plot_io = PlotInputOutput(rx_symbols_arr, tx_symbols_arr)
# plot_io.plot_scatter()

In [None]:
# plot_ionn = PlotInputOutput(rx_symbols_nn_arr, tx_symbols_arr)
# plot_ionn.plot_scatter()

In [None]:
plotter = SNRVsLinewidthPlotterNN(pipeline, linewidths, link_distances)
plotter.plot(test_results)

In [None]:
test_results