<a href="https://colab.research.google.com/github/pollyjuice74/REU-LDPC-Project/blob/main/Comp_Diffusion_to_GNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function
import argparse
import random
import sys
import os
import time
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from torch.utils import data
from torch.optim.lr_scheduler import CosineAnnealingLR
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense

!git clone https://github.com/pollyjuice74/gnn-decoder
!pip install sionna

if os.path.exists('gnn-decoder'):
  os.rename('gnn-decoder', 'gnn_decoder')

from gnn_decoder.gnn import GNN_BP, UpdateEmbeddings

if not os.path.exists('DDECC'):
  !git clone https://github.com/pollyjuice74/DDECC
os.chdir('DDECC')

from Codes import *
from DDECC import *
from utils import *
from args import *


# The modifications are made so that the model trains on BCH codes of length n=63, k=45, where it says:

"### IMPORTANT ###"

# The modifications made are seen in the **FEC_Dataset** where it says:

"### IMPORTANT ###"

In [2]:

class FEC_Dataset(data.Dataset):                 ####
    def __init__(self, code, sigma, len, zero_cw=True):
        self.code = code
        self.sigma = sigma
        self.len = len
        self.generator_matrix = code.generator_matrix.transpose(0, 1)
        self.pc_matrix = code.pc_matrix.transpose(0, 1)

        self.zero_word = torch.zeros((self.code.k)).long() if zero_cw else None
        self.zero_cw = torch.zeros((self.code.n)).long() if zero_cw else None

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        if self.zero_cw is None:
            m = torch.randint(0, 2, (1, self.code.k)).squeeze()
            x = torch.matmul(m, self.generator_matrix) % 2
        else: # SET TO TRUE
            m = self.zero_word
            x = self.zero_cw

        #h = torch.from_numpy(np.random.rayleigh(1,self.code.n)).float()
        # y = x.clone()

        # # Random bit flipping error
        # ix = torch.tensor(random.sample(range(self.code.n), 3))
        # y[ix] = 1 - y[ix] # flip bits
        # y = bin_to_sign(y)

        ### IMPORTANT ###
        #######################################################################
        # Make noise
        std_noise = random.choice(self.sigma)
        z = torch.randn(self.code.n) * std_noise
        # Convert y to sign and add noise
        h=1
        y = h*bin_to_sign(x) + z

        # Sign to LLR conversion
        var = std_noise ** 2
        def sign_to_llr(bpsk_vect, noise_variance):
          return 2*bpsk_vect / noise_variance

        # x,y to llrs
        x = bin_to_sign(x)
        x_llr = sign_to_llr(x, var)
        y_llr = sign_to_llr(y, var)
        #######################################################################

        magnitude = torch.abs(y)
        syndrome = torch.matmul(sign_to_bin(torch.sign(y)).long(),
                                self.pc_matrix) % 2
        syndrome = bin_to_sign(syndrome)
        return m.float(), x.float(), z.float(), y.float(), x_llr.float(), y_llr.float(), magnitude.float(), syndrome.float()


### DIFFUSION FUNCTIONS ###

def train_dif(model, device, train_loader, optimizer, epoch, LR):
    model.train()
    cum_loss = cum_samples = 0
    t = time.time()
    for batch_idx, (m, x, z, y, x_llr, y_llr, magnitude, syndrome) in enumerate(train_loader):
        # stop at batch 25
        if batch_idx==25:
          break

        loss = model.loss(bin_to_sign(x))
        model.zero_grad()
        loss.backward()
        optimizer.step()
        model.ema.update(model)

        cum_loss += loss.item() * x.shape[0]
        cum_samples += x.shape[0]
        if (batch_idx+1) % 25 == 0 or batch_idx == len(train_loader) - 1:
            print(f'Training epoch {epoch}, Batch {batch_idx + 1}/{len(train_loader)}: LR={LR:.2e}, Loss={cum_loss / cum_samples:.5e}')
            break
    print(f'Epoch {epoch} Train Time {time.time() - t}s\n')
    return cum_loss / cum_samples


def test_dif(model, device, test_loader_list, EbNo_range_test, min_FER=100, max_cum_count=1e7, min_cum_count=1e5):
    model.eval()
    test_loss_ber_list, test_loss_fer_list, cum_samples_all = [], [], []
    t = time.time()
    with torch.no_grad():
        for ii, test_loader in enumerate(test_loader_list): # just the first three
            # stop at batch 5
            if ii==5:
              break

            test_ber = test_fer = cum_count = 0.
            _, x_pred_list, _, _ = model.p_sample_loop(next(iter(test_loader))[3])
            test_ber_ddpm , test_fer_ddpm = [0]*len(x_pred_list), [0]*len(x_pred_list)
            idx_conv_all = []
            printed = False # Flag for printing x, x_pred


            for batch_ix, (m, x, z, y, x_llr, y_llr, magnitude, syndrome) in enumerate(test_loader):
                # stop at batch 5
                if batch_ix==5:
                  break

                x_pred, x_pred_list, idx_conv,synd_all = model.p_sample_loop(y)
                # convert to binary
                x_pred = sign_to_bin(torch.sign(x_pred))
                x = sign_to_bin(x)

                idx_conv_all.append(idx_conv)
                for kk, x_pred_tmp in enumerate(x_pred_list):
                    x_pred_tmp = sign_to_bin(torch.sign(x_pred_tmp))

                    test_ber_ddpm[kk] += BER(x_pred_tmp, x) * x.shape[0]
                    test_fer_ddpm[kk] += FER(x_pred_tmp, x) * x.shape[0]

                test_ber += BER(x_pred, x) * x.shape[0]
                test_fer += FER(x_pred, x) * x.shape[0]
                cum_count += x.shape[0]

                if not printed:
                  print("x: ", x)
                  print("x_pred: ", x_pred)
                  printed = True

                break # from while loop

            idx_conv_all = torch.stack(idx_conv_all).float()
            cum_samples_all.append(cum_count)
            test_loss_ber_list.append(test_ber / cum_count)
            test_loss_fer_list.append(test_fer / cum_count)
            for kk in range(len(test_ber_ddpm)):
                test_ber_ddpm[kk] /= cum_count
                test_fer_ddpm[kk] /= cum_count
            print(f'Test EbN0={EbNo_range_test[ii]}, BER={test_loss_ber_list}')
            print(f'Test EbN0={EbNo_range_test[ii]}, BER_DDPM={test_ber_ddpm}')
            print(f'Test EbN0={EbNo_range_test[ii]}, -ln(BER)_DDPM={[-np.log(elem) for elem in test_ber_ddpm]}')
            print(f'Test EbN0={EbNo_range_test[ii]}, FER_DDPM={test_fer_ddpm}')
            print(f'#It. to zero syndrome: Mean={idx_conv_all.mean()}, Std={idx_conv_all.std()}, Min={idx_conv_all.min()}, Max={idx_conv_all.max()}')

        print('Test FER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_fer_list, EbNo_range_test))]))
        print('Test BER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
        print('Test -ln(BER) ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, -np.log(elem)) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
    print(f'# of testing samples: {cum_samples_all}\n Test Time {time.time() - t} s\n')
    return test_loss_ber_list, test_loss_fer_list


### GNN FUNCTIONS ###

def train_gnn(model, device, train_loader, optimizer, epoch, LR):
    # model.train()
    cum_loss = cum_samples = 0
    t = time.time()
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    for batch_idx, (m, x, z, y, x_llr, y_llr, magnitude, syndrome) in enumerate(train_loader): # train_loader size 1000
        # stop at batch 250
        if batch_idx==250:
          break
        # convert to tf for GNN_BP
        y_llr = tf.convert_to_tensor(y_llr.numpy(), dtype=tf.float32)

        with tf.GradientTape() as tape:
          # model prediction
          x_llr_hat = model(y_llr)

          loss = loss_fn(x_llr, x_llr_hat)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        cum_loss += loss * x.shape[0]
        cum_samples += x.shape[0]
        if (batch_idx+1) % 5 == 0 or batch_idx == len(train_loader) - 1:
            print(f'Training epoch {epoch}, Batch {batch_idx + 1}/{len(train_loader)}: LR={LR:.2e}, Loss={cum_loss / cum_samples:.5e}')

    print(f'Epoch {epoch} Train Time {time.time() - t}s\n')
    return cum_loss / cum_samples


def test_gnn(model, device, test_loader_list, EbNo_range_test, min_FER=100, max_cum_count=1e7, min_cum_count=1e5):
    # model.eval()
    test_loss_ber_list, test_loss_fer_list, cum_samples_all = [], [], []
    t = time.time()
    printed = False

    with torch.no_grad():
        for ii, test_loader in enumerate(test_loader_list):
            # stop at batch 5
            if ii==5:
              break

            test_ber = test_fer = cum_count = 0.
            # idx_conv_all = []
            # Iterate over first 5 batches
            for batch_ix, (m, x, z, y, x_llr, y_llr, magnitude, syndrome) in enumerate(test_loader):
                # stop at batch 5
                if batch_ix==5:
                  break

                # convert llr from pytorch to tensorflow
                y_llr = tf.convert_to_tensor(y_llr.numpy(), dtype=tf.float32)

                # Model prediction
                x_llr_hat = model(y_llr)
                x_hat = torch.tensor(x_llr_hat.numpy() > 0, dtype=torch.float32)

                # Convert x from bpsk to binary
                x = sign_to_bin(x)

                # FER, BER
                test_ber += BER(x_hat, x) * x.shape[0]
                test_fer += FER(x_hat, x) * x.shape[0]
                cum_count += x.shape[0]

                # Show prediction and actual x
                if not printed:
                  print("GNN x_hat: ", x_hat)
                  print("Actual x: ", x)
                  printed = True

                if (min_FER > 0 and test_fer > min_FER and cum_count > min_cum_count) or cum_count >= max_cum_count:
                    if cum_count >= 1e9:
                        print(f'Cum count reached EbN0:{EbNo_range_test[ii]}')
                    else:
                        print(f'FER count treshold reached EbN0:{EbNo_range_test[ii]}')
                    break


            # idx_conv_all = torch.stack(idx_conv_all).float()
            cum_samples_all.append(cum_count)
            test_loss_ber_list.append(test_ber / cum_count)
            test_loss_fer_list.append(test_fer / cum_count)

            print(f'Test EbN0={EbNo_range_test[ii]}, BER={test_loss_ber_list}')
            # print(f'#It. to zero syndrome: Mean={idx_conv_all.mean()}, Std={idx_conv_all.std()}, Min={idx_conv_all.min()}, Max={idx_conv_all.max()}')

        print('Test FER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_fer_list, EbNo_range_test))]))
        print('Test BER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
        print('Test -ln(BER) ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, -np.log(elem)) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
    print(f'# of testing samples: {cum_samples_all}\n Test Time {time.time() - t} s\n')
    return test_loss_ber_list, test_loss_fer_list


In [None]:

args = pass_args_ddecc()

code = args.code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Models
ddecct = DDECCT(args, device=device,dropout=0).to(device)
ddecct.ema.register(ddecct)
gnn = GNN_BP(code)

# Optimizers and schedulers
optimizer_dif = torch.optim.Adam(ddecct.parameters(), lr=args.lr)
scheduler_dif = CosineAnnealingLR(optimizer_dif, T_max=args.epochs, eta_min=5e-6)
print(f'Diffusion # of Parameters: {np.sum([np.prod(p.shape) for p in ddecct.parameters()])}')

scheduler_gnn = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=args.lr, decay_steps=args.epochs*1000) # 1000 is size of trainloader
optimizer_gnn = tf.keras.optimizers.Adam(learning_rate=scheduler_gnn)
print(f'GNN # of Parameters: {len(gnn.trainable_variables)}\n')

print("Creating data...")
EbNo_range_test = range(4, 7)
EbNo_range_train = range(2, 8)
std_train = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_train]
std_test = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_test]
train_dataloader = DataLoader(FEC_Dataset(code, std_train, len=args.batch_size * 1000, zero_cw=True), batch_size=int(args.batch_size),
                              shuffle=True, num_workers=args.workers)
test_dataloader_list = [DataLoader(FEC_Dataset(code, [std_test[ii]], len=int(args.test_batch_size), zero_cw=False),
                                    batch_size=int(args.test_batch_size), shuffle=False, num_workers=args.workers) for ii in range(len(std_test))]

print(f"Training model with code type: {args.code_type}\n\n")



# Train for 25 data loader batches
for epoch in range(1, args.epochs + 1):
    print("Training GNN...")
    train_gnn(gnn, device, train_dataloader,
              optimizer_gnn, epoch, LR=scheduler_gnn(tf.Variable(0, dtype=tf.float32)).numpy())
    print("Training DDECCT...")
    train_dif(ddecct, device, train_dataloader,
              optimizer_dif, epoch, LR=scheduler_dif.get_last_lr()[0])

    # print comparison
    if epoch % 1 == 0 or epoch in [1,5]:
        print("Testing GNN...")
        test_gnn(gnn, device, test_dataloader_list,
                 EbNo_range_test,min_FER=50,max_cum_count=1e6,min_cum_count=1e4)
        print("Testing DDECCT...")
        test_dif(ddecct, device, test_dataloader_list,
                 EbNo_range_test,min_FER=50,max_cum_count=1e6,min_cum_count=1e4)
    break # from for loop




Path to model/logs: DDECCT_Results/LDPC__Code_n_121_k_80__26_06_2024_21_42_02
Diffusion # of Parameters: 52503
GNN # of Parameters: 0

Creating data...
Training model with code type: LDPC


Training DDECCT...


  self.pid = os.fork()
  self.pid = os.fork()


Training epoch 1, Batch 25/1000: LR=5.00e-04, Loss=4.59557e-01
Epoch 1 Train Time 58.65085005760193s

Testing DDECCT...
