<a href="https://colab.research.google.com/github/pollyjuice74/REU-LDPC-Project/blob/main/DDECCT_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function
import argparse
import random
import sys
import os
import torch
from torch.utils.data import DataLoader
from torch.utils import data
from datetime import datetime
import time
from torch.optim.lr_scheduler import CosineAnnealingLR

!git clone https://github.com/pollyjuice74/gnn-decoder
!pip install sionna

if os.path.exists('gnn-decoder'):
  os.rename('gnn-decoder', 'gnn_decoder')

from gnn_decoder.gnn import GNN_BP

if not os.path.exists('DDECC'):
  !git clone https://github.com/pollyjuice74/DDECC.git
os.chdir('DDECC')

from Codes import *
from DDECC import *
from utils import *


# The modifications are made so that the model trains on BCH codes of length n=63, k=45, where it says:

"### IMPORTANT ###"

# The modifications made are seen in the **FEC_Dataset** where it says:

"### IMPORTANT ###"

In [3]:
############ ZERO CODEWORD SET TO TRUE ###############
class FEC_Dataset(data.Dataset):                 ####
    def __init__(self, code, sigma, len, zero_cw=True):
        self.code = code
        self.sigma = sigma
        self.len = len
        self.generator_matrix = code.generator_matrix.transpose(0, 1)
        self.pc_matrix = code.pc_matrix.transpose(0, 1)

        self.zero_word = torch.zeros((self.code.k)).long() if zero_cw else None
        self.zero_cw = torch.zeros((self.code.n)).long() if zero_cw else None

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        if self.zero_cw is None:
            m = torch.randint(0, 2, (1, self.code.k)).squeeze()
            x = torch.matmul(m, self.generator_matrix) % 2
        else: # SET TO TRUE
            m = self.zero_word
            x = self.zero_cw

        std_noise = random.choice(self.sigma)
        z = torch.randn(self.code.n) * std_noise
        #h = torch.from_numpy(np.random.rayleigh(1,self.code.n)).float()
        # y = x.clone()

        # # index to be flipped
        # ix = torch.tensor(random.sample(range(self.code.n), 3))
        # y[ix] = 1 - y[ix] # flip bits
        # y = bin_to_sign(y)

        ### IMPORTANT ###
        #######################################################################
        #######################################################################

        h=1
        y = h*bin_to_sign(x) + z

        std_noise = random.choice(self.sigma)
        noise_variance = std_noise ** 2

        # y to llrs
        llr = 2 * y / noise_variance
        #######################################################################
        #######################################################################

        magnitude = torch.abs(y)
        syndrome = torch.matmul(sign_to_bin(torch.sign(y)).long(),
                                self.pc_matrix) % 2
        syndrome = bin_to_sign(syndrome)
        return m.float(), x.float(), z.float(), y.float(), llr.float(), magnitude.float(), syndrome.float()


def train(model, device, train_loader, optimizer, epoch, LR):
    model.train()
    cum_loss = cum_samples = 0
    t = time.time()
    for batch_idx, (m, x, z, y, llr, magnitude, syndrome) in enumerate(
            train_loader):
        loss = model.loss(bin_to_sign(x))
        model.zero_grad()
        loss.backward()
        optimizer.step()
        model.ema.update(model)
        ###
        cum_loss += loss.item() * x.shape[0]
        cum_samples += x.shape[0]
        if (batch_idx+1) % 500 == 0 or batch_idx == len(train_loader) - 1:
            print(
                f'Training epoch {epoch}, Batch {batch_idx + 1}/{len(train_loader)}: LR={LR:.2e}, Loss={cum_loss / cum_samples:.5e}')
    print(f'Epoch {epoch} Train Time {time.time() - t}s\n')
    return cum_loss / cum_samples


def test(model, device, test_loader_list, EbNo_range_test, min_FER=100, max_cum_count=1e7, min_cum_count=1e5):
    model.eval()
    test_loss_ber_list, test_loss_fer_list, cum_samples_all = [], [], []
    t = time.time()
    with torch.no_grad():
        for ii, test_loader in enumerate(test_loader_list):
            test_ber = test_fer = cum_count = 0.
            _, x_pred_list, _, _ = model.p_sample_loop(next(iter(test_loader))[3])
            test_ber_ddpm , test_fer_ddpm = [0]*len(x_pred_list), [0]*len(x_pred_list)
            idx_conv_all = []
            while True:
                (m, x, z, y, magnitude, syndrome) = next(iter(test_loader))
                x_pred, x_pred_list, idx_conv,synd_all = model.p_sample_loop(y)
                x_pred = sign_to_bin(torch.sign(x_pred))

                idx_conv_all.append(idx_conv)
                for kk, x_pred_tmp in enumerate(x_pred_list):
                    x_pred_tmp = sign_to_bin(torch.sign(x_pred_tmp))

                    test_ber_ddpm[kk] += BER(x_pred_tmp, x) * x.shape[0]
                    test_fer_ddpm[kk] += FER(x_pred_tmp, x) * x.shape[0]

                test_ber += BER(x_pred, x) * x.shape[0]
                test_fer += FER(x_pred, x) * x.shape[0]
                cum_count += x.shape[0]
                if (min_FER > 0 and test_fer > min_FER and cum_count > min_cum_count) or cum_count >= max_cum_count:
                    if cum_count >= 1e9:
                        print(f'Cum count reached EbN0:{EbNo_range_test[ii]}')
                    else:
                        print(f'FER count treshold reached EbN0:{EbNo_range_test[ii]}')
                    break
            idx_conv_all = torch.stack(idx_conv_all).float()
            cum_samples_all.append(cum_count)
            test_loss_ber_list.append(test_ber / cum_count)
            test_loss_fer_list.append(test_fer / cum_count)
            for kk in range(len(test_ber_ddpm)):
                test_ber_ddpm[kk] /= cum_count
                test_fer_ddpm[kk] /= cum_count
            print(f'Test EbN0={EbNo_range_test[ii]}, BER={test_loss_ber_list}')
            print(f'Test EbN0={EbNo_range_test[ii]}, BER_DDPM={test_ber_ddpm}')
            print(f'Test EbN0={EbNo_range_test[ii]}, -ln(BER)_DDPM={[-np.log(elem) for elem in test_ber_ddpm]}')
            print(f'Test EbN0={EbNo_range_test[ii]}, FER_DDPM={test_fer_ddpm}')
            print(f'#It. to zero syndrome: Mean={idx_conv_all.mean()}, Std={idx_conv_all.std()}, Min={idx_conv_all.min()}, Max={idx_conv_all.max()}')
        ###
        print('Test FER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_fer_list, EbNo_range_test))]))
        print('Test BER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
        print('Test -ln(BER) ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, -np.log(elem)) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
    print(f'# of testing samples: {cum_samples_all}\n Test Time {time.time() - t} s\n')
    return test_loss_ber_list, test_loss_fer_list



# model = DDECCT(args, device=device,dropout=0).to(device)
# model.ema.register(model)


In [6]:
args = pass_args_ddecc() # code_type, k, n

code = args.code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# MODEL #
model = DDECCT(args, device=device,dropout=0).to(device)
model.ema.register(model)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=5e-6)

print(model)
print(f'# of Parameters: {np.sum([np.prod(p.shape) for p in model.parameters()])}')

#################################
EbNo_range_test = range(4, 7)
EbNo_range_train = range(2, 8)
std_train = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_train]
std_test = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_test]
train_dataloader = DataLoader(FEC_Dataset(code, std_train, len=args.batch_size * 1000, zero_cw=True), batch_size=int(args.batch_size),
                              shuffle=True, num_workers=args.workers)
test_dataloader_list = [DataLoader(FEC_Dataset(code, [std_test[ii]], len=int(args.test_batch_size), zero_cw=False),
                                    batch_size=int(args.test_batch_size), shuffle=False, num_workers=args.workers) for ii in range(len(std_test))]
#################################

print(f"Training model with code type: {args.code_type}")
print(args.code_type)


best_loss = float('inf')
for epoch in range(1, args.epochs + 1):
    loss= train(model, device, train_dataloader, optimizer,
                            epoch, LR=scheduler.get_last_lr()[0])

    scheduler.step()
    if loss < best_loss:
        best_loss = loss
        torch.save(model, os.path.join(args.path, 'best_model'))
        print(f'Model Saved')
    if epoch % 5 == 0 or epoch in [1,25]:


        test(model, device, test_dataloader_list, EbNo_range_test,min_FER=50,max_cum_count=1e6,min_cum_count=1e4)
#################################

print('Regular Reverse Diffusion')
test(model, device, test_dataloader_list, EbNo_range_test,min_FER=100)
print('Line Search Reverse Diffusion')
model.line_search = True
test(model, device, test_dataloader_list, EbNo_range_test,min_FER=100)

tensor([[[[False,  True,  True,  ...,  True,  True,  True],
          [ True, False,  True,  ...,  True,  True,  True],
          [ True,  True, False,  ...,  True,  True,  True],
          ...,
          [ True,  True,  True,  ..., False,  True,  True],
          [ True,  True,  True,  ...,  True, False,  True],
          [ True,  True,  True,  ...,  True,  True, False]]]])
DDECCT(
  (decoder): Encoder(
    (layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=32, out_features=32, bias=True)
          )
          (dropout): Dropout(p=0, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=32, out_features=128, bias=True)
          (w_2): Linear(in_features=128, out_features=32, bias=True)
          (dropout): Dropout(p=0, inplace=False)
        )
        (sublayer): ModuleList(
          (0-1): 2 x SublayerConnect

  self.pid = os.fork()



DDECC Loss
x_0:  torch.Size([128, 63])
t:  torch.Size([65])
t:  torch.Size([128])
e:  torch.Size([128, 63])
noise_factor:  torch.Size([128, 1])
h:  torch.Size([63])
yt:  torch.Size([128, 63])
sum_syndrome:  torch.Size([128])

DDECCT model
magnitude:  torch.Size([128, 63])
syndrome:  torch.Size([128, 18]) tensor([[ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 1., -1.,  1.,  ..., -1., -1.,  1.],
        [ 1., -1., -1.,  ...,  1., -1., -1.],
        ...,
        [ 1.,  1.,  1.,  ..., -1.,  1.,  1.],
        [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  ...,  1.,  1.,  1.]])
emb:  torch.Size([128, 81, 32])
time_emb:  torch.Size([128, 1, 32])
emb:  torch.Size([128, 81, 32])
emb:  torch.Size([128, 81, 32])
out_fc:  torch.Size([128, 63]) tensor([[ 1.1049,  0.3087,  0.1041,  ...,  0.1831,  1.1800,  0.4571],
        [ 0.8460, -1.2810, -0.3899,  ...,  0.6222,  0.9290, -0.2577],
        [ 0.3626,  1.1882,  0.0508,  ..., -2.5425,  0.4151,  2.8715],
        ...,
        [ 1.0635,

  self.pid = os.fork()


KeyboardInterrupt: 

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
from gnn_decoder.gnn import UpdateEmbeddings


def train_gnn(model, device, train_loader, epoch, LR):
    # model.train()
    cum_loss = cum_samples = 0
    t = time.time()
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam(learning_rate=LR)

    for batch_idx, (m, x, z, y, llr, magnitude, syndrome) in enumerate(train_loader):
        # convert to tf for GNN_BP
        llr = tf.convert_to_tensor(llr.numpy(), dtype=tf.float32)


        with tf.GradientTape() as tape:
          # model prediction
          x_hat = model(llr)

          loss = loss_fn(x, x_hat)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        cum_loss += loss * x.shape[0]
        cum_samples += x.shape[0]
        if (batch_idx+1) % 10 == 0 or batch_idx == len(train_loader) - 1:
            print(f'Training epoch {epoch}, Batch {batch_idx + 1}/{len(train_loader)}: LR={LR:.2e}, Loss={cum_loss / cum_samples:.5e}')

    print(f'Epoch {epoch} Train Time {time.time() - t}s\n')
    return cum_loss / cum_samples



def test_gnn(model, device, test_loader_list, EbNo_range_test, min_FER=100, max_cum_count=1e7, min_cum_count=1e5):
    # model.eval()
    test_loss_ber_list, test_loss_fer_list, cum_samples_all = [], [], []
    t = time.time()

    with torch.no_grad():
        for ii, test_loader in enumerate(test_loader_list):
            test_ber = test_fer = cum_count = 0.

            # _, x_pred_list, _, _ = model.p_sample_loop(next(iter(test_loader))[3])
            # test_ber_ddpm , test_fer_ddpm = [0]*len(x_pred_list), [0]*len(x_pred_list)
            idx_conv_all = []

            while True:
                (m, x, z, y, magnitude, syndrome) = next(iter(test_loader))

                # x_pred, x_pred_list, idx_conv,synd_all = model.p_sample_loop(y)
                # x_pred = sign_to_bin(torch.sign(x_pred))

                # idx_conv_all.append(idx_conv)

                for kk, x_pred_tmp in enumerate(x_pred_list):
                    x_pred_tmp = sign_to_bin(torch.sign(x_pred_tmp))

                    test_ber_ddpm[kk] += BER(x_pred_tmp, x) * x.shape[0]
                    test_fer_ddpm[kk] += FER(x_pred_tmp, x) * x.shape[0]

                test_ber += BER(x_pred, x) * x.shape[0]
                test_fer += FER(x_pred, x) * x.shape[0]
                cum_count += x.shape[0]

                if (min_FER > 0 and test_fer > min_FER and cum_count > min_cum_count) or cum_count >= max_cum_count:
                    if cum_count >= 1e9:
                        print(f'Cum count reached EbN0:{EbNo_range_test[ii]}')
                    else:
                        print(f'FER count treshold reached EbN0:{EbNo_range_test[ii]}')
                    break

            idx_conv_all = torch.stack(idx_conv_all).float()
            cum_samples_all.append(cum_count)
            test_loss_ber_list.append(test_ber / cum_count)
            test_loss_fer_list.append(test_fer / cum_count)

            for kk in range(len(test_ber_ddpm)):
                test_ber_ddpm[kk] /= cum_count
                test_fer_ddpm[kk] /= cum_count
            print(f'Test EbN0={EbNo_range_test[ii]}, BER={test_loss_ber_list}')
            print(f'Test EbN0={EbNo_range_test[ii]}, BER_DDPM={test_ber_ddpm}')
            print(f'Test EbN0={EbNo_range_test[ii]}, -ln(BER)_DDPM={[-np.log(elem) for elem in test_ber_ddpm]}')
            print(f'Test EbN0={EbNo_range_test[ii]}, FER_DDPM={test_fer_ddpm}')
            print(f'#It. to zero syndrome: Mean={idx_conv_all.mean()}, Std={idx_conv_all.std()}, Min={idx_conv_all.min()}, Max={idx_conv_all.max()}')

        print('Test FER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_fer_list, EbNo_range_test))]))
        print('Test BER ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, elem) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
        print('Test -ln(BER) ' + ' '.join(
            ['{}: {:.2e}'.format(ebno, -np.log(elem)) for (elem, ebno)
             in
             (zip(test_loss_ber_list, EbNo_range_test))]))
    print(f'# of testing samples: {cum_samples_all}\n Test Time {time.time() - t} s\n')
    return test_loss_ber_list, test_loss_fer_list



gnn = GNN_BP(code)

train_gnn(gnn, device, train_dataloader, epoch, LR=scheduler.get_last_lr()[0])
# test_gnn(gnn, device, test_dataloader_list, EbNo_range_test,min_FER=50,max_cum_count=1e6,min_cum_count=1e4)

Training epoch 1, Batch 10/1000: LR=5.00e-04, Loss=3.24446e-01
Training epoch 1, Batch 20/1000: LR=5.00e-04, Loss=1.93297e-01
Training epoch 1, Batch 30/1000: LR=5.00e-04, Loss=1.36352e-01
Training epoch 1, Batch 40/1000: LR=5.00e-04, Loss=1.05236e-01
Training epoch 1, Batch 50/1000: LR=5.00e-04, Loss=8.57611e-02
Training epoch 1, Batch 60/1000: LR=5.00e-04, Loss=7.24461e-02
Training epoch 1, Batch 70/1000: LR=5.00e-04, Loss=6.27684e-02
Training epoch 1, Batch 80/1000: LR=5.00e-04, Loss=5.54122e-02
Training epoch 1, Batch 90/1000: LR=5.00e-04, Loss=4.96271e-02
Training epoch 1, Batch 100/1000: LR=5.00e-04, Loss=4.49552e-02
Training epoch 1, Batch 110/1000: LR=5.00e-04, Loss=4.11005e-02
Training epoch 1, Batch 120/1000: LR=5.00e-04, Loss=3.78642e-02
Training epoch 1, Batch 130/1000: LR=5.00e-04, Loss=3.51074e-02
Training epoch 1, Batch 140/1000: LR=5.00e-04, Loss=3.27301e-02
Training epoch 1, Batch 150/1000: LR=5.00e-04, Loss=3.06583e-02
Training epoch 1, Batch 160/1000: LR=5.00e-04, Lo