<a href="https://colab.research.google.com/github/pollyjuice74/REU-LDPC-Project/blob/main/ECCT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/pollyjuice74/ECCT
!git clone https://github.com/pollyjuice74/DDECC
!git clone https://github.com/pollyjuice74/REU-LDPC-Project
!git clone https://github.com/pollyjuice74/gnn-decoder

!pip install sionna
#!pip install torch torch-geometric

In [2]:
import os
import copy
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# Rename folders
if os.path.exists('REU-LDPC-Project'):
  os.rename('REU-LDPC-Project', 'REU_LDPC_Project')
if os.path.exists('gnn-decoder'):
  os.rename('gnn-decoder', 'gnn_decoder')

In [4]:
class E2EModelDDECC(tf.keras.Model):
    def __init__(self, model, decoder,
                       batch_size=1,
                       return_infobits=False,
                       es_no=False,
                       decoder_active=False):
        super().__init__()

        self._n = decoder.encoder._n
        self._k = decoder.encoder._k

        self._binary_source = BinarySource()
        self._num_bits_per_symbol = 4 # QAM16


        # Channel
        ############################
        # Encoding
        self._encoder = model.encoder
        self._mapper = Mapper("qam", self._num_bits_per_symbol) #

        # Channel
        self._channel = AWGN() #
        # Add adversarial channel noise emulator

        # Decoding
        self._demapper = Demapper("app", "qam", self._num_bits_per_symbol) #
        # Decoders
        self._decoder = model # DDECCT
        self._decoder5g = decoder # LDPC5GDecoder
        ############################

        self._return_infobits = return_infobits
        self._es_no = es_no

        self._batch_size = batch_size

        # Channel info
        self.ebno_db = np.arange(0, 0.5, 0.5) #4.5 # ebno_db_min, ebno_db_max, ebno_db_stepsize

    def train(self):
      pass

    def test(self):
      pass

    # @tf.function(jit_compile=True)
    def call(self):
        # Noise Variance
        if self._decoder is not None and self._es_no==False: # no rate-adjustment for uncoded transmission or es_no scenario
            no = ebnodb2no(self.ebno_db, self._num_bits_per_symbol, self._k/self._n) ### LOOK UP EBNODB2NO
        else: #for uncoded transmissions the rate is 1
            no = ebnodb2no(self.ebno_db, self._num_bits_per_symbol, 1) ###
        no = tf.expand_dims(tf.cast(no, tf.float32), axis=-1) # turn to float32, turns shape (9,) -> (9,1)
        print("no, ebno_db: ", no.shape, self.ebno_db.shape)

        b = self._binary_source([self._batch_size, self._encoder._k]) # (batch_size, k), k information bits
        print("bit: ", b.shape) # print(b.shape[-1]==self._k, b.shape, self._k, self._n - self._k)

        # Turns INFO BITS (batch_size, k) -> (batch_size, n) info and parity bits CODEWORD of rate = k/n
        if self._encoder is not None:
            c = self._encoder(b) ##### c = G @ b.T, (n,k) @ (k,1)
        else:
            c = b

        print("n, c: ", self._n, c.shape)
        # check that rate calculations are correct
        assert self._n == c.shape[-1], "Invalid value of n."

        # zero padding to support odd codeword lengths
        if self._n%2 == 1:
            c_pad = tf.concat([c, tf.zeros([self._batch_size, 1])], axis=1)
        else: # no padding
            c_pad = c
        print("c_pad, c: ", c_pad.shape, c.shape)

        # Channel
        ############################
        x = self._mapper(c_pad)
        # y = self._channel([x, no]) ###
        llr = self._demapper([x, no]) # no noise
        ############################
        # print("y, no: ", y.shape, no.shape)

        # remove zero padded bit at the end
        if self._n%2 == 1:
            llr = llr[:,:-1]
        print("llr: ", llr.shape, llr)# b, c, x, y)

        # Run decoder
        llr_nldpc, u_hat, x_hat = self._decoder5g(llr) # Gets reshaped (n_ldpc,1) llrs
        print("llr (n_ldpc,): ", llr_nldpc.shape, " sum positive: ", tf.reduce_sum(tf.boolean_mask(llr, llr > 0)), " n_ldpc: ", self._encoder._n_ldpc)
        print("llr (crude): ", llr_nldpc[:, 54])

        if isinstance(llr, tf.Tensor):
            llr = torch.tensor(llr.numpy())
        if isinstance(llr_nldpc, tf.Tensor):
            llr_nldpc = torch.tensor(llr_nldpc.numpy())

        r_cw = (llr > 0).float()
        print("c == r_cw: ", c.shape, c==r_cw)

        print("c == x_hat: ", c==x_hat)

        llr_ddecc = self._decoder(llr_nldpc, time_step=0) # Outputs decoded llrs (n_ldpc,1)
        print("llr_ddecc: ", llr_ddecc.shape)

        # TODO: How do I turn the decoded llrs of (n_ldpc,1) to c_hat (n,1)?
        c_hat = llr_ddecc

        # codeword, info bits, llr of either cw or info bits
        return c, b, c_hat

        # if self._return_infobits:
        #     return b, llr_ddecc
        # else:
        #     return c, llr_ddecc



from sionna.fec.ldpc.encoding import LDPC5GEncoder
# from sionna.fec.ldpc.decoding import LDPC5GDecoder
from sionna.utils import BitErrorRate, BinarySource, ebnodb2no
from sionna.mapping import Mapper, Demapper
from sionna.channel import AWGN

# from REU_LDPC_Project.channel import E2EModelDDECC
from REU_LDPC_Project.decoder import LDPC5GDecoder

from ECCT.Model import MultiHeadedAttention, PositionwiseFeedForward, Encoder, EncoderLayer, ECC_Transformer
from ECCT.args import pass_args
from ECCT.Codes import FER, BER # x_pred, x_gt

# from DDECC.DDECC import DDECCT
from DDECC.args import pass_args_ddecc
from DDECC.Codes import bin_to_sign, sign_to_bin
# from DDECC.Main import train, test

from gnn_decoder.gnn import LDPC5GGNN


import random
import os
from torch.utils.data import DataLoader
from torch.utils import data
from datetime import datetime
import logging
import time
from torch.optim.lr_scheduler import CosineAnnealingLR


n, k = (100, 90)
# args_ecct = pass_args()
args_ddecc = pass_args_ddecc()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

enc = LDPC5GEncoder(k,n)
dec = LDPC5GDecoder(enc)

# Models
# gnn = LDPC5GGNN(enc)
# ecct = ECC_Transformer(args_ecct, enc, dec)
ddecct = DDECCT(args_ddecc, enc, device)
# e2e Channel
channel = E2EModelDDECC(ddecct, dec)

# Train info
optimizer = torch.optim.Adam(channel._decoder.parameters(), lr=args_ddecc.lr)


def train(channel, model, optimizer, LR, iters=1000): #device, train_loader, optimizer, epoch, LR):

    model.train()
    cum_loss = cum_samples = 0
    t = time.time()

    # for epoch in range(epochs):

    for i in range(iters):
        c, b, b_hat = channel.call() # Created pcm

        # Create channel transmission
        # c_hat = (c_hat > 0.5).float()  # Convert c_hat to binary and then to float

        # Convert c, c_hat to tensors
        if isinstance(c, tf.Tensor):
            c = torch.tensor(c.numpy()).float()
        if isinstance(b_hat, tf.Tensor):
            b_hat = torch.tensor(b_hat.numpy())
        print("c: ", c.shape, " b_hat: ", b_hat.shape)
        print(c, b_hat)

        # Ensure c and c_hat require gradients
        c = c.requires_grad_()
        b_hat = b_hat.requires_grad_()

        if c.shape != b_hat.shape:
            c = c.expand_as(b_hat)  # repeat c to match dims of c_hat with different freqs.

        loss = F.binary_cross_entropy_with_logits(b, b_hat) # compares c and c_hat

        model.zero_grad()
        loss.backward()
        optimizer.step()
        model.ema.update(model) # update EMA

        cum_loss += loss.item() * b_hat.shape[0] # x.shape[0]
        cum_samples += b_hat.shape[0]

        # if (i+1) % 3 == 0:
        print(f'\nBatch {i + 1}/{iters}: LR={LR:.2e}, Loss={cum_loss / cum_samples:.5e}')
        print(f'Iter {i} Train Time {time.time() - t}s\n')
        print(f'BER: ', BER(c, c_hat), ' FER: ', FER(c, c_hat))

    return cum_loss / cum_samples


# Train loop
train(channel, channel._decoder, optimizer, args_ddecc.lr)

c, b, c_hat = channel.call() # cw, info bits, llr_ddecc




Path to model/logs: DDECCT_Results/POLAR__Code_n_64_k_32__25_06_2024_12_40_48


  indices = torch.tensor([pcm_coo.row, pcm_coo.col], dtype=torch.int64)


tensor([[[[False,  True,  True,  ...,  True,  True,  True],
          [ True, False,  True,  ...,  True,  True,  True],
          [ True,  True, False,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ..., False,  True,  True],
          [ True,  True,  True,  ...,  True, False,  True],
          [ True,  True,  True,  ...,  True,  True, False]]]])
no, ebno_db:  (1, 1) (1,)
bit:  (1, 90)
n, c:  100 (1, 100)
c_pad, c:  (1, 100) (1, 100)
llr:  (1, 100) tf.Tensor(
[[ 1.6394187 -1.6394187 -1.6394182 -1.6394184 -5.971885   5.971885
   1.42679    1.42679    1.6394188 -5.971885  -1.6394184  1.4267902
  -1.6394187  1.6394187 -1.6394184 -1.6394181  1.6394187 -1.6394187
  -1.6394182 -1.6394184  1.6394187  1.6394187 -1.6394181 -1.6394181
  -1.6394188  5.971885  -1.6394184  1.4267899 -1.6394188 -5.971885
  -1.6394184  1.4267902 -5.971885   1.6394186  1.4267902 -1.6394184
  -1.6394187 -1.6394187 -1.6394184 -1.6394184  5.971885  -1.6394186
   1.4267899 -1.6394184 -1.639418

RuntimeError: The size of tensor a (1410) must match the size of tensor b (20) at non-singleton dimension 1

In [3]:


class DDECCT(nn.Module):
    def __init__(self, args, encoder,
                 device, dropout=0):

        super(DDECCT, self).__init__()

        self.encoder = encoder
        encoder._m_ldpc = encoder._n_ldpc - encoder._k_ldpc # m_ldpc = n_ldpc-k_ldpc

        self.n_steps = encoder._m_ldpc + 5 # m_ldpc + 5
        self.d_model = args.d_model
        self.sigma = args.sigma

        pcm_coo = encoder.pcm.tocoo() # Convert the SciPy sparse matrix to COO format

        # Create a PyTorch sparse tensor from the COO format
        indices = torch.tensor([pcm_coo.row, pcm_coo.col], dtype=torch.int64)
        values = torch.tensor(pcm_coo.data, dtype=torch.float32)
        shape = torch.Size(pcm_coo.shape)

        # Register the sparse tensor as a buffer
        self.register_buffer('pc_matrix', torch.sparse_coo_tensor(indices, values, shape)) # should be float
        self.device = device

        betas = torch.linspace(1e-3, 1e-2, self.n_steps)
        betas = betas*0+self.sigma
        self.betas = betas.view(-1,1)
        self.betas_bar =  torch.cumsum(self.betas, 0).view(-1,1)

        self.line_search = False

        # code = args.code
        c = copy.deepcopy
        attn = MultiHeadedAttention(args.h, args.d_model)
        ff = PositionwiseFeedForward(args.d_model, args.d_model*4, dropout)

        self.src_embed = torch.nn.Parameter(torch.empty(
            (encoder._n_ldpc + encoder._m_ldpc, args.d_model))) #code.n + code.pc_matrix.size(0), args.d_model)))

        self.decoder = Encoder(EncoderLayer(
            args.d_model, c(attn), c(ff), dropout), args.N_dec)

        self.oned_final_embed = torch.nn.Sequential(
            *[nn.Linear(args.d_model, 1)])

        # want a shape (1,n) original codeword sent #code.n + code.pc_matrix.size(0), code.n)
        self.out_fc = nn.Sequential(
            nn.Linear(encoder._n_ldpc + encoder._m_ldpc, encoder._k),
            nn.Sigmoid(),  # Convert logits to probabilities
        )
        self.time_embed = nn.Embedding(self.n_steps, args.d_model)

        self.get_mask()

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        self.ema = EMA(self, 0.9)


    def forward(self, llrs, time_step):
        print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        print("\nDDECCT model")
        # Ensure r_cw is a tensor
        if isinstance(llrs, tf.Tensor):
            llrs = torch.tensor(llrs.numpy())
        # Ensure time_step is a tensor
        if isinstance(time_step, int):
            time_step = torch.tensor([time_step], dtype=torch.long)
        elif isinstance(time_step, list):
            time_step = torch.tensor(time_step, dtype=torch.long)

        syndrome = torch.sparse.mm(self.pc_matrix, llrs.T) % 2 # syndrome = (self.pc_matrix @ sign_to_bin(torch.sign(y)).T.float()) % 2
        syndrome = (syndrome > 0).float().T.reshape(-1, 10)
        # syndrome = bin_to_sign(syndrome).T
        magnitude = torch.abs(syndrome) # m = H @ y.T
        print("magnitude: ", magnitude.shape, magnitude)
        print("syndrome: ", syndrome.shape, syndrome)

        emb = torch.cat([magnitude, syndrome], -1).unsqueeze(-1) # (9, n_ldpc + m_ldpc, 1)
        emb = self.src_embed.unsqueeze(0) * emb # (9, n_ldpc + m_ldpc, 32) embeding size
        print("emb: ", emb.shape)

        # Diffusion time steps
        time_emb = self.time_embed(time_step).view(-1, 1, self.d_model) # time_step is the ix
        # d_model shaped nodes 'overseeing' the attn in the network
        # could add a (1, time_embed.size) 'overseeing' attn vector
        print("time_emb: ", time_emb.shape)

        emb = time_emb * emb
        print("emb: ", emb.shape) #, " args.N_dec: ", self.args.N_dec)
        emb = self.decoder(emb, self.src_mask, )#time_emb) # attention
        print("emb: ", emb.shape) #, " args.N_dec: ", self.args.N_dec)

        # removes (d_model, n + m) shaped dims
        out_fc = self.out_fc(self.oned_final_embed(emb).squeeze(-1))
        print("out_fc: ", out_fc.shape, out_fc)
        return out_fc


    def p_sample(self, yt):
        # Single sampling from the real p dist.
        sum_syndrome =  (torch.matmul(sign_to_bin(torch.sign(yt.to(self.device))),self.pc_matrix) % 2).round().long().sum(-1)
        # assert sum_syndrome.max() <= self.pc_matrix.shape[1] and sum_syndrome.min() >= 0
        t = sum_syndrome.cpu()
        # Model output
        noise_mul_pred = self(yt.to(self.device), sum_syndrome.to(self.device)).cpu()# predicted multiplicative noise
        noise_add_pred = yt-torch.sign(-noise_mul_pred * torch.sign(yt)) #predicted additive noise
        factor = (torch.sqrt(self.betas_bar[t])*self.betas[t]/(self.betas_bar[t]+self.betas[t])) #theoretical step size
        alpha_final = 1
        if self.line_search:
            #Perform Step Sizer Line-search # TODO : perform it on GPU for speed
            alpha = torch.linspace(1,20,20).unsqueeze(0).unsqueeze(0)
            new_synd = (torch.matmul(sign_to_bin(torch.sign(yt.unsqueeze(-1) - alpha*(noise_add_pred*factor).unsqueeze(-1))).permute(0,2,1),self.pc_matrix.cpu()) % 2).round().long().sum(-1)
            alpha_final = alpha.squeeze(0)[:,new_synd.argmin(-1).unsqueeze(-1)].squeeze(0)
        yt_1 = yt - alpha_final*noise_add_pred*factor
        yt_1[t==0] = yt[t==0] # if some codeword has 0 synd. keep it as is
        return (yt_1), t


    def p_sample_loop(self, cur_y):
        # Iterative sampling from the real p dist.
        res = []
        synd_all = []
        for it in range(self.pc_matrix.shape[1]):
            cur_y,curr_synd = self.p_sample(cur_y)
            synd_all.append(curr_synd)
            res.append(cur_y)
        synd_all = torch.stack(synd_all).t().cpu()
        # Chose the biggest iteration that reaches 0 synd.
        aa = (synd_all == 0).int()*2-1
        idx = torch.arange(aa.shape[1], 0, -1)
        idx_conv = torch.argmax(aa * idx, 1, keepdim=True)
        return cur_y, res, idx_conv.view(-1), synd_all


  #################################
    def loss(self,x_0):
        print("\nDDECC Loss")
        # Convert NumPy array to PyTorch tensor
        x_0_np = x_0.numpy()
        x_0 = torch.tensor(x_0_np)
        print("x_0: ", x_0.shape)

        t = torch.randint(0, self.n_steps, size=(x_0.shape[0] // 2 + 1,))
        print("t: ", t.shape)
        t = torch.cat([t, self.n_steps - t - 1], dim=0)[:x_0.shape[0]].long()
        print("t: ", t.shape)

        e = torch.randn_like(x_0)
        print("e: ", e.shape)

        noise_factor = torch.sqrt(self.betas_bar[t]).to(x_0.device)
        print("noise_factor: ", noise_factor.shape)

        h = torch.from_numpy(np.random.rayleigh(x_0.size(0),x_0.size(1))).float()
        print("h: ", h.shape)
        h = 1.

        yt = h*x_0 * 1 + e * noise_factor
        print("yt: ", yt.shape)

        sum_syndrome =  (torch.matmul(sign_to_bin(torch.sign(yt.to(self.device))),
        self.pc_matrix) % 2).sum(-1).long()
        print("sum_syndrome: ", sum_syndrome.shape)

        output = self(yt.to(self.device), sum_syndrome.to(self.device))
        print("output: ", output.shape)

        z_mul = (yt *x_0)
        print("z_mul: ", z_mul.shape)

        return F.binary_cross_entropy_with_logits(output, sign_to_bin(torch.sign(z_mul.to(self.device))))
  #################################


    def get_mask(self, no_mask=False):
        if no_mask:
            self.src_mask = None
            return

        src_mask = self.build_mask()
        print(src_mask)
        self.register_buffer('src_mask', src_mask)


    def build_mask(self):
        mask_size = 2*self.encoder._n_ldpc - self.encoder._k_ldpc
        mask = torch.eye(mask_size, mask_size)

        for ii in range(self.encoder._n_ldpc - self.encoder._k_ldpc): # m_ldpc, check node
            idx = self.encoder.pcm[ii].indices #idx = torch.where(self.encoder.pcm[ii].indices > 0)[0]
            for jj in idx:
                for kk in idx:
                    if jj != kk:
                        mask[jj, kk] += 1
                        mask[kk, jj] += 1
                        mask[self.encoder._n_ldpc + ii, jj] += 1
                        mask[jj, self.encoder._n_ldpc + ii] += 1

        src_mask = ~ (mask > 0).unsqueeze(0).unsqueeze(0)
        return src_mask



class EMA(object):
    def __init__(self, module, mu=0.999):
        self.mu = mu
        self.shadow = {}
        self.register(module)

    def register(self, module):
            for name, param in module.named_parameters():
                if param.requires_grad:
                    self.shadow[name] = param.data.clone()


    def update(self, module):
            for name, param in module.named_parameters():
                if param.requires_grad:
                    self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data




In [None]:
enc._n_ldpc, enc._k_ldpc, enc.pcm.shape