<a href="https://colab.research.google.com/github/pollyjuice74/Error-Bit-Decoding/blob/main/BCHDecoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# general imports
import tensorflow as tf
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

# load required Sionna components
!pip install sionna
import sionna as sn
from sionna.fec.utils import load_parity_check_examples, LinearEncoder, GaussianPriorSource
from sionna.utils import BinarySource, ebnodb2no, BitwiseMutualInformation, hard_decisions
from sionna.utils.metrics import compute_ber
from sionna.utils.plotting import PlotBER
from sionna.mapping import Mapper, Demapper
from sionna.channel import AWGN
from sionna.fec.ldpc import LDPCBPDecoder

from tensorflow.keras.layers import Dense, Layer

%load_ext autoreload
%autoreload 2
import sys
from importlib import import_module
import pickle

from google.colab import drive
drive.mount('/content/drive')


#tf.config.experimental_run_functions_eagerly(True)
from drive.MyDrive.gnn_decoder.gnn import *
from drive.MyDrive.gnn_decoder.wbp import * # load weighted BP functions

In [None]:
#----- BCH -----
params={
    # --- Code Parameters ---
        "code": "BCH", # (63,45)
    # --- GNN Architecture ----
        "num_embed_dims": 20,
        "num_msg_dims": 20,
        "num_hidden_units": 40,
        "num_mlp_layers": 2,
        "num_iter": 8,
        "reduce_op": "mean",
        "activation": "tanh",
        "clip_llr_to": None,
        "use_attributes": False,
        "node_attribute_dims": 0,
        "msg_attribute_dims": 0,
        "use_bias": False,
    # --- Training ---- #
        "batch_size": [256, 256, 256], # bs, iter, lr must have same dim
        "train_iter": [35000, 300000, 300000],
        "learning_rate": [1e-3, 1e-4, 1e-5],
        "ebno_db_train": [3, 8.],
        "ebno_db_eval": 4.,
        "batch_size_eval": 10000, # batch size only used for evaluation during training
        "eval_train_steps": 1000, # evaluate model every N iters
    # --- Log ----
        "save_weights_iter": 10000, # save weights every X iters
        "run_name": "BCH_01", # name of the stored weights/logs
        "save_dir": "results/", # folder to store results
    # --- MC Simulation parameters ----
        "mc_iters": 1000,
        "mc_batch_size": 2000,
        "num_target_block_errors": 500,
        "ebno_db_min": 2.,
        "ebno_db_max": 9.,
        "ebno_db_stepsize": 1.,
        "eval_iters": [2, 3, 4, 6, 8, 10],
    # --- Weighted BP parameters ----
        "simulate_wbp": True, # simulate weighted BP as baseline
        "wbp_batch_size" : [2000, 2000, 2000],
        "wbp_train_iter" : [300, 10000, 2000],
        "wbp_learning_rate" : [1e-2, 1e-3, 1e-3],
        "wbp_ebno_train" : [5., 5., 6.],
        "wbp_ebno_val" : 7., # validation SNR during training
        "wbp_batch_size_val" : 2000,
        "wbp_clip_value_grad" : 10,
}

In [None]:
# all codes must provide an encoder-layer and a pcm
if params["code"]=="BCH":
    print("Loading BCH code")
    pcm, k, n, coderate = load_parity_check_examples(pcm_id=1, verbose=True)

    encoder = LinearEncoder(pcm, is_pcm=True)
    params["k"] = k
    params["n"] = n
else:
    raise ValueError("Unknown code type")

ber_plot = PlotBER(f"GNN-based Decoding - {params['code']}, (k,n)=({k},{n})")
ebno_dbs = np.arange(params["ebno_db_min"],
                     params["ebno_db_max"]+1,
                     params["ebno_db_stepsize"])

# simulate "conventional" BP performance for given pcm
bp_decoder = LDPCBPDecoder(pcm, num_iter=20, hard_out=False)
e2e_bp = E2EModel(encoder, bp_decoder, k, n)
# ber_plot.simulate(e2e_bp,
#                  ebno_dbs=ebno_dbs,
#                  batch_size=params["mc_batch_size"],
#                  num_target_block_errors=params["num_target_block_errors"],
#                  legend=f"BP {bp_decoder._num_iter.numpy()} iter.",
#                  soft_estimates=True,
#                  max_mc_iter=params["mc_iters"],
#                  forward_keyboard_interrupt=False,
#                  show_fig=False);

Loading BCH code

n: 63, k: 45, coderate: 0.714
EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
      2.0 | 5.9794e-02 | 8.1200e-01 |        7534 |      126000 |         1624 |        2000 |         6.8 |reached target block errors
      3.0 | 3.1714e-02 | 4.8550e-01 |        3996 |      126000 |          971 |        2000 |         0.3 |reached target block errors
      4.0 | 1.3381e-02 | 2.1050e-01 |        3372 |      252000 |          842 |        4000 |         0.5 |reached target block errors
      5.0 | 3.9508e-03 | 5.8500e-02 |        2489 |      630000 |          585 |       10000 |         1.3 |reached target block errors
      6.0 | 9.2479e-04 | 1.2071e-02 |        2447 |     2646000 |          507 |       42000 |         5.3 |reached target block errors

Simulation stopp

In [None]:
#encoder.load_weights('gnn_decoder/weights/BCH_precomputed.npy')

# train and simulate Weighted BP as additional baseline
# please note that the training parameters could be critical
if params["simulate_wbp"]:
    evaluate_wbp(params, pcm, encoder, ebno_dbs, ber_plot)

Note that WBP requires Sionna > v0.11.
Iter: 0 loss: 0.002754 ber: 0.000095 bmi: 0.999
Iter: 50 loss: 0.001551 ber: 0.000119 bmi: 0.999
Iter: 100 loss: 0.002035 ber: 0.000087 bmi: 0.999
Iter: 150 loss: 0.002294 ber: 0.000103 bmi: 0.999
Iter: 200 loss: 0.001316 ber: 0.000024 bmi: 1.000
Iter: 250 loss: 0.001561 ber: 0.000103 bmi: 0.999
Iter: 0 loss: 0.000573 ber: 0.000040 bmi: 1.000
Iter: 50 loss: 0.000833 ber: 0.000032 bmi: 1.000
Iter: 100 loss: 0.001029 ber: 0.000071 bmi: 0.999
Iter: 150 loss: 0.001136 ber: 0.000119 bmi: 0.999
Iter: 200 loss: 0.000692 ber: 0.000048 bmi: 1.000
Iter: 250 loss: 0.000701 ber: 0.000040 bmi: 1.000
Iter: 300 loss: 0.000414 ber: 0.000000 bmi: 1.000
Iter: 350 loss: 0.000756 ber: 0.000071 bmi: 0.999
Iter: 400 loss: 0.000902 ber: 0.000071 bmi: 1.000
Iter: 450 loss: 0.000682 ber: 0.000087 bmi: 0.999
Iter: 500 loss: 0.001183 ber: 0.000040 bmi: 1.000
Iter: 550 loss: 0.001236 ber: 0.000095 bmi: 0.999
Iter: 600 loss: 0.000524 ber: 0.000024 bmi: 1.000
Iter: 650 loss: 0

KeyboardInterrupt: 