In [1]:
import time 
import numpy as np
from numpy import sum, isrealobj, sqrt
from numpy.random import standard_normal

from scipy import signal
from scipy.spatial.distance import cdist

from pynq import (Overlay,
                  allocate)

# some handy functions to use along widgets
from IPython.display import display, Markdown, clear_output
# widget packages
import ipywidgets as widgets
from ipywidgets import (HTML,
                        Label,
                        Button, 
                        Dropdown,
                        IntSlider,
                        Tab,
                        HBox,
                        VBox,
                        GridBox, 
                        Layout, 
                        Output,
                        ButtonStyle)

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import (rc, animation)
rc('animation', html='html5')

import warnings
warnings.filterwarnings("ignore")

plt.rcParams.update({
    "lines.color": "white",
    "patch.edgecolor": "white",
    "axes.facecolor": "white",
    "axes.edgecolor": "lightgray",
    "axes.labelcolor": "white",
    "xtick.color": "white",
    "ytick.color": "white",
    "grid.color": "lightgray",
    "figure.facecolor": "black",
    "figure.edgecolor": "black",
    "savefig.facecolor": "black",
    "savefig.edgecolor": "black",
    "font.size": 22})

plt.style.use('dark_background')

In [2]:
"""
    Note: For the GUI, I am using ipywidgets.
    If some part of the GUI is not understandable please look at the ipywidgets documenation.
    If the confusion persists, mail me at jafri1999@gmail.com.
"""

'\n    Note: For the GUI, I am using ipywidgets.\n    If some part of the GUI is not understandable please look at the ipywidgets documenation.\n    If the confusion persists, mail me at jafri1999@gmail.com.\n'

In [3]:
"""
    Class: IP Overlay
         Class used to run LS, MMSE and LS-DNN IPs.
"""
class IP_Overlay():
    N_SNR = 7
    N_CH = 20

    ip_name = None
    overlay = None
    dma = None
    ip_block = None

    input_buffer = None
    output_buffer = None
    
    def __init__(cls):
        pass

    """
        To run the MMSE IPs, there needs to be some software delay
        before the DMA output is read (to let the IP finish processing the input).
    """
    @classmethod
    def _wait(cls):
        start_time = time.time()
        while (time.time() - start_time) < 0.1:
            pass
    
    """
        Load the IP.
            This function must be run before the run_ip() function.
    """
    @classmethod
    def load_ip(cls, ip_name, ip_path):
        cls.ip_name = ip_name
        cls.overlay = Overlay(ip_path)
        cls.dma = cls.overlay.axi_dma_0

        if cls.ip_name in ["ls_dnn_fl_sp", "ls_dnn_fp_16_8", "ls_dnn_fp_24_8"]:
            cls.ip_block = cls.overlay.lsDnn_0
    
    """
        Run IP.
            mmse_fl_dp IP output is float64 while the rest are float32.
     """      
    @classmethod
    def run_ip(cls):
        if cls.ip_name not in ["mmse_fl_dp"]:
            cls.output_buffer = allocate(shape=(104,), dtype=np.float32)
        else:
            cls.output_buffer = allocate(shape=(104,), dtype=np.float64)
        
        if cls.ip_name in ["ls_dnn_fl_sp", "ls_dnn_fp_16_8", "ls_dnn_fp_24_8"]:
            cls.ip_block.write(0x00, 1)
        
        cls.dma.sendchannel.transfer(cls.input_buffer)
        
        if cls.ip_name in ["mmse_fl_sp", "mmse_fl_dp", "mmse_fl_dp_sp"]:
            cls._wait()
        
        cls.dma.recvchannel.transfer(cls.output_buffer)
    
    """
        Loads the input when the class instance is loaded as an LS IP.
            This function must be run before the run_ip() function.
    """
    @classmethod
    def load_ls_input(cls, yin, xin):
        cls.input_buffer = allocate(shape=(2*104,), dtype=np.float32)
                        
        for idx in range(2*104):
            if idx<104:
                cls.input_buffer[idx] = yin[idx]
            else:
                cls.input_buffer[idx] = xin[idx-104]
    
    """
        Loads the input when the class instance is loaded as an MMSE IP.
            This function must be run before the run_ip() function.
    """
    @classmethod
    def load_mmse_input(cls, yin, xin, snr_val):
        if cls.ip_name in ["mmse_fl_dp"]:
            cls.input_buffer = allocate(shape=(2*104 + 1,), dtype=np.float64)
        
        if cls.ip_name in ["mmse_fl_sp", "mmse_fl_dp_sp"]:
            cls.input_buffer = allocate(shape=(2*104 + 1,), dtype=np.float32)
        
        for idx in range(2*104 + 1):
            if idx < 104:
                cls.input_buffer[idx] = yin[idx]
            elif idx < 2*104:
                cls.input_buffer[idx] = xin[idx-104]
            else:
                cls.input_buffer[idx] = snr_val
    
    """
        Loads the input when the class instance is loaded as an LS-DNN IP.
            This function must be run before the run_ip() function.
    """
    @classmethod
    def load_ls_dnn_input(cls, yin, xin):
        cls.input_buffer = allocate(shape=(2*104,), dtype=np.float32)
                        
        for idx in range(2*104):
            if idx<104:
                cls.input_buffer[idx] = yin[idx]
            else:
                cls.input_buffer[idx] = xin[idx-104]

In [4]:
"""
    Class to store the output of the IP.
    This includes the IP DMA output, Err_hw, Phf_hw and Err_normalized_hw.
"""
class IP_Output():
    def __init__(self):
        self.is_empty = True
        self.IP_output = None
        self.Err_hw = None
        self.Phf_hw = None
        self.Err_normalized_hw = None
        
    def load_output(self, IP_output, Err_hw, Phf_hw, Err_normalized_hw):
        self.is_empty = False
        self.IP_output = IP_output
        self.Err_hw = Err_hw
        self.Phf_hw = Phf_hw
        self.Err_normalized_hw = Err_normalized_hw
    
    """
        If is_empty is True it means the instance is not storing any valid data.
        If is_empty is False it means the instance is storing valid data.
    """
    def is_output_empty(self):
         return self.is_empty
    
    @property
    def _IP_output(self):
        return self.IP_output
    
    @property
    def _Err_normalized_hw(self):
        return self.Err_normalized_hw
    
    @_Err_normalized_hw.setter
    def _Err_normalized_hw(self, err_nrmalized):
        self.Err_normalized_hw = err_normalized

In [6]:
"""
    GUI Class.
"""
class GUI():
    N_SNR = 21
    """
        There are 200 channels in total. 
        But to generate an acceptable error plot in less time only 20 channels are used.
    """
    N_CH = 20  
    
    """
        Dict to store IP outputs.
    """
    ips = {
        "ls_fl_sp": IP_Output(),
        "ls_fp_16_2": IP_Output(),
        "ls_fp_18_2": IP_Output(),
        
        "mmse_fl_sp": IP_Output(),
        "mmse_fl_dp": IP_Output(),
        "mmse_fl_dp_sp": IP_Output(),
        
        "ls_dnn_fl_sp": IP_Output(),
        "ls_dnn_fp_16_8": IP_Output(),
        "ls_dnn_fp_24_8": IP_Output(),
    }

    ip_paths = {
        "ls_fl_sp": "./../bitstreams/ls_wl/ls_float/ls_float.bit",
        "ls_fp_16_2": "./../bitstreams/ls_wl/ls_16/ls_16.bit",
        "ls_fp_18_2": "./../bitstreams/ls_wl/ls_18/ls_18.bit",

    
        "mmse_fl_sp": "./../bitstreams/mmse_wl/mmse_float/mmse_float1.bit",
        "mmse_fl_dp": "./../bitstreams/mmse_wl/mmsed_new/mmsed1.bit",
        "mmse_fl_dp_sp": "./../bitstreams/mmse_wl/mmse_d_f/mmse_d_f1.bit",

        "ls_dnn_fl_sp": "./../bitstreams/ls_dnn/test_ls_dnn_float_32/lsDnn_float_32.bit",
        "ls_dnn_fp_16_8": "./../bitstreams/ls_dnn/test_ls_dnn_fixed_16/lsDnn_fixed_16.bit",
        "ls_dnn_fp_24_8": "./../bitstreams/ls_dnn/test_ls_dnn_fixed_24/lsDnn_fixed_24.bit",
    }

    cur_ip = None
    err_fig_size = (12, 7)
    mag_fig_size = (7, 6)
    
    """
        This is for the Compare functionality only.
        When a particular IP output is being plotted, the IP is set as True.
    """
    plot_list = {
        "ls_fl_sp": False,
        "ls_fp_18_2": False,
        "ls_fp_16_2": False,

        "mmse_fl_dp": False,
        "mmse_fl_sp": False,
        "mmse_fl_dp_sp": False,

        "ls_dnn_fl_sp": False,
        "ls_dnn_fp_16_8": False,
        "ls_dnn_fp_24_8": False,
    }
    
    """
        This dictionary stores the legend to be displayed when the 
        error plot of a particular IP is plotted .
    """
    legend_dict_ip = {
        "ls_fl_sp": "LS SPFL",
        "ls_fp_16_2": "LS FP-16",
        "ls_fp_18_2": "LS FP-18",

        "mmse_fl_dp": "MMSE DPFL",
        "mmse_fl_sp": "MMSE SPFL",
        "mmse_fl_dp_sp": "MMSE DP-SPFL",

        "ls_dnn_fl_sp": "LS-DNN SPFL",
        "ls_dnn_fp_16_8": "LS-DNN FP-16",
        "ls_dnn_fp_24_8": "LS-DNN FP-24",
    }
    
    """
        This dictionary stores the legend to be displayed when the 
        magnitude plot of a particular IP is plotted .
    """
    legend_dict_mag = {
        "ls_fl_sp": "LS",
        "ls_fp_16_2": "LS",
        "ls_fp_18_2": "LS",

        "mmse_fl_dp": "MMSE",
        "mmse_fl_sp": "MMSE",
        "mmse_fl_dp_sp": "MMSE",

        "ls_dnn_fl_sp": "LS-DNN",
        "ls_dnn_fp_16_8": "LS-DNN",
        "ls_dnn_fp_24_8": "LS-DNN",
    }
    
    """
        Theoretical output of the LS and MMSE IP.
    """
    ls_err_th = None
    mmse_err_th = None
    
    def __init__(self):
        pass
    
    """
        Load the thereotical outputs.
    """
    @classmethod
    def _class_init(cls):
        cls.ls_err_th = cls._load_err_file(f"./../data/errset/err_ls_ip_th.txt")
        cls.mmse_err_th = cls._load_err_file(f"./../data/errset/err_mmse_ip_th.txt")
    
    """
        Reset All: Remove all data stored. 
    """ 
    @classmethod
    def _reset_all(cls):
        cls.cur_ip = None
        
        ips = {
            "ls_fl_sp": IP_Output(),
            "ls_fp_16_2": IP_Output(),
            "ls_fp_18_2": IP_Output(),

            "mmse_fl_sp": IP_Output(),
            "mmse_fl_dp": IP_Output(),
            "mmse_fl_dp_sp": IP_Output(),

            "ls_dnn_fl_sp": IP_Output(),
            "ls_dnn_fp_16_8": IP_Output(),
            "ls_dnn_fp_24_8": IP_Output()
        }
        
        plot_list = {
            "ls_fl_sp": False,
            "ls_fp_18_2": False,
            "ls_fp_16_2": False,

            "mmse_fl_dp": False,
            "mmse_fl_sp": False,
            "mmse_fl_dp_sp": False,

            "ls_dnn_fl_sp": False,
            "ls_dnn_fp_24_8": False,
            "ls_dnn_fp_16_8": False,
        }

    """
        Use this function when displaying text through the IPYWideget.Output.
    """
    @classmethod
    def _gui_print(cls, text, color="white", size="15px", family="monospace", decor="none"):
        color_dict = {
            "white": "#FFFFFF",
            "red": "#FF0000"
        }
        
        display (Markdown(f"""<div style="text-align: left; line-height: 0.8">
                                <span style="font-family: {family};
                                color:{color_dict[color]}; 
                                font-size: {size};
                                text-decoration: {decor}"
                                >{text}</span>
                                </div>"""))
    
    """
        For yin, xin, and actual.txt files.
    """
    @staticmethod
    def _load_file(filename):
        data = np.zeros(shape=(200,104))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                line = line.split(",")
                arr = np.asarray(line, dtype=np.float64)
                data[idx] = arr
        return data
    
    """
        For the mmse_snr_val file.
    """
    @staticmethod
    def _load_mmse_snr_val_file(filename):
        data = np.zeros(shape=(21,))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                arr = np.asarray(line, dtype=np.float32)
                data[idx] = arr
        return data   
    
    @staticmethod
    def _load_err_file(filename):
        data = np.zeros(shape=(21,))
        with open(filename, 'r') as reader:
            file_data = reader.readlines()
            for idx, line in enumerate(file_data):
                line = line[:-1]
                # line = line.split(" ")
                arr = np.asarray(line, dtype=np.float32)
                data[idx] = arr
        return data    
    
    """
        This function runs the IP by calling the internal functions of
        the IP_Overlay Class.
        
        The order is as follows:
            1. Load the IP.
            2. Load the IP with the appropriate inputs.
            3. Call run_ip().
    """
    @classmethod
    def _run_ip(cls):
        cls.IP_output = np.zeros([cls.N_SNR, cls.N_CH, 104])
        cls.Err_hw = np.zeros([cls.N_SNR,])
        cls.Phf_hw = np.zeros([cls.N_SNR,])
        cls.Err_normalized_hw = np.zeros([cls.N_SNR,])
        
        xin = cls._load_file(f"./../data/result_3/Xin.dat")[0]
        
        for n_snr in range(cls.N_SNR):
            yin = cls._load_file(f"./../data/result_3/yin/YinC_{n_snr+1}m.dat")
            act = cls._load_file(f"./../data/result_3/act/actC_{n_snr+1}m.dat")

            complex_act = np.zeros([cls.N_CH, 52], dtype=complex)
            complex_IP = np.zeros([cls.N_CH, 52], dtype=complex)
            
            for n_ch in range(cls.N_CH):
                if cls.cur_ip.ip_name in ["ls_fl_sp", 
                                          "ls_fp_16_2",
                                          "ls_fp_18_2"]:
                    cls.cur_ip.load_ls_input(yin[n_ch], xin)
                
                elif cls.cur_ip.ip_name in ["mmse_fl_sp", 
                                            "mmse_fl_dp", 
                                            "mmse_fl_dp_sp"]:
                    snr_val = cls._load_mmse_snr_val_file("./../data/mmse_snr_val.txt")
                    cls.cur_ip.load_mmse_input(yin[n_ch], xin, snr_val[n_snr])

                elif cls.cur_ip.ip_name in ["ls_dnn_fl_sp",
                                            "ls_dnn_fp_16_8", 
                                            "ls_dnn_fp_24_8"]: 
                    cls.cur_ip.load_ls_dnn_input(yin[n_ch], xin)

                cls.cur_ip.run_ip()
                cls.IP_output[n_snr,n_ch,:] = cls.cur_ip.output_buffer
                
                """
                    Convert the DMA output and the actual array into a complex array.
                    The flatten array is stored as follows:
                        Real[0], Real[1], ..., Imag[0], Imag[1], ...
                        
                        There are 52 complex values. 
                        Thus the index of the imaginary part of the first complex number is 52.
                        And so and so forth.
                """
                for n_ch in range(cls.N_CH):
                    for n_f in range(52):
                        complex_act[n_ch, n_f] = act[n_ch, n_f] + 1j*(act[n_ch, n_f+52])
                        complex_IP[n_ch, n_f] = cls.IP_output[n_snr, n_ch, n_f] + 1j*(cls.IP_output[n_snr, n_ch, n_f+52])

                # 2-norm (largest sing. value)
                Err_hw = np.linalg.norm(complex_act - complex_IP, 2)**2
                Phf_hw = np.linalg.norm(complex_act, 2)**2

            cls.Err_hw[n_snr] = Err_hw / 200
            cls.Phf_hw[n_snr] = Phf_hw / 200
            cls.Err_normalized_hw[n_snr] = Err_hw / Phf_hw
        
        """
            Store the IP output, Err_hw, Phf_hw, Err_normalized_hw in the ips dict.
        """
        cls.ips[cls.cur_ip.ip_name].load_output(cls.IP_output, cls.Err_hw, cls.Phf_hw, cls.Err_normalized_hw)

    """
        Plot the error curve.
    """
    @classmethod
    def _plot_err_curve(cls):
        cls.fig = plt.figure(figsize=cls.err_fig_size)
        cls.ax = plt.axes( 
            xlabel="SNR (dB)", 
            ylabel="NMSE"
        )
        
        for ip, out in cls.ips.items():
            if not out.is_output_empty() and cls.plot_list[ip]:
                plt.semilogy(
                    np.arange(-30, 31, 5), 
                    out._Err_normalized_hw[4:17], 
                    lw=2, 
                    marker='o', 
                    label=f"{cls.legend_dict_ip[ip]}"
                )
        
        plt.semilogy(
            np.arange(-30, 31, 5), 
            cls.ls_err_th[4:17], 
            lw=2, 
            marker='o', 
            label=f"LS-TH"
        )
        
        plt.semilogy(
            np.arange(-30, 31, 5), 
            cls.mmse_err_th[4:17], 
            lw=2, 
            marker='o', 
            label=f"MMSE-TH"
        )
        
        cls.legend = cls.ax.legend() 
        plt.show()
   
    """
        Plot the magnitude curve.
    """
    @classmethod
    def _plot_mag_curve(cls, n_snr):
        cls.fig = plt.figure(figsize=cls.mag_fig_size)
        cls.ax = plt.axes( 
            xlabel="Preamble Subcarriers", 
            ylabel="Magnitude"
        )
        
        act = cls._load_file(f"./../data/result_3/act/actC_{n_snr+1}m.dat")
        
        complex_act = np.zeros([52], dtype=complex)
        complex_IP = np.zeros([52], dtype=complex)
        
        for n_f in range(52):
            complex_act[n_f] = act[0, n_f] + 1j*(act[0, n_f+52])

        abs_complex_act = abs(complex_act)
        
        plt.semilogy(
                    np.arange(1, 53, 1),
                    abs_complex_act, 
                    lw=2, 
                    marker='+', 
                    label=f"CIR"
                )

        for ip, out in cls.ips.items():
            if ip in ["ls_fl_sp","mmse_fl_dp", "ls_dnn_fl_sp"] and not out.is_output_empty():
                IP_out = out._IP_output
                for n_f in range(52):
                    complex_IP[n_f] = IP_out[n_snr, 0, n_f] + 1j*(IP_out[n_snr, 0, n_f+52])
                
                abs_complex_IP = abs(complex_IP)
                plt.semilogy(
                    np.arange(1, 53, 1),
                    abs_complex_IP, 
                    lw=2, 
                    marker='o', 
                    label=f"{cls.legend_dict_mag[ip]}"
                )
        
        cls.legend = cls.ax.legend() 
        plt.show()
    
    """
        Create the HTML GUI template.
    """    
    @classmethod
    def _create_widgets(cls):
        """
            This is the CSS file for the HTML code.
        """
        HTML_TEXT = """
        <style>
        .header-label {
            font-family: verdana;
            font-size: 20px;
            color: white;
            text-align: center;
            border-style: solid;
            border-color: white;
            border-radius: 4px;
            padding: 10px 10px 40px 10px;
            margin: 0px 0px 0px 0px;
        }
        
        .header {
            margin: 10px 0px 5px 0px;
            height: auto;
        }
        
        .body {
            background-color: #000000;
            margin: 5px;
            padding: 5px 10px 0px 10px;
            height: fit-content;
            width: auto;
        }
        
        /* menu */
        .menu-btn {
            font-family: Sans-Serif;
            font-size: 12px;
            color: white;
            background-color: #4397EC;
            border-radius: 4px;
        }
        
        .menu-btn-small {
            max-width: 50px;
            min-width: 50px;
            font-family: verdana;
            font-size: 12px;
            color: white;
            background-color: #42ba96;
        }
        
        .menu-ip-label {
            font-family: verdana;
            font-size: 14px;
            color: white;
            text-align: center;
            border-style: solid;
            border-color: white;
            border-radius: 4px;
        }
        
        .menu-ip-area {
            display: flex;
            align-items: center;
            justify-content: center;
            margin: -10px 0px 30px 0px;
            height: 150px;
            overflow: hidden;
        }
        
        .menu-nmse-area {
            display: flex;
            align-items: center;
            justify-content: center;
            margin: -10px 0px 10px 0px;
            height: 150px;
            border-radius: 4px;
            overflow: hidden;
        }
        
        /* vbox for norm_err */
        .norm-err-label {
            font-family: verdana;
            font-size: 16px;
            color: white;
            text-align: center;
            //width: 338px;
        }
        
        .norm-err-vbox {
            border-style: solid;
            border-radius: 4px;
            border-color: #6B6C72;
            background-color: #9A9DA6;
            margin: 0px 0px 0px 0px;
            width: 
        }
        
        /* plot */
        .plot-header {
            margin: -25px 0px 0px 0px;
            height: auto;
        }
        .plot-out-label {
            font-family: verdana;
            font-size: 20px;
            color: white;
            text-align: center;
            border-style: solid;
            border-color: white;
            border-radius: 4px;
            padding: 0px 10px 40px 10px;
            margin: 0px 0px 0px 0px;
        }
        
        .plot-out-div {
            font-family: verdana;
            font-size: 25px;
            color: white;
            text-align: center;
            align-items: center;
            border-style: solid;
            border-color: white;
            border-radius: 4px;
            width: 100%;
            overflow: hidden;
            //height: 500px;
            margin-top: 0px;
        }
        
        .plot-out-area {
            margin-top: 5px;
        }

        .invalid-btn {
            border-radius: 4px;
            background-color: #A9A9A9;
        }
        
        .menu-active-btn {
            font-family: Sans-Serif;
            font-size: 12px;Area* 
            border-radius: 4px;
            color: white;
            background-color: #F091EC;
        }
        
        .active-btn {
            border-radius: 4px;
            background-color: #F88379;
        }
        
        /* SNR Slider */
        .btn-plot-slider {
            font-family: Sans-Serif;
            font-size: 12px;
            border-radius: 4px;
            color: white;
            background-color: #5CB85C;
            width: 80px;
        }
        
        .snr-slider {
            width: 250px;
        }
        
        .snr-slider-area {
            border-radius: 4px;
            border-style: solid;
            border-color: white;
            #text-align: center;
            align-items: center;
            width: auto;
            background-color: white;
            margin: 0px 30px 12px 30px;
        }
        </style>
        """
        
        """
            HTML() runs the input text as HTML code.
        """
        cls.CSS = HTML(HTML_TEXT)
        
        """
            Header
        """
        cls.HeaderLabel = Label("Deep Neural Network Augmented Wireless Channel Estimation on System on Chip")
        cls.HeaderLabel.add_class("header-label")     
        
        cls.Header = VBox([
            cls.HeaderLabel,
        ],
        layout = Layout(width="auto", grid_area="Header"))
        
        cls.Header.add_class("header")
        
        """
            Menu Labels and Buttons.
        """
        cls.lbl_ls = Label("LS")
        cls.lbl_ls.add_class("menu-ip-label")
        
        cls.lbl_mmse = Label("MMSE")
        cls.lbl_mmse.add_class("menu-ip-label")
        
        cls.lbl_ls_dnn = Label("LS DNN")        
        cls.lbl_ls_dnn.add_class("menu-ip-label")
        
        """
            LS IP.
        """
        cls.btn_ls = {
            "fl_sp": Button(description="SPFL"),
            "fp_16_2": Button(description="FP-16"),
            "fp_18_2": Button(description="FP-18")
        }
        
        cls.btn_ls["fl_sp"].add_class("menu-btn")
        cls.btn_ls["fp_16_2"].add_class("menu-btn")
        cls.btn_ls["fp_18_2"].add_class("menu-btn")
        
        cls.btn_ls["fl_sp"].on_click(cls._btn_ls_fl_sp_on_click)      
        cls.btn_ls["fp_16_2"].on_click(cls._btn_ls_fp_16_2_on_click)
        cls.btn_ls["fp_18_2"].on_click(cls._btn_ls_fp_18_2_on_click)
        
        """
            MMSE IP.
        """
        cls.btn_mmse = {
            "fl_sp": Button(description="SPFL"),
            "fl_dp": Button(description="DPFL"),
            "fl_dp_sp": Button(description="DP-SPFL")
        }
        
        cls.btn_mmse["fl_sp"].add_class("menu-btn")
        cls.btn_mmse["fl_dp"].add_class("menu-btn")
        cls.btn_mmse["fl_dp_sp"].add_class("menu-btn")
        
        cls.btn_mmse["fl_sp"].on_click(cls._btn_mmse_fl_sp_on_click)
        cls.btn_mmse["fl_dp"].on_click(cls._btn_mmse_fl_dp_on_click)
        cls.btn_mmse["fl_dp_sp"].on_click(cls._btn_mmse_fl_dp_sp_on_click)
        
        """
            LS-DNN IP.
        """
        cls.btn_ls_dnn = {
            "fl_sp": Button(description="SPFL"),
            "fp_16_8": Button(description="FP-16"),
            "fp_24_8": Button(description="FP-24")
        }
        
        cls.btn_ls_dnn["fl_sp"].add_class("menu-btn")
        cls.btn_ls_dnn["fp_24_8"].add_class("menu-btn")
        cls.btn_ls_dnn["fp_16_8"].add_class("menu-btn")
        
        cls.btn_ls_dnn["fl_sp"].on_click(cls._btn_ls_dnn_fl_sp_on_click)
        cls.btn_ls_dnn["fp_24_8"].on_click(cls._btn_ls_dnn_fp_24_8_on_click)
        cls.btn_ls_dnn["fp_16_8"].on_click(cls._btn_ls_dnn_fp_16_8_on_click)       
        
        """
            Each VBox displays one particular IP and its different configurations.
        """
        cls.vbox_ls = VBox([
            cls.lbl_ls,
            cls.btn_ls["fl_sp"],
            cls.btn_ls["fp_16_2"],
            cls.btn_ls["fp_18_2"]  
        ],
        layout = Layout(width="auto", grid_area="vbox_ls"))

        cls.vbox_mmse = VBox([
            cls.lbl_mmse,
            cls.btn_mmse["fl_dp"],
            cls.btn_mmse["fl_sp"],
            cls.btn_mmse["fl_dp_sp"]  
        ],
        layout = Layout(width="auto", grid_area="vbox_mmse"))

        cls.vbox_ls_dnn = VBox([
            cls.lbl_ls_dnn,
            cls.btn_ls_dnn["fl_sp"],
            cls.btn_ls_dnn["fp_24_8"],
            cls.btn_ls_dnn["fp_16_8"]  
        ],
        layout = Layout(width="auto", grid_area="vbox_ls_dnn"))
        
        """
            The three VBoxes are grouped into one HBox called Menu_IP_div.
        """
        cls.Menu_IP_div = HBox([
            cls.vbox_ls,
            cls.vbox_mmse,
            cls.vbox_ls_dnn
        ],
        layout = Layout(width="auto", grid_area="Menu_IP_div")) 
        cls.Menu_IP_div.add_class("menu-ip-area")
                
        """
            SNR Slider.
            For varying the SNR.
        """
        cls.snr_slider = widgets.IntSlider(
            value=0,
            min=-50,
            max=50,
            step=5,
            description='SNR:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='d'
        )
        
        cls.snr_slider.add_class("snr-slider")
        
        cls.btn_snr_plot = Button(description="Plot")
        cls.btn_snr_plot.add_class("btn-plot-slider")
        cls.btn_snr_plot.on_click(cls._btn_snr_plot_on_click)
        
        cls.OutPlotLabel2 = Output()
        
        cls.hbox_snr_slider = HBox([
            cls.snr_slider, 
            cls.btn_snr_plot])
        cls.hbox_snr_slider.add_class("snr-slider-area")
        
        """
            Display the Normalized Error per Subcarrier
        """
        cls.norm_err_Label = Label("Normalized Error per Subcarrier")
        cls.norm_err_Label.add_class("norm-err-label")
        
        op1 = 0
        op2 = 0
        op3 = 0
        op4 = 0
        
        cls.norm_err_Out = Output()
        with cls.norm_err_Out:
            mid_space = "&nbsp;"*4
            cls._gui_print(f"&nbsp;CIR &nbsp; {op1:.8f} {mid_space} LS&nbsp;&nbsp;   {op2:.8f}")
            cls._gui_print(f"&nbsp;LSDNN {op3:.8f} {mid_space} MMSE {op4:.8f}")
        
        cls.norm_err_box = VBox([
            cls.norm_err_Label,
            cls.norm_err_Out
        ],
        layout = Layout(width="auto", grid_area="norm_err_box"))
        cls.norm_err_box.add_class("norm-err-vbox")
        
        cls.mag_plots_Out = Output()
        
        """
            The SNR Slider and the NMSE Outputs are grouped into one Hbox called Menu_NMSE_div.
        """
        cls.Menu_NMSE_div = VBox([
            cls.hbox_snr_slider,
            cls.norm_err_box,
            cls.mag_plots_Out,
        ],
        layout = Layout(width="auto", grid_area="Menu_NMSE_div"))
        cls.Menu_NMSE_div.add_class("menu-nmse-area")
        
        """
            Ouptut Plots
                Plot Label
        """     
        cls.PlotHeaderLabel = Label("Comparison of channel estimation schemes")
        cls.PlotHeaderLabel.add_class("plot-out-label")
        
        cls.PlotHeader = VBox([
            cls.PlotHeaderLabel,
        ],
        layout = Layout(width="auto", grid_area="plot_header"))
        cls.PlotHeader.add_class("plot-header")
        
        """
            OutPlot - NMSE Plots.
        """
        cls.OutPlot_NMSE = Output()        
        cls.OutPlot_NMSE_div = VBox([
            cls.OutPlot_NMSE
        ],
        layout=Layout(width="auto", grid_area="OutPlot_NMSE_div"))
        cls.OutPlot_NMSE.add_class("plot-out-area")
        cls.OutPlot_NMSE_div.add_class("plot-out-div")
        
        """
            OutPlot - Magnitude Plots.
        """
        cls.OutPlot_Mag = Output()
        cls.OutPlot_Mag_div = VBox([
            cls.OutPlot_Mag
        ], 
        layout=Layout(width="auto", grid_area="OutPlot_Mag_div"))
        cls.OutPlot_Mag.add_class("plot-out-area")
        cls.OutPlot_Mag_div.add_class("plot-out-div")
        
        """
            Gridbox - The GUI has a grid layout.
                 ______________________
                |                      |     
                |         Header       |       
                |______________________|
                |            ||        |
                |   Menu 1   || Menu 2 |
                |____________||________|
                |                      |
                |      Plot-Header     |
                |______________________|
                |            ||        |
                |    Out 1   ||  Out 2 | 
                |____________||________|
        """
        cls.gridBox = GridBox(
            children = [cls.Header, 
                        cls.Menu_IP_div,
                        cls.Menu_NMSE_div,
                        cls.PlotHeader,
                        cls.OutPlot_NMSE_div,
                        cls.OutPlot_Mag_div,
                        cls.CSS],
            layout = Layout(
            grid_template_rows="auto auto auto 350px",
            grid_template_columns="60% 40%",
            grid_template_areas="""
            "Header Header"
            "Menu_IP_div Menu_NMSE_div"
            "plot_header plot_header"
            "OutPlot_NMSE_div OutPlot_Mag_div"
            """)
        )

        cls.gridBox.add_class("body")      
    
    """
        Button: LS IP
    """ 
    @classmethod
    def _btn_ls_fl_sp_on_click(cls, change):
        if cls.plot_list["ls_fl_sp"]:
            cls.btn_ls["fl_sp"].remove_class("active-btn")
            cls.plot_list["ls_fl_sp"] = False
        else:
            cls.btn_ls["fl_sp"].add_class("active-btn")
            cls.plot_list["ls_fl_sp"] = True
            
            ip_name = "ls_fl_sp"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
            
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
        
        
    @classmethod
    def _btn_ls_fp_16_2_on_click(cls, change):
        if cls.plot_list["ls_fp_16_2"]:
            cls.btn_ls["fp_16_2"].remove_class("active-btn")
            cls.plot_list["ls_fp_16_2"] = False
        else:
            cls.btn_ls["fp_16_2"].add_class("active-btn")
            cls.plot_list["ls_fp_16_2"] = True
            
            ip_name = "ls_fp_16_2"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
            
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
        
    @classmethod 
    def _btn_ls_fp_18_2_on_click(cls, change):
        if cls.plot_list["ls_fp_18_2"]:
            cls.btn_ls["fp_18_2"].remove_class("active-btn")
            cls.plot_list["ls_fp_18_2"] = False
        else:
            cls.btn_ls["fp_18_2"].add_class("active-btn")
            cls.plot_list["ls_fp_18_2"] = True
            
            ip_name = "ls_fp_18_2"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
            
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
    
    """
        Button: MMSE IP
    """
    @classmethod
    def _btn_mmse_fl_sp_on_click(cls, change):
        if cls.plot_list["mmse_fl_sp"]:
            cls.btn_mmse["fl_sp"].remove_class("active-btn")
            cls.plot_list["mmse_fl_sp"] = False
        else:
            cls.btn_mmse["fl_sp"].add_class("active-btn")
            cls.plot_list["mmse_fl_sp"] = True
            
            ip_name = "mmse_fl_sp"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
        
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
    
    @classmethod
    def _btn_mmse_fl_dp_on_click(cls, change):
        if cls.plot_list["mmse_fl_dp"]:
            cls.btn_mmse["fl_dp"].remove_class("active-btn")
            cls.plot_list["mmse_fl_dp"] = False
        else:
            cls.btn_mmse["fl_dp"].add_class("active-btn")
            cls.plot_list["mmse_fl_dp"] = True
            
            ip_name = "mmse_fl_dp"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
        
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()

    @classmethod
    def _btn_mmse_fl_dp_sp_on_click(cls, change):
        if cls.plot_list["mmse_fl_dp_sp"]:
            cls.btn_mmse["fl_dp_sp"].remove_class("active-btn")
            cls.plot_list["mmse_fl_dp_sp"] = False  
        else:
            cls.btn_mmse["fl_dp_sp"].add_class("active-btn")
            cls.plot_list["mmse_fl_dp_sp"] = True
            
            ip_name = "mmse_fl_dp_sp"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
        
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
    
    """
        Button: LS-DNN IP
    """
    @classmethod
    def _btn_ls_dnn_fl_sp_on_click(cls, change):
        if cls.plot_list["ls_dnn_fl_sp"]:
            cls.btn_ls_dnn["fl_sp"].remove_class("active-btn")
            cls.plot_list["ls_dnn_fl_sp"] = False
        else:
            cls.btn_ls_dnn["fl_sp"].add_class("active-btn")
            cls.plot_list["ls_dnn_fl_sp"] = True
            
            ip_name = "ls_dnn_fl_sp"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
        
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
     
    @classmethod
    def _btn_ls_dnn_fp_16_8_on_click(cls, change):
        if cls.plot_list["ls_dnn_fp_16_8"]:
            cls.btn_ls_dnn["fp_16_8"].remove_class("active-btn")
            cls.plot_list["ls_dnn_fp_16_8"] = False
        else:
            cls.btn_ls_dnn["fp_16_8"].add_class("active-btn")
            cls.plot_list["ls_dnn_fp_16_8"] = True
            
            ip_name = "ls_dnn_fp_16_8"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
        
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
    
    @classmethod
    def _btn_ls_dnn_fp_24_8_on_click(cls, change):
        if cls.plot_list["ls_dnn_fp_24_8"]:
            cls.btn_ls_dnn["fp_24_8"].remove_class("active-btn")
            cls.plot_list["ls_dnn_fp_24_8"] = False
        else:
            cls.btn_ls_dnn["fp_24_8"].add_class("active-btn")
            cls.plot_list["ls_dnn_fp_24_8"] = True
            
            ip_name = "ls_dnn_fp_24_8"
            cls.cur_ip = IP_Overlay()
            cls.cur_ip.load_ip(ip_name, cls.ip_paths[ip_name])
            cls._run_ip()
        
        with cls.OutPlot_NMSE:
            clear_output()
            cls._plot_err_curve()
    
    """
        If the SNR value of SNR_Slider is changed then display the error values for that SNR.
    """
    @classmethod
    def _btn_snr_plot_on_click(cls, change):
        n_snr_neg50_50 = cls.snr_slider.value
        n_snr_0_20 = (int) ( (n_snr_neg50_50 + 50) / 5 )
        
        n_num = f"{n_snr_0_20+1}".zfill(2)
        act = cls._load_file(f"./../data/act/actC_{n_num}.txt")
        
        err_act = 0
        ls_out = 0
        mmse_out = 0
        ls_dnn_out = 0
        
        if not cls.ips["ls_fl_sp"].is_output_empty():
            ls_out = cls.ips["ls_fl_sp"].Err_normalized_hw[n_snr_0_20]
        if not cls.ips["mmse_fl_dp"].is_output_empty():
            mmse_out = cls.ips["mmse_fl_dp"].Err_normalized_hw[n_snr_0_20]
        if not cls.ips["ls_dnn_fl_sp"].is_output_empty():
            ls_dnn_out = cls.ips["ls_dnn_fl_sp"].Err_normalized_hw[n_snr_0_20]
        
        with cls.norm_err_Out:
            clear_output()
            mid_space = "&nbsp;"*5
            cls._gui_print(f"&nbsp;CIR &nbsp; {err_act:.8f} {mid_space} LS&nbsp;&nbsp;   {ls_out:.8f}")
            cls._gui_print(f"&nbsp;LSDNN {ls_dnn_out:.8f} {mid_space} MMSE {mmse_out:.8f}")
        
        with cls.OutPlot_Mag:
            clear_output()
            cls._plot_mag_curve(n_snr_0_20)
        
    @classmethod
    def display_widgets(cls):
        cls._class_init()
        cls._create_widgets()
        
        display(
            cls.gridBox
        )
        
        cls._btn_ls_fl_sp_on_click("change")
        #cls._btn_mmse_fl_dp_on_click("change")
        #cls._btn_ls_dnn_fl_sp_on_click("change")
        #cls._btn_snr_plot_on_click("change")
            
GUI.display_widgets()

GridBox(children=(VBox(children=(Label(value='Deep Neural Network Augmented Wireless Channel Estimation on Sys…