<a href="https://colab.research.google.com/github/pollyjuice74/REU-LDPC-Project/blob/main/ECCT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/pollyjuice74/ECCT
!git clone https://github.com/pollyjuice74/DDECC
!git clone https://github.com/pollyjuice74/REU-LDPC-Project
!git clone https://github.com/pollyjuice74/gnn-decoder

!pip install sionna
#!pip install torch torch-geometric

In [2]:
import os
import copy
import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np
# Rename folders
if os.path.exists('REU-LDPC-Project'):
  os.rename('REU-LDPC-Project', 'REU_LDPC_Project')
if os.path.exists('gnn-decoder'):
  os.rename('gnn-decoder', 'gnn_decoder')

In [3]:
from sionna.fec.ldpc.encoding import LDPC5GEncoder
from sionna.utils import BitErrorRate, BinarySource, ebnodb2no
from sionna.mapping import Mapper, Demapper
from sionna.channel import AWGN

from REU_LDPC_Project.channel import E2EModelDDECC
from REU_LDPC_Project.decoder import LDPC5GDecoder

from ECCT.Model import MultiHeadedAttention, PositionwiseFeedForward, Encoder, EncoderLayer, ECC_Transformer
from ECCT.args import pass_args
from ECCT.Codes import FER, BER # x_pred, x_gt

# from DDECC.DDECC import DDECCT
from DDECC.args import pass_args_ddecc

from gnn_decoder.gnn import LDPC5GGNN


n, k = (100, 90)
args_ecct = pass_args()
args_ddecc = pass_args_ddecc()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

enc = LDPC5GEncoder(k,n)
dec = LDPC5GDecoder(enc)


class DDECCT(nn.Module):
    def __init__(self, args, encoder,
                 device, dropout=0):

        super(DDECCT, self).__init__()

        self.encoder = encoder
        encoder._m = encoder._n - encoder._k # m = n-k

        self.n_steps = encoder._m + 5 # m + 5
        self.d_model = args.d_model
        self.sigma = args.sigma

        pcm_coo = encoder.pcm.tocoo() # Convert the SciPy sparse matrix to COO format

        # Create a PyTorch sparse tensor from the COO format
        indices = torch.tensor([pcm_coo.row, pcm_coo.col], dtype=torch.int64)
        values = torch.tensor(pcm_coo.data, dtype=torch.float32)
        shape = torch.Size(pcm_coo.shape)

        # Register the sparse tensor as a buffer
        self.register_buffer('pc_matrix', torch.sparse_coo_tensor(indices, values, shape)) # should be float
        self.device = device

        betas = torch.linspace(1e-3, 1e-2, self.n_steps)
        betas = betas*0+self.sigma
        self.betas = betas.view(-1,1)
        self.betas_bar =  torch.cumsum(self.betas, 0).view(-1,1)

        self.ema = EMA(0.9,flag_run=True)

        self.line_search = False

        # code = args.code
        c = copy.deepcopy
        attn = MultiHeadedAttention(args.h, args.d_model)
        ff = PositionwiseFeedForward(args.d_model, args.d_model*4, dropout)

        self.src_embed = torch.nn.Parameter(torch.empty(
            (encoder._n + encoder._m, args.d_model))) #code.n + code.pc_matrix.size(0), args.d_model)))

        self.decoder = Encoder(EncoderLayer(
            args.d_model, c(attn), c(ff), dropout), args.N_dec)

        self.oned_final_embed = torch.nn.Sequential(
            *[nn.Linear(args.d_model, 1)])

        self.out_fc = nn.Linear(encoder._n + encoder._m, encoder._n) #code.n + code.pc_matrix.size(0), code.n)
        self.time_embed = nn.Embedding(self.n_steps, args.d_model)

        self.get_mask()

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)


    def p_sample(self, yt):
        # Single sampling from the real p dist.
        sum_syndrome =  (torch.matmul(sign_to_bin(torch.sign(yt.to(self.device))),self.pc_matrix) % 2).round().long().sum(-1)
        # assert sum_syndrome.max() <= self.pc_matrix.shape[1] and sum_syndrome.min() >= 0
        t = sum_syndrome.cpu()
        # Model output
        noise_mul_pred = self(yt.to(self.device), sum_syndrome.to(self.device)).cpu()# predicted multiplicative noise
        noise_add_pred = yt-torch.sign(-noise_mul_pred * torch.sign(yt)) #predicted additive noise
        factor = (torch.sqrt(self.betas_bar[t])*self.betas[t]/(self.betas_bar[t]+self.betas[t])) #theoretical step size
        alpha_final = 1
        if self.line_search:
            #Perform Step Sizer Line-search # TODO : perform it on GPU for speed
            alpha = torch.linspace(1,20,20).unsqueeze(0).unsqueeze(0)
            new_synd = (torch.matmul(sign_to_bin(torch.sign(yt.unsqueeze(-1) - alpha*(noise_add_pred*factor).unsqueeze(-1))).permute(0,2,1),self.pc_matrix.cpu()) % 2).round().long().sum(-1)
            alpha_final = alpha.squeeze(0)[:,new_synd.argmin(-1).unsqueeze(-1)].squeeze(0)
        yt_1 = yt - alpha_final*noise_add_pred*factor
        yt_1[t==0] = yt[t==0] # if some codeword has 0 synd. keep it as is
        return (yt_1), t


    def p_sample_loop(self, cur_y):
        # Iterative sampling from the real p dist.
        res = []
        synd_all = []
        for it in range(self.pc_matrix.shape[1]):
            cur_y,curr_synd = self.p_sample(cur_y)
            synd_all.append(curr_synd)
            res.append(cur_y)
        synd_all = torch.stack(synd_all).t().cpu()
        # Chose the biggest iteration that reaches 0 synd.
        aa = (synd_all == 0).int()*2-1
        idx = torch.arange(aa.shape[1], 0, -1)
        idx_conv = torch.argmax(aa * idx, 1, keepdim=True)
        return cur_y, res, idx_conv.view(-1), synd_all


  #################################
    def loss(self,x_0):
        t = torch.randint(0, self.n_steps, size=(x_0.shape[0] // 2 + 1,))
        t = torch.cat([t, self.n_steps - t - 1], dim=0)[:x_0.shape[0]].long()
        e = torch.randn_like(x_0)
        noise_factor = torch.sqrt(self.betas_bar[t]).to(x_0.device)

        h = torch.from_numpy(np.random.rayleigh(x_0.size(0),x_0.size(1))).float()
        h = 1.
        yt = h*x_0 * 1 + e * noise_factor
        sum_syndrome =  (torch.matmul(sign_to_bin(torch.sign(yt.to(self.device))),
        self.pc_matrix) % 2).sum(-1).long()

        output = self(yt.to(self.device), sum_syndrome.to(self.device))
        z_mul = (yt *x_0)
        return F.binary_cross_entropy_with_logits(output, sign_to_bin(torch.sign(z_mul.to(self.device))))
  #################################


    def get_mask(self, no_mask=False):
        if no_mask:
            self.src_mask = None
            return

        src_mask = self.build_mask()
        print(src_mask)
        self.register_buffer('src_mask', src_mask)


    def build_mask(self):
            mask_size = 2*self.encoder._n_ldpc - self.encoder._k_ldpc
            mask = torch.eye(mask_size, mask_size)

            for ii in range(self.encoder._n_ldpc - self.encoder._k_ldpc): # m
                idx = self.encoder.pcm[ii].indices #idx = torch.where(self.encoder.pcm[ii].indices > 0)[0]
                for jj in idx:
                    for kk in idx:
                        if jj != kk:
                            mask[jj, kk] += 1
                            mask[kk, jj] += 1
                            mask[self.encoder._n + ii, jj] += 1
                            mask[jj, self.encoder._n + ii] += 1

            src_mask = ~ (mask > 0).unsqueeze(0).unsqueeze(0)
            return src_mask


    def forward(self, y, time_step):
        syndrome = (self.pc_matrix @ sign_to_bin(torch.sign(y)).T.float()) % 2
        syndrome = bin_to_sign(syndrome)
        magnitude = torch.abs(y) # m = H @ y.T
        print(magnitude, syndrome)

        emb = torch.cat([magnitude, syndrome], -1).unsqueeze(-1)
        emb = self.src_embed.unsqueeze(0) * emb

        # Diffusion time steps
        time_emb = self.time_embed(time_step).view(-1, 1, self.d_model) # time_step is the ix
        # d_model shaped nodes 'overseeing' the attn in the network
        # could add a (1, time_embed.size) 'overseeing' attn vector

        emb = time_emb * emb
        emb = self.decoder(emb, self.src_mask,time_emb) # attention

        # removes (d_model, n + m) shaped dims
        return self.out_fc(self.oned_final_embed(emb).squeeze(-1))



class EMA(object):
    def __init__(self, mu=0.999,flag_run = True):
        self.mu = mu
        self.shadow = {}
        self.flag_run = flag_run


    def register(self, module):
        if self.flag_run:
            for name, param in module.named_parameters():
                if param.requires_grad:
                    self.shadow[name] = param.data.clone()


    def update(self, module):
        if self.flag_run:
            for name, param in module.named_parameters():
                if param.requires_grad:
                    self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data



# Models
# gnn = LDPC5GGNN(enc)
# ecct = ECC_Transformer(args_ecct, enc, dec)
ddecct = DDECCT(args_ddecc, enc, device)
# e2e Channel
channel = E2EModelDDECC(ddecct, dec)

# Channel Info
ebno_dbs = np.arange(0, 4.5, 0.5) # ebno_db_min, ebno_db_max, ebno_db_stepsize
channel.call(ebno_dbs)


Path to model/logs: Results_ECCT/POLAR__Code_n_64_k_32__20_06_2024_02_40_12
Namespace(epochs=1000, workers=4, lr=0.0001, gpus='-1', batch_size=128, test_batch_size=2048, seed=42, code_type='POLAR', code_k=32, code_n=64, standardize=False, N_dec=6, d_model=32, h=8, code=<ECCT.args.pass_args.<locals>.Code object at 0x7c2715421450>, path='Results_ECCT/POLAR__Code_n_64_k_32__20_06_2024_02_40_12')
Path to model/logs: DDECCT_Results/POLAR__Code_n_64_k_32__20_06_2024_02_40_12


  indices = torch.tensor([pcm_coo.row, pcm_coo.col], dtype=torch.int64)


tensor([[[[False,  True,  True,  ...,  True,  True,  True],
          [ True, False,  True,  ...,  True,  True,  True],
          [ True,  True, False,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ..., False,  True,  True],
          [ True,  True,  True,  ...,  True, False,  True],
          [ True,  True,  True,  ...,  True,  True, False]]]])
no, ebno_db:  (9, 1) (9,)
bit:  (1, 90)
n, c:  100 (1, 100)
c_pad, c:  (1, 100) (1, 100)
y, no:  (9, 25) (9, 1)
llr:  torch.Size([9, 100])

 5G Decoding llr_ch (torch.Size([9, 100])) 
batch_size:  ()
llr_ch_reshaped:  (9, 100)
llr_5g:  (9, 130)
k_filler:  60
nb_punc_bits:  590
llr_5g:  (9, 150)
x1:  (9, 90)
nb_par_bits:  60
x2:  (9, 60)
z:  (9, 60)
llr_5g:  (9, 210)  n_ldpc:  780

BP llr ((9, 210)) decoding
x_out:  (9, 210)
x_hat:  (9, 210)
u_hat:  (9, 90)
u_reshaped:  (9, 90)
u_out:  (9, 90)


TypeError: Cannot convert the argument `type_value`: torch.float32 to a TensorFlow DType.

In [None]:
enc._n_ldpc, enc._k_ldpc, enc.pcm.shape

In [None]:
for i in enc.pcm:
  print(i)
  break

In [None]:


class E2EModelDDECC(tf.keras.Model):
    def __init__(self, model, decoder,
                       batch_size=1,
                       return_infobits=False,
                       es_no=False,
                       decoder_active=False):
        super().__init__()

        self._n = decoder.encoder._n
        self._k = decoder.encoder._k

        self._binary_source = BinarySource()
        self._num_bits_per_symbol = 4 # QAM16

        # Channel
        ############################
        # Encoding
        self._encoder = model.encoder
        self._mapper = Mapper("qam", self._num_bits_per_symbol) #

        # Channel
        self._channel = AWGN() #
        # adversarial channel noise emulator

        # Decoding
        self._demapper = Demapper("app", "qam", self._num_bits_per_symbol) #
        # Decoders
        self._decoder = model # DDECCT
        self._decoder5g = decoder # LDPC5GDecoder
        ############################

        self._return_infobits = return_infobits
        self._es_no = es_no

        self._batch_size = batch_size

    def train(self):
      pass

    def test(self):
      pass

    # @tf.function(jit_compile=True)
    def call(self, ebno_db):
        # Noise Variance
        if self._decoder is not None and self._es_no==False: # no rate-adjustment for uncoded transmission or es_no scenario
            no = ebnodb2no(ebno_db, self._num_bits_per_symbol, self._k/self._n) ### LOOK UP EBNODB2NO
        else: #for uncoded transmissions the rate is 1
            no = ebnodb2no(ebno_db, self._num_bits_per_symbol, 1) ###
        no = tf.expand_dims(tf.cast(no, tf.float32), axis=-1) # turn to float32, turns shape (9,) -> (9,1)
        print("no, ebno_db: ", no.shape, ebno_db.shape)

        b = self._binary_source([self._batch_size, self._encoder._k]) # (batch_size, k), k information bits
        print("bit: ", b.shape) # print(b.shape[-1]==self._k, b.shape, self._k, self._n - self._k)

        # Turns INFO BITS (batch_size, k) -> (batch_size, n) info and parity bits CODEWORD of rate = k/n
        if self._encoder is not None:
            c = self._encoder(b) ##### c = G @ b.T, (n,k) @ (k,1)
        else:
            c = b

        print("n, c: ", self._n, c.shape)
        # check that rate calculations are correct
        assert self._n == c.shape[-1], "Invalid value of n."

        # zero padding to support odd codeword lengths
        if self._n%2 == 1:
            c_pad = tf.concat([c, tf.zeros([self._batch_size, 1])], axis=1)
        else: # no padding
            c_pad = c
        print("c_pad, c: ", c_pad.shape, c.shape)

        # Channel
        ############################
        x = self._mapper(c_pad)
        y = self._channel([x, no]) ###
        llr = self._demapper([y, no])
        ############################
        print("y, no: ", y.shape, no.shape)

        # remove zero padded bit at the end
        if self._n%2 == 1:
            llr = llr[:,:-1]
        llr = torch.tensor(llr.numpy())
        print("llr: ", llr.shape,)# b, c, x, y)

        # run the decoder
        if self._decoder is not None:
            llr = self._decoder5g(llr)
            llr_ddecc = self._decoder(llr, time_step=0) # 9 no values, 100 bits of data, time step 0
            print("%%%%%%%%%%%%%%%%%%%")

        if self._return_infobits:
            return b, llr
        else:
            return c, llr
