<a href="https://colab.research.google.com/github/pollyjuice74/ECCT/blob/main/ECCT_on_5G.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ECCT
!git clone https://github.com/pollyjuice74/ECCT
import os
os.chdir('ECCT')

from args import pass_args_ecct
from Model import *
from Codes import *

# Enc/Dec 5G
!pip install sionna
from sionna.fec.ldpc.encoding import LDPC5GEncoder
from sionna.utils import BitErrorRate, BinarySource, ebnodb2no
from sionna.mapping import Mapper, Demapper
from sionna.channel import AWGN

!wget https://raw.githubusercontent.com/pollyjuice74/REU-LDPC-Project/main/5g_enc_dec/decoder.py
from decoder import LDPC5GDecoder

# Other
import tensorflow as tf
from torch.nn import functional as F
import torch.nn as nn
import torch
import copy
import time


Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets>=8.0.4->sionna)
  Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.6/1.6 MB[0m [31m97.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: widgetsnbextension, jupyterlab-widgets, jedi, drjit, mitsuba, ipywidgets, ipydatawidgets, pythreejs, sionna
  Attempting uninstall: widgetsnbextension
    Found existing installation: widgetsnbextension 3.6.6
    Uninstalling widgetsnbextension-3.6.6:
      Successfully uninstalled widgetsnbextension-3.6.6
  Attempting uninstall: jupyterlab-widgets
    Found existing installation: jupyterlab_widgets 3.0.11
    Uninstalling jupyterlab_widgets-3.0.11:
      Successfully uninstall

In [8]:
from sionna.utils import BitErrorRate, BinarySource, ebnodb2no

# Code data
k, n = (90, 100)
bps = 4 # bits per symbol
args = pass_args_ecct()
batch_size = 1

# 5G compliant encoder
enc = LDPC5GEncoder(k,n)
# Decoder models
dec = LDPC5GDecoder(enc)
ecct = ECC_Transformer(args, enc)

# Message generation
binary_source = BinarySource()
# Channel objects
mapper = Mapper("qam", bps)
channel = AWGN()
demapper = Demapper("app", "qam", bps)

ebno_range = tf.convert_to_tensor(range(4, 7), dtype=tf.float32)

def llr_to_bin(llr_tf):
    return tf.cast(llr_tf > 0, tf.float32)


def nldpc_to_n(x_hat):
    x_hat = tf.convert_to_tensor(x_hat.detach().numpy())

    # llr to bin
    x_hat = tf.cast(tf.less(0.0, x_hat), dec._output_dtype)

    x = tf.reshape(x_hat, [batch_size, dec._n_pruned])
    # print("x: ", x.shape)

    # remove filler bits at pos (k, k_ldpc)
    x_no_filler1 = tf.slice(x, [0, 0], [batch_size, dec.encoder.k])

    x_no_filler2 = tf.slice(x,
                            [0, dec.encoder.k_ldpc],
                            [batch_size,
                            dec._n_pruned-dec.encoder.k_ldpc])
    # print("x_no_filler1: ", x_no_filler1.shape)
    # print("x_no_filler2: ", x_no_filler2.shape)

    x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1)
    # print("x_no_filler: ", x_no_filler.shape)

    # shorten the first 2*Z positions and end after n bits
    x_short = tf.slice(x_no_filler,
                        [0, 2*dec.encoder.z],
                        [batch_size, dec.encoder.n])
    # print("x_short: ", x_short.shape)

    # if used, apply rate-matching output interleaver again as
    # Sec. 5.4.2.2 in 38.212
    if dec._encoder.num_bits_per_symbol is not None:
        x_short = tf.gather(x_short, dec._encoder.out_int, axis=-1)
    # print("x_short: ", x_short.shape)

    # Reshape x_short so that it matches the original input dimensions
    # overwrite first dimension as this could be None (Keras)
    llr_ch_shape = [1, enc._n]
    llr_ch_shape[0] = -1
    x_short= tf.reshape(x_short, llr_ch_shape)
    # print("x_short: ", x_short.shape)

    # enable other output datatypes than tf.float32
    x_out = tf.cast(x_short, dec._output_dtype)
    # print("x_out: ", x_out.shape)

    return torch.tensor(x_out.numpy(), dtype=torch.float32)


def train(model, iters=100):
    model.train()
    cum_loss = cum_ber = cum_fer = cum_samples = 0
    t = time.time()

    for eb_no in range(4,7):

        for i in range(iters):
            no = ebnodb2no(eb_no, bps, k/n)
            no = tf.expand_dims(tf.cast(no, tf.float32), axis=-1)

            b = binary_source([batch_size, enc._k]) # (k,1) info bits
            c = enc(b) # (n,1) codeword

            # NOISELESS Channel to get (n_ldpc,1) original llrs
            x = mapper(c)
            # print(x.shape, no.shape)
            llr = demapper([x, no]) # no noise # (n,1)

            llr_nldpc_noiseless, _, _, llr_noiseless = dec(llr) # decoder turns (n,1) to (n_ldpc,1)

            # AWGN Channel
            x = mapper(c)
            y = channel([x, no])
            llr_r = demapper([y, no])

            llr_nldpc, _, _, _ = dec(llr_r) # decoder turns (n,1) to (n_ldpc,1)

            # print(enc.pcm.shape, llr_nldpc.shape)
            syndrome = ( enc.pcm @ llr_to_bin(tf.transpose(llr_nldpc_noiseless)) ) % 2
            magnitude = tf.abs(llr_nldpc)

            # Convert to pytorch tensors
            syndrome = ( torch.tensor(syndrome, dtype=torch.float32) ).T
            magnitude = torch.tensor(magnitude.numpy(), dtype=torch.float32)
            print(f"syndrome: {syndrome}, magnitude: {magnitude.shape}")

            llr_noiseless_hat = model(magnitude, syndrome) # (n_ldpc,1)

            # Convert to pytorch
            llr_noiseless = torch.tensor(llr_noiseless.numpy(), dtype=torch.float32)

            # Compute loss
            loss = F.binary_cross_entropy_with_logits(llr_noiseless_hat, llr_noiseless)

            # turn llrs (n_ldpc,1) into cw_hat (n,1)
            c_hat = nldpc_to_n(llr_noiseless_hat)
            c = torch.tensor(c.numpy(), dtype=torch.float32)

            print(f"c: {c}, c_hat: {c_hat}")

            cum_loss += loss.item() * llr_noiseless.shape[-1]
            cum_ber += BER(c_hat, c) * llr_noiseless.shape[-1]
            cum_fer += FER(c_hat, c) * llr_noiseless.shape[-1]
            cum_samples += llr_noiseless.shape[-1]

            if i%10 == 0:
                print(f'Batch {i + 1}/{iters}: Loss={cum_loss / cum_samples:.2e} BER={cum_ber / cum_samples:.2e} FER={cum_fer / cum_samples:.2e}')

    print(f'Train time: {time.time() - t:.2f}s\n')
    return cum_loss / cum_samples, cum_ber / cum_samples, cum_fer / cum_samples


def test(model):
  pass



train(ecct)

Path to model/logs: Results_ECCT/POLAR__Code_n_64_k_32__28_06_2024_18_00_26
Namespace(epochs=1000, workers=4, lr=0.0001, gpus='-1', batch_size=128, test_batch_size=2048, seed=42, code_type='POLAR', code_k=32, code_n=64, standardize=False, N_dec=6, d_model=32, h=8, code=<args.pass_args_ecct.<locals>.Code object at 0x7f1e895c0f40>, path='Results_ECCT/POLAR__Code_n_64_k_32__28_06_2024_18_00_26')
Self-Attention Sparsity Ratio=99.00%, Self-Attention Complexity Ratio=0.50%
Mask:
 tensor([[[[False,  True,  True,  ...,  True,  True,  True],
          [ True, False,  True,  ...,  True,  True,  True],
          [ True,  True, False,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ..., False,  True,  True],
          [ True,  True,  True,  ...,  True, False,  True],
          [ True,  True,  True,  ...,  True,  True, False]]]])

5G Decoding llr_ch ((1, 100)) 

BP llr ((1, 780)) decoding

5G Decoding llr_ch ((1, 100)) 

BP llr ((1, 780)) decoding
syndrome: tensor([[0., 

KeyboardInterrupt: 