<a href="https://colab.research.google.com/github/pollyjuice74/Error-Bit-Decoding/blob/main/GNNDecoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
# general imports
import tensorflow as tf
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

# load required Sionna components
#!pip install sionna
import sionna as sn
from sionna.fec.utils import load_parity_check_examples, LinearEncoder, GaussianPriorSource
from sionna.utils import BinarySource, ebnodb2no, BitwiseMutualInformation, hard_decisions
from sionna.utils.metrics import compute_ber
from sionna.utils.plotting import PlotBER
from sionna.mapping import Mapper, Demapper
from sionna.channel import AWGN
from sionna.fec.ldpc import LDPCBPDecoder

from tensorflow.keras.layers import Dense, Layer

!git clone https://github.com/NVlabs/gnn-decoder.git
!
%load_ext autoreload
%autoreload 2
import sys
from importlib import import_module
import pickle

#tf.config.experimental_run_functions_eagerly(True)
#from gnn_decoder.gnn import *
# gnn_decoder.wbp import * # load weighted BP functions

Cloning into 'gnn-decoder'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 46 (delta 22), reused 38 (delta 14), pack-reused 0[K
Receiving objects: 100% (46/46), 857.22 KiB | 7.14 MiB/s, done.
Resolving deltas: 100% (22/22), done.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
class MLP(Layer):
    """Simple MLP layer.

    Parameters
    ----------
    units : List of int
        Each element of the list describes the number of units of the
        corresponding layer.

    activations : List of activations
        Each element of the list contains the activation to be used
        by the corresponding layer.

    use_bias : List of booleans
        Each element of the list indicates if the corresponding layer
        should use a bias or not.
    """
    def __init__(self, units, activations, use_bias):
        super().__init__()
        self._num_units = units
        self._activations = activations
        self._use_bias = use_bias

    def build(self, input_shape):
        self._layers = []
        for i, units in enumerate(self._num_units):
            self._layers.append(Dense(units,
                                      self._activations[i],
                                      use_bias=self._use_bias[i]))

    def call(self, inputs):
        outputs = inputs
        for layer in self._layers:
            outputs = layer(outputs)
        return outputs


class GNN_BP(Layer):
    """GNN-based BP Decoder

    Parameters
    ---------
    H : [num_ch, num_vn], numpy.array
        The parity check matrix.

    num_embed_dims: int
        Number of dimensions of the vertex embeddings.

    num_msg_dims: int
        Number of dimensions of a message.

    num_hidden_units: int
        Number of hidden units of the MLPs used to compute
        messages and to update the vertex embeddings.

    num_mlp_layers: int
        Number of layers of the MLPs used to compute
        messages and to update the vertex embeddings.

    num_iter: int
        Number of iterations.

    reduce_op: str
        A string defining the vertex aggregation function.
        Currently, "mean" and "sum" is supported.

    activation: str
        A string defining the activation function of the hidden MLP layers to
        be used. Defaults to "relu".

    output_all_iter: Bool
        Indicates if the LLRs of all iterations should be returned as list
        or if only the LLRs of the last iteration should be returned.

    clip_llr_to: float or None
        If set, the absolute value of the input LLRs will be clipped to this value.

    Input
    -----
    llr : [batch_size, num_vn], tf.float32
        Tensor containing the LLRs of all bits.

    Output
    ------
    llr_hat: : [batch_size, num_vn], tf.float32
        Tensor containing the LLRs at the decoder output.
        If `output_all_iter`==True, a list of such tensors will be returned.
    """
    def __init__(self,
                 pcm,
                 num_embed_dims,
                 num_msg_dims,
                 num_hidden_units,
                 num_mlp_layers,
                 num_iter,
                 reduce_op="sum",
                 activation="relu",
                 output_all_iter=False,
                 clip_llr_to=None):
        super().__init__()

        self._pcm = pcm # Parity check matrix
        self._num_cn = pcm.shape[0] # Number of check nodes
        self._num_vn = pcm.shape[1] # Number of variables nodes
        self._num_edges = int(np.sum(pcm)) # Number of edges

        # Array of shape [num_edges, 2]
        # 1st col = CN id, 2nd col = VN id
        # The ith row of this array defines the ith edge.
        self._edges = np.stack(np.where(pcm), axis=1)

        # Create 2D ragged tensor of shape [num_cn,...]
        # cn_edges[i] contains the edge ids for CN i
        cn_edges = []
        for i in range(self._num_cn):
            cn_edges.append(np.where(self._edges[:,0]==i)[0])
        self._cn_edges = tf.ragged.constant(cn_edges)

        # Create 2D ragged tensor of shape [num_vn,...]
        # vn_edges[i] contains the edge ids for VN i
        vn_edges = []
        for i in range(self._num_vn):
            vn_edges.append(np.where(self._edges[:,1]==i)[0])
        self._vn_edges = tf.ragged.constant(vn_edges)

        self._num_embed_dims = num_embed_dims # Number of dimensions for vertex embeddings
        self._num_msg_dims = num_msg_dims # Number of dimensions for messages
        self._num_hidden_units = num_hidden_units # Number of hidden units for MLPs computing messages and embeddings
        self._num_mlp_layers = num_mlp_layers # Number of layers for MLPs computing messages and embeddings
        self._num_iter = num_iter # Number of BP iterations, can be modified

        self._reduce_op = reduce_op # reduce operation for message aggregation
        self._activation = activation # activation function of the hidden MLP layers

        self._output_all_iter = output_all_iter
        self._clip_llr_to = clip_llr_to

    @property
    def num_iter(self):
        return self._num_iter

    @num_iter.setter
    def num_iter(self, value):
        self._num_iter = value

    def build(self, input_shape):
        # NN to transform input LLR to VN embedding
        self._llr_embed = Dense(self._num_embed_dims)

        # NN to transform VN embedding to output LLR
        self._llr_inv_embed = Dense(1)

        # CN embedding update function
        self.update_h_cn = UpdateEmbeddings(self._num_msg_dims,
                                            self._num_hidden_units,
                                            self._num_mlp_layers,
                                            np.flip(self._edges, 1), # Flip columns: "from VN to CN"
                                            self._cn_edges,
                                            self._reduce_op,
                                            self._activation)

        # VN embedding update function
        self.update_h_vn = UpdateEmbeddings(self._num_msg_dims,
                                            self._num_hidden_units,
                                            self._num_mlp_layers,
                                            self._edges, # "from CN to VN"
                                            self._vn_edges,
                                            self._reduce_op,
                                            self._activation)

    def llr_to_embed(self, llr):
        """Transform LLRs to VN embeddings"""
        return self._llr_embed(tf.expand_dims(llr, -1))

    def embed_to_llr(self, h_vn):
        """Transform VN embeddings to LLRs"""
        return tf.squeeze(self._llr_inv_embed(h_vn), axis=-1)

    def call(self, llr):
        batch_size = tf.shape(llr)[0]

        # Initialize vertex embeddings
        if self._clip_llr_to is not None:
            llr = tf.clip_by_value(llr, -self._clip_llr_to, self._clip_llr_to)

        h_vn = self.llr_to_embed(llr)
        h_cn = tf.zeros([batch_size, self._num_cn, self._num_embed_dims])

        # BP iterations
        if self._output_all_iter:
            llr_hat = []

        for i in range(self._num_iter):
            # Update CN embeddings
            h_cn = self.update_h_cn(h_vn, h_cn)

            # Update VNs
            h_vn = self.update_h_vn(h_cn, h_vn)

            if self._output_all_iter:
                llr_hat.append(self.embed_to_llr(h_vn))

        if not self._output_all_iter:
            llr_hat = self.embed_to_llr(h_vn)

        return llr_hat


class UpdateEmbeddings(Layer):
    """Update vertex embeddings of the GNN BP decoder.

    This layer computes first the messages that are sent across the edges
    of the graph, then sums the incoming messages at each vertex, finally and
    updates their embeddings.

    Parameters
    ----------
    num_msg_dims: int
        Number of dimensions of a message.

    num_hidden_units: int
        Number of hidden units of MLPs used to compute
        messages and to update the vertex embeddings.

    num_mlp_layers: int
        Number of layers of the MLPs used to compute
        messages and to update the vertex embeddings.

    from_to_ind: [num_egdes, 2], np.array
        Two dimensional array containing in each row the indices of the
        originating and receiving vertex for an edge.

    gather_ind: [`num_vn` or `num_cn`, None], tf.ragged.constant
        Ragged tensor that contains for each receiving vertex the list of
        edge indices from which to aggregate the incoming messages. As each
        vertex can have a different degree, a ragged tensor is used.

    reduce_op: str
        A string defining the vertex aggregation function.
        Currently, "mean" and "sum" is supported.

    activation: str
        A string defining the activation function of the hidden MLP layers to
        be used. Defaults to "relu".

    Input
    -----
    h_from : [batch_size, num_cn or num_vn, num_embed_dims], tf.float32
        Tensor containing the embeddings of the "transmitting" vertices.

    h_to : [batch_size, num_vn or num_cn, num_embed_dims], tf.float32
        Tensor containing the embeddings of the "receiving" vertices.

    Output
    ------
    h_to_new : Same shape and type as `h_to`
        Tensor containing the updated embeddings of the "receiving" vertices.
    """
    def __init__(self,
                 num_msg_dims,
                 num_hidden_units,
                 num_mlp_layers,
                 from_to_ind,
                 gather_ind,
                 reduce_op="sum",
                 activation="relu",
                 ):
        super().__init__()
        self._num_msg_dims = num_msg_dims
        self._num_hidden_units = num_hidden_units
        self._num_mlp_layers = num_mlp_layers
        self._from_ind = from_to_ind[:,0]
        self._to_ind = from_to_ind[:,1]
        self._gather_ind = gather_ind
        self._reduce_op = reduce_op
        self._activation = activation

    def build(self, input_shape):
        num_embed_dims = input_shape[-1]

        # MLP to compute messages
        units = [self._num_hidden_units]*(self._num_mlp_layers-1) + [self._num_msg_dims]
        activations = [self._activation]*(self._num_mlp_layers-1) + [None]
        use_bias = [True]*self._num_mlp_layers
        self._msg_mlp = MLP(units, activations, use_bias)

        # MLP to update embeddings from accumulated messages
        units[-1] = num_embed_dims
        self._embed_mlp = MLP(units, activations, use_bias)

    def call(self, h_from, h_to):
        # Concatenate embeddings of the transmitting (from) and receiving (to) vertex for each edge
        features = tf.concat([tf.gather(h_from, self._from_ind, axis=1),
                              tf.gather(h_to, self._to_ind, axis=1)],
                             axis=-1)

        # Compute messsages for all edges
        messages = self._msg_mlp(features)

        # Reduce messages at each receiving (to) vertex
        # note: bring batch dim to last dim for improved performance
        # with ragged tensors
        messages = tf.transpose(messages, (1,2,0))
        m_ragged = tf.gather(messages, self._gather_ind, axis=0)
        if self._reduce_op=="sum":
            m = tf.reduce_sum(m_ragged, axis=1)
        elif self._reduce_op=="mean":
            m = tf.reduce_mean(m_ragged, axis=1)
        else:
            raise ValueError("unknown reduce operation")
        m = tf.transpose(m, (2,0,1)) # batch-dim back to first dim

        # Compute new embeddings
        h_to_new = self._embed_mlp(tf.concat([m, h_to], axis=-1))

        return h_to_new



class E2EModel(tf.keras.Model):
    def __init__(self, pcm, decoder):
        super().__init__()
        self._pcm = pcm
        self._n = pcm.shape[1]
        self._k = self._n - pcm.shape[0]
        self._encoder = LinearEncoder(pcm, is_pcm=True)

        self._binary_source = BinarySource()

        self._num_bits_per_symbol = 2 # at the moment only QPSK is supported
        self._mapper = Mapper("qam", self._num_bits_per_symbol)
        self._demapper = Demapper("app", "qam", self._num_bits_per_symbol)
        self._channel = AWGN()
        self._decoder = decoder

    @tf.function()
    def call(self, batch_size, ebno_db):

        # calculate noise variance
        if self._decoder is not None:
            no = ebnodb2no(ebno_db, self._num_bits_per_symbol, self._k/self._n)
        else: #for uncoded BPSK the rate is 1
            no = ebnodb2no(ebno_db, self._num_bits_per_symbol, 1)

        # draw random info bits to transmit
        b = self._binary_source([batch_size, self._k])
        c = self._encoder(b)

        # zero padding to support odd codeword lengths
        if self._n%2==1:
            c_pad = tf.concat([c, tf.zeros([batch_size, 1])], axis=1)
        else: # no padding
            c_pad = c

        # map to symbols
        x = self._mapper(c_pad)

        # transmit over AWGN channel
        y = self._channel([x, no])

        # demap to LLRs
        llr = self._demapper([y, no])

        # remove filler bits
        if self._n%2==1:
            llr = llr[:,:-1]

        # and decode
        if self._decoder is not None:
            llr = self._decoder(llr)

        return c, llr


In [23]:
def train_model(model, params):
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    for p in params:
        train_batch_size, lr, train_iter = p
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

        @tf.function()
        def train_step():
            ebno_db = tf.random.uniform([train_batch_size, 1], minval=ebno_db_min, maxval=ebno_db_max)
            with tf.GradientTape() as tape:
                c, llr_hat = model(train_batch_size, ebno_db)
                loss_value = 0
                for m, l in enumerate(llr_hat):
                    loss_value += loss(c, l)

            weights = model.trainable_weights
            grads = tape.gradient(loss_value, weights)
            optimizer.apply_gradients(zip(grads, weights))
            return c, llr_hat

        for i in range(train_iter):
            c, llr_hat = train_step()
            if i%10==0:
                ebno_db = tf.random.uniform([10000, 1],
                                            minval=ebno_db_min,
                                            maxval=ebno_db_max)
                c, llr_hat = model(10000, ebno_db)
                loss_value = 0
                for l in llr_hat:
                    loss_value += loss(c, l)
                c_hat = tf.cast(tf.greater(llr_hat[-1], 0), tf.float32)
                ber = compute_ber(c, c_hat).numpy()
                print(f"Iteration {i}, loss = {loss_value.numpy():.3f}, " \
                      f"ber = {ber:.5f}")

def load_weights(system, model_path):
  """Load model weights.

  This function loads the weights of a Keras model ``system`` from a file
  provided by ``model_path``.

  Parameters
  ----------
      system: Keras model
          The target model into which the weights are loaded.

      model_path: str
          Defining the path where the weights are stored.

  """
  with open(model_path, 'rb') as f:
      weights = pickle.load(f)
  system.set_weights(weights)

In [13]:
#BCH codes
pcm, k, n, coderate = load_parity_check_examples(pcm_id=1, verbose=True)
ebno_db_min = 2.0
ebno_db_max = 9.0
ebno_dbs = np.arange(ebno_db_min,ebno_db_max+1)

mc_iters = 100
mc_batch_size = 10000
num_target_block_errors = 2000



encoder = LinearEncoder(pcm, is_pcm=True)
tf.random.set_seed(2)
gnn_decoder = GNN_BP(pcm=pcm,
                     num_embed_dims=20,
                     num_msg_dims=20,
                     num_hidden_units=40,
                     num_mlp_layers=2,
                     num_iter=8,
                     reduce_op="mean",
                     activation="tanh",
                     output_all_iter=True,
                     clip_llr_to=None)
e2e_gnn = E2EModel(pcm, gnn_decoder)



n: 63, k: 45, coderate: 0.714


In [None]:

train_params = [
    #batch_size, learning_rate, num_iter
    [512, 1e-3, 35000],
    [512, 1e-4, 300000],
    [512, 1e-5, 300000],
]
e2e_gnn._decoder._output_all_iter = True # use multi-loss during training
train_model(e2e_gnn, train_params)


#train = False#True #
# if train:
#     train_model(e2e_gnn, train_params)
# else:
#     # you can also load the precomputed weights
#     load_weights(e2e_gnn, "gnn_decoder/weights/LDPC_reg_precomputed.npy")

Iteration 0, loss = 0.575, ber = 0.03809
Iteration 10, loss = 0.435, ber = 0.02387
Iteration 20, loss = 0.411, ber = 0.02023
Iteration 30, loss = 0.405, ber = 0.01934
Iteration 40, loss = 0.417, ber = 0.02004
Iteration 50, loss = 0.409, ber = 0.01960
Iteration 60, loss = 0.405, ber = 0.01929
Iteration 70, loss = 0.400, ber = 0.01902
Iteration 80, loss = 0.409, ber = 0.01963
Iteration 90, loss = 0.406, ber = 0.01949
Iteration 100, loss = 0.408, ber = 0.01960
Iteration 110, loss = 0.410, ber = 0.01943
Iteration 120, loss = 0.410, ber = 0.01967
