<a href="https://colab.research.google.com/github/pollyjuice74/REU-LDPC-Project/blob/main/LinearTranDiff_on_5G.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Source code
!git clone https://github.com/pollyjuice74/REU-LDPC-Project
!git clone https://github.com/pollyjuice74/DDECC
# Sionna stuff
!git clone https://github.com/NVlabs/sionna.git
!pip install mitsuba
!pip install pythreejs

In [2]:
import tensorflow as tf
import random
import numpy as np
import scipy as sp
import time
import os
from scipy.sparse import issparse, csr_matrix, coo_matrix


# os.chdir('../..')
os.chdir('sionna')
from sionna.fec.ldpc.encoding import * # 5g encoder

from sionna.utils.metrics import compute_ber, compute_bler
from sionna.utils import BitErrorRate, BinarySource, ebnodb2no

from sionna.mapping import Constellation, Mapper, Demapper
from sionna.channel import AWGN
os.chdir('..')

if os.path.exists('REU-LDPC-Project'):
  os.rename('REU-LDPC-Project', 'REU_LDPC_Project')

os.chdir('REU_LDPC_Project/adv_nn')
from attention import *
from channel import *
from transformer import *
from dataset import *
from args import *
from model_functs import *
from models import *
from decoder5G import * # 5g decoder

os.chdir('../..')
from DDECC.src.codes import EbN0_to_std


In [75]:

class TransformerDiffusion( Layer ):
    def __init__(self, args):
        super().__init__()
        self.model_type = args.model_type
        self.n_steps = args.n_steps

        code = args.code
        assert issparse(code.H), "Code's pcm must be sparse."
        self.pcm = code.H

        self.m, self.n = self.pcm.shape
        self.k = self.n - self.m

        self.mask = self.create_mask(self.pcm)
        self.src_embed = tf.Variable( tf.random.uniform([1, self.n + self.m, args.d_model]), trainable=True )
        self.decoder = Transformer(args.d_model, args.heads, self.mask, args.t_layers)
        self.fc = Dense(1)
        self.to_n = Dense(1)
        self.time_embed = Embedding(args.n_steps, args.d_model)

        self.betas = tf.constant( tf.linspace(1e-3, 1e-2, args.n_steps)*0 + args.sigma )
        self.betas_bar = tf.constant( tf.math.cumsum(self.betas, 0) )
        self.ls_active = args.ls_active

    # Extracts noise estimate z_hat from r
    def tran_call(self, r_t, t):
        # Make sure r_t and t are compatible
        r_t = tf.reshape(r_t, (self.n, -1)) # (n,b)
        t = tf.cast(t, dtype=tf.int32)

        # Compute synd and magn
        syndrome = tf.reshape( self.get_syndrome(llr_to_bin(r_t)), (self.pcm.shape[0], -1) ) # (m,n)@(n,b)->(m,b) check nodes
        magnitude = tf.reshape( tf.abs(r_t), (self.n, -1) ) #(n,b) variable nodes
        # make sure their the same dtype
        magnitude, syndrome = [ tf.cast(tensor, dtype=tf.float32) for tensor in [magnitude, syndrome] ]

        # Concatenate synd and magn
        nodes = tf.concat([magnitude, syndrome], axis=0) # data for vertices
        nodes = tf.reshape(nodes, (1, self.pcm.shape[0]+self.n, -1)) # (1, n+m, b)
        print(nodes.shape)

        # Embedding nodes w/ attn and 'time' (sum syn errs) dims
        nodes_emb = tf.reshape( self.src_embed * nodes, (self.src_embed.shape[-1], self.pcm.shape[0]+self.n, -1) ) # (d,n+m,b)
        time_emb = tf.reshape( self.time_embed(t), (self.src_embed.shape[-1], 1, -1) ) # (d,1,b)

        # Applying embeds
        emb_t = time_emb * nodes_emb # (d, n+m, b)
        logits = self.decoder(emb_t) # (d, n+m, d) # TODO: missing batch dims b
        print(emb_t.shape, logits.shape)

        # Reduce (d,n+m,d)->(d,n+m)
        logits = tf.squeeze( self.fc(logits), axis=-1 )
        node_logits = tf.reshape( logits[:, :self.n], (self.n, -1) ) # (n,d) take the first n logits from the concatenation
        # (n,d)->(n,)
        z_hat = self.to_n(node_logits)
        print(logits.shape, z_hat.shape)
        return z_hat

    # optimal lambda l for theoretical and for error prediction
    def line_search(self, r_t, sigma, err_hat, lin_splits=20):
        l_values =  tf.reshape( tf.linspace(1., 20., lin_splits), (1, 1, lin_splits) )
        r_t, sigma, err_hat = [ tf.expand_dims(tensor, axis=-1) for tensor in [r_t, sigma, err_hat] ]# (n,b, 1)

        # Compute theoretical step size w/ ls splits
        z_hat_values = l_values*(sigma*err_hat) # (n,b, l), l is lin_splits
        r_values = llr_to_bin(r_t - z_hat_values) # (n,b, l)
        # sum of synds (m,n)@(n,b*l)->(m,b*l)->(b*l, 1)
        sum_synds = tf.reduce_sum( self.pcm.dot( tf.squeeze(r_values, axis=1) ) % 2, axis=0 )[:, tf.newaxis]
        print(sum_synds.shape)

        # Pick optimal ls value
        if self.model_type=='dis':
             ixs = tf.math.argmin(sum_synds, axis=-1, output_type=tf.int32) # (b,1) w/ ixs of optimal line search for batch b
        elif self.model_type=='gen':
             ixs = tf.math.argmax(sum_synds, axis=-1, output_type=tf.int32) # (b,1)

        print(r_values.shape, z_hat_values.shape)
        # (b, l, n) for indexing on l
        r_values, z_hat_values = [ tf.transpose(tensor, perm=[1,2,0])
                                            for tensor in [r_values, z_hat_values] ]

        # concat range of batch ixs [0,...,n-1] and optimal line search ixs in gather_nd
        indices = tf.concat([ tf.range(ixs.shape[0])[:, tf.newaxis], ixs], axis=-1) # (b,2)

        # print(r_values, z_hat_values, indices)
        # ix on lin_splits w/ gather_nd st. ix,(b, l, n)->(n,b)
        r_t1, z_hat = [ tf.reshape( tf.gather_nd(tensor, indices), (self.n, -1) )
                                             for tensor in [r_values, z_hat_values] ]
        return r_t1, z_hat # r at t-1

    def train(self, c_0, struct_noise=0, sim_ampl=True):
        t = tf.random.uniform( (c_0.shape[0] // 2 + 1,), minval=0,maxval=self.n_steps, dtype=tf.int32 )
        t = tf.concat([t, self.n_steps - t - 1], axis=0)[:c_0.shape[0]] # reshapes t to size x_0
        t = tf.cast(t, dtype=tf.int32)

        noise_factor = tf.math.sqrt( tf.gather(self.betas_bar, t) )
        noise_factor = tf.reshape(noise_factor, (-1, 1))
        z = tf.random.normal(c_0.shape)
        h = np.random.rayleigh(size=c_0.shape)if sim_ampl else 1.

        # added noise to codeword
        c_t = tf.transpose(h * c_0 + struct_noise + (z*noise_factor))
        # calculate sum of syndrome
        t = tf.math.reduce_sum( self.get_syndrome( llr_to_bin(tf.sign(c_t)) ), axis=0 ) # (batch_size, 1)

        z_hat = self.tran_call(c_t, t) # model prediction

        if self.model_type=='dis':
            z_mul = c_t * tf.transpose(c_0) # actual noise added through the channel

        elif self.model_type=='gen':
            c_t += z_hat # could contain positive or negative values
            z_mul = c_t * tf.transpose(c_0) # moidfied channel noise st. it will fool the discriminator

        z_mul = tf.reshape(z_mul, (z_hat.shape[0], -1))
        return z_hat, llr_to_bin(z_mul), c_t

    def create_mask(self, H):
        m,n = H.shape
        mask = tf.eye(n+m, dtype=tf.float32) # (n+m, n+m)
        cn_con, vn_con, _ = sp.sparse.find(H)

        for cn, vn_i in zip(cn_con, vn_con):
            # cn to vn connections in the mask
            mask = tf.tensor_scatter_nd_update(mask, [[n+cn, vn_i],[vn_i, n+cn]], [1.0,1.0])

            # distance 2 vn neighbors of vn_i
            related_vns = vn_con[cn_con==cn]
            for vn_j in related_vns:
                mask = tf.tensor_scatter_nd_update(mask, [[vn_i, vn_j],[vn_j, vn_i]], [1.0,1.0])

        # -infinity where mask is not set
        mask = tf.where(mask == 1.0,
                        mask, -tf.constant(float('inf'), dtype=tf.float32))
        return mask

    def get_sigma(self, t):
        # make sure t is a positive int
        t = tf.cast( tf.abs(t), tf.int32 )
        # gather betas
        betas_t = tf.gather(self.betas, t)
        betas_bar_t = tf.gather(self.betas_bar, t)

        return betas_bar_t * betas_t / (betas_bar_t + betas_t)

    def get_syndrome(self, r_t):
        # Calculate syndrome (pcm @ r = 0) if r is correct in binary
        r_t = tf.reshape(r_t, (self.n, -1)) # (n,b)
        return self.pcm.dot( llr_to_bin( r_t ).numpy() ) % 2 # (m,n)@(n,b)->(m,b)




class Decoder( TransformerDiffusion ):
    def __init__(self, args):
        super().__init__(args)

    # 'test' function
    def call(self, r_t):
       for i in range(self.m):
           print(r_t.shape)
           r_t, z_hat = self.rev_diff_call(r_t) # both (n,)

           # Check if synd is 0 return r_t
           if tf.reduce_sum( self.get_syndrome(r_t) ) == 0:
               return r_t, z_hat, i

       return r_t, z_hat, i

    # Refines recieved codeword r at time t
    def rev_diff_call(self, r_t):
        print("Rev def call...")
        # Make sure r_t and t are compatible
        r_t = tf.reshape(r_t, (self.n, -1)) # (n,b)
        # 'time step' of diffusion is really ix of abs(sum synd errors)
        t = tf.reduce_sum( self.get_syndrome(llr_to_bin(r_t)), axis=0 ) # (m,n)@(n,b)->(m,b)->(1,b)
        t = tf.cast(tf.abs(t), dtype=tf.int32)

        # Transformer error prediction
        z_hat_crude = self.tran_call(r_t, t) # (n,1)

        # Compute diffusion vars
        sigma = self.get_sigma(t) # theoretical step size
        print(r_t.shape, z_hat_crude.shape, r_t.shape)
        err_hat = r_t - tf.sign(z_hat_crude * r_t) # (n,1)

        # Refined estimate of the codeword for the ls diffusion step
        r_t1, z_hat = self.line_search(r_t, sigma, err_hat) if self.ls_active else 1.
        # r_t1[t==0] = r_t[t==0] # if cw has 0 synd. keep as is

        return r_t1, z_hat # r at t-1, both (n,1)





args = Args(model_type='dis', code_type='LDPC5G') # args for decoder/discriminator

# Define enc/dec layers #
enc5G = LDPC5GEncoder(args.k, args.n)
dec5G = LDPC5GDecoder(enc5G, args)

binary_source = BinarySource()

# initialize mapper and demapper for constellation object
constellation = Constellation("qam", num_bits_per_symbol=4)
mapper = Mapper(constellation=constellation)
demapper = Demapper("app", constellation=constellation)

channel = AWGN() # replace w/ Generator(args)

no = ebnodb2no(1, 11, args.k/args.n) # eb_no=1, bps=4
no = tf.expand_dims(tf.cast(no, tf.float32), axis=-1)


# Simulate #
u = binary_source([1, args.k])
c = enc5G(u) # (1,n)

x = mapper(c) # map c to symbols x
y = channel([x, no]) # transmit over AWGN channel
llr_ch = demapper([y, no]) # demap y to LLRs (1,n)

llr_5g = dec5G(llr_ch) # u_hat = dec5G(llr_ch) # run FEC decoder (incl. rate-recovery)

args.code.H = dec5G.pcm
dec = Decoder(args)

dec(llr_5g)



(1, 512)
Rev def call...
(1, 764, 1)
(128, 764, 1) (128, 764, 128)
(128, 764) (512, 1)
(512, 1) (512, 1) (512, 1)
(20, 1)
(512, 1, 20) (512, 1, 20)


InvalidArgumentError: Exception encountered when calling layer 'decoder_32' (type Decoder).

{{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [20,1] vs. shape[1] = [20] [Op:ConcatV2] name: concat

Call arguments received by layer 'decoder_32' (type Decoder):
  • r_t=tf.Tensor(shape=(1, 512), dtype=float32)

In [53]:
cn_con, vn_con, vals = sp.sparse.find(dec5G.pcm)
cn_con, vn_con

H = [[1,0,1,1,1,0,0],
     [0,1,0,1,1,1,0],
     [0,0,1,0,1,1,1]]
H = csr_matrix(H)

def get_syndrome(pcm, r_t):
        # Calculate syndrome (pcm @ r = 0) if r is correct in binary
        r_t = tf.reshape(r_t, (pcm.shape[1], -1)) # (n,b)
        return pcm.dot(llr_to_bin(r_t).numpy()) % 2

get_syndrome(dec5G.pcm, llr_5g)

(252, 1)

In [None]:
class LDPC5GDecoder(LDPCBPDecoder):
    def __init__(self,
                 encoder,
                 args,
                 trainable=False,
                 cn_type='boxplus-phi',
                 hard_out=True,
                 track_exit=False,
                 return_infobits=True,
                 prune_pcm=True,
                 num_iter=20,
                 stateful=False,
                 output_dtype=tf.float32,
                 **kwargs):

        # needs the 5G Encoder to access all 5G parameters
        assert isinstance(encoder, LDPC5GEncoder), 'encoder must \
                          be of class LDPC5GEncoder.'
        self._encoder = encoder
        pcm = encoder.pcm

        assert isinstance(return_infobits, bool), 'return_info must be bool.'
        self._return_infobits = return_infobits

        assert isinstance(output_dtype, tf.DType), \
                                'output_dtype must be tf.DType.'
        if output_dtype not in (tf.float16, tf.float32, tf.float64):
            raise ValueError(
                'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
        self._output_dtype = output_dtype

        assert isinstance(stateful, bool), 'stateful must be bool.'
        self._stateful = stateful

        assert isinstance(prune_pcm, bool), 'prune_pcm must be bool.'
        # prune punctured degree-1 VNs and connected CNs. A punctured
        # VN-1 node will always "send" llr=0 to the connected CN. Thus, this
        # CN will only send 0 messages to all other VNs, i.e., does not
        # contribute to the decoding process.
        self._prune_pcm = prune_pcm
        if prune_pcm:
            # find index of first position with only degree-1 VN
            dv = np.sum(pcm, axis=0) # VN degree
            last_pos = encoder._n_ldpc
            for idx in range(encoder._n_ldpc-1, 0, -1):
                if dv[0, idx]==1:
                    last_pos = idx
                else:
                    break
            # number of filler bits
            k_filler = self.encoder.k_ldpc - self.encoder.k
            # number of punctured bits
            nb_punc_bits = ((self.encoder.n_ldpc - k_filler)
                                     - self.encoder.n - 2*self.encoder.z)
            # effective codeword length after pruning of vn-1 nodes
            self._n_pruned = np.max((last_pos, encoder._n_ldpc - nb_punc_bits))
            self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned
            # remove last CNs and VNs from pcm
            pcm = pcm[:-self._nb_pruned_nodes, :-self._nb_pruned_nodes]

            #check for consistency
            assert(self._nb_pruned_nodes>=0), "Internal error: number of \
                        pruned nodes must be positive."
        else:
            self._nb_pruned_nodes = 0
            # no pruning; same length as before
            self._n_pruned = encoder._n_ldpc

        # DECODER
        super().__init__(pcm,
                         trainable,
                         cn_type,
                         hard_out,
                         track_exit,
                         num_iter=num_iter,
                         stateful=stateful,
                         output_dtype=output_dtype,
                         **kwargs)
        # args.code.H = pcm
        # self._decoder = Decoder(args)

    #########################################
    # Public methods and properties
    #########################################

    @property
    def encoder(self):
        """LDPC Encoder used for rate-matching/recovery."""
        return self._encoder

    #########################
    # Keras layer functions
    #########################

    def build(self, input_shape):
        """Build model."""
        if self._stateful:
            assert(len(input_shape)==2), \
                "For stateful decoding, a tuple of two inputs is expected."
            input_shape = input_shape[0]

        # check input dimensions for consistency
        assert (input_shape[-1]==self.encoder.n), \
                                'Last dimension must be of length n.'
        assert (len(input_shape)>=2), 'The inputs must have at least rank 2.'

        self._old_shape_5g = input_shape

    def call(self, inputs):
        """Iterative BP decoding function.

        This function performs ``num_iter`` belief propagation decoding
        iterations and returns the estimated codeword.

        Args:
            inputs (tf.float32): Tensor of shape `[...,n]` containing the
                channel logits/llr values.

        Returns:
            `tf.float32`: Tensor of shape `[...,n]` or `[...,k]`
            (``return_infobits`` is True) containing bit-wise soft-estimates
            (or hard-decided bit-values) of all codeword bits (or info
            bits, respectively).

        Raises:
            ValueError: If ``inputs`` is not of shape `[batch_size, n]`.

            ValueError: If ``num_iter`` is not an integer greater (or equal)
                `0`.

            InvalidArgumentError: When rank(``inputs``)<2.
        """

        # Extract inputs
        if self._stateful:
            llr_ch, msg_vn = inputs
        else:
            llr_ch = inputs

        tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.')

        llr_ch_shape = llr_ch.get_shape().as_list()
        new_shape = [-1, llr_ch_shape[-1]]
        llr_ch_reshaped = tf.reshape(llr_ch, new_shape)
        batch_size = tf.shape(llr_ch_reshaped)[0]

        # invert if rate-matching output interleaver was applied as defined in
        # Sec. 5.4.2.2 in 38.212
        if self._encoder.num_bits_per_symbol is not None:
            llr_ch_reshaped = tf.gather(llr_ch_reshaped,
                                        self._encoder.out_int_inv,
                                        axis=-1)


        # undo puncturing of the first 2*Z bit positions
        llr_5g = tf.concat(
            [tf.zeros([batch_size, 2*self.encoder.z], self._output_dtype),
                          llr_ch_reshaped],
                          1)

        # undo puncturing of the last positions
        # total length must be n_ldpc, while llr_ch has length n
        # first 2*z positions are already added
        # -> add n_ldpc - n - 2Z punctured positions
        k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits
        nb_punc_bits = ((self.encoder.n_ldpc - k_filler)
                                     - self.encoder.n - 2*self.encoder.z)


        llr_5g = tf.concat([llr_5g,
                   tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes],
                            self._output_dtype)],
                            1)

        # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max)
        # the first k positions are the systematic bits
        x1 = tf.slice(llr_5g, [0,0], [batch_size, self.encoder.k])

        # parity part
        nb_par_bits = (self.encoder.n_ldpc - k_filler
                       - self.encoder.k - self._nb_pruned_nodes)
        x2 = tf.slice(llr_5g,
                      [0, self.encoder.k],
                      [batch_size, nb_par_bits])

        # negative sign due to logit definition
        z = -tf.cast(self._llr_max, self._output_dtype) \
            * tf.ones([batch_size, k_filler], self._output_dtype)

        llr_5g = tf.concat([x1, z, x2], 1)

        return llr_5g

        # ############################################################
        # # and execute the decoder
        # print(llr_5g.shape)
        # x_hat = self._decoder(llr_5g) #super().call(llr_5g)
        # ############################################################

        # if self._return_infobits: # return only info bits
        #     # reconstruct u_hat # code is systematic
        #     u_hat = tf.slice(x_hat, [0,0], [batch_size, self.encoder.k])
        #     # Reshape u_hat so that it matches the original input dimensions
        #     output_shape = llr_ch_shape[0:-1] + [self.encoder.k]
        #     # overwrite first dimension as this could be None (Keras)
        #     output_shape[0] = -1
        #     u_reshaped = tf.reshape(u_hat, output_shape)

        #     # enable other output datatypes than tf.float32
        #     u_out = tf.cast(u_reshaped, self._output_dtype)

        #     if not self._stateful:
        #         return u_out
        #     else:
        #         return u_out, msg_vn

        # else: # return all codeword bits
        #     # the transmitted CW bits are not the same as used during decoding
        #     # cf. last parts of 5G encoding function

        #     # remove last dim
        #     x = tf.reshape(x_hat, [batch_size, self._n_pruned])

        #     # remove filler bits at pos (k, k_ldpc)
        #     x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k])

        #     x_no_filler2 = tf.slice(x,
        #                             [0, self.encoder.k_ldpc],
        #                             [batch_size,
        #                             self._n_pruned-self.encoder.k_ldpc])

        #     x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1)

        #     # shorten the first 2*Z positions and end after n bits
        #     x_short = tf.slice(x_no_filler,
        #                        [0, 2*self.encoder.z],
        #                        [batch_size, self.encoder.n])

        #     # if used, apply rate-matching output interleaver again as
        #     # Sec. 5.4.2.2 in 38.212
        #     if self._encoder.num_bits_per_symbol is not None:
        #         x_short = tf.gather(x_short, self._encoder.out_int, axis=-1)

        #     # Reshape x_short so that it matches the original input dimensions
        #     # overwrite first dimension as this could be None (Keras)
        #     llr_ch_shape[0] = -1
        #     x_short= tf.reshape(x_short, llr_ch_shape)

        #     # enable other output datatypes than tf.float32
        #     x_out = tf.cast(x_short, self._output_dtype)

        #     if not self._stateful:
        #         return x_out
        #     else:
        #         return x_out, msg_vn

In [None]:
## DATA ###
EbNo_range_train = range(2, 8)
EbNo_range_test = range(5, 14)
# Standard deviation for train/test
std_train = [EbN0_to_std(ii, args.k / args.n) for ii in EbNo_range_train]
std_test = [EbN0_to_std(ii, args.k / args.n) for ii in EbNo_range_test]

scheduler = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=args.lr, decay_steps=args.epochs) # 1000 is size of trainloader
optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler)

train_loader = FEC_Dataset(args.code, sigma=std_test, zero_cw=True, length=args.traindata_len).batch(args.batch_size).shuffle(buffer_size=args.batch_size)

#                         z_cw   m 1s   1-cw     Should use zero codeword by default
dataset_types = {
              "bin_bits":(False, False, False), # Binary bits sent and recieved with some awgn
              "flip_cw": (True, False, True),   # Zero codeword flipped to a all ones vector [1,1,...,1]
              "zero_cw": (True, False, False),  # Standard zero codeword used for training
              "ones_m":  (False, True, False),  # Makes the message all ones vector and passes it to generator matrix producing codeword and pcm
              }

test_ebnos_datasets = [ [FEC_Dataset(args.code, sigma=std_test, zero_cw=zero_cw, ones_m=ones_m, flip_cw=flip_cw)
                                 for ii in range(len(std_test))]
                                 for (zero_cw, ones_m, flip_cw) in dataset_types.values() ]

In [None]:

for epoch in range(1, args.epochs + 1):
    print("Training Linear Transformer Diffusion Model...")
    # train_dec(dec, train_loader, optimizer, epoch,
    #           LR=scheduler(tf.Variable(0, dtype=tf.float32)).numpy(),
    #           traindata_len=args.traindata_len)

    # print comparison
    if epoch % 1 == 0:
        data = test_models(dec, test_ebnos_datasets, EbNo_range_test)
    break # from for loop

data

In [None]:
import matplotlib.pyplot as plt
def plot_comparison(title, EbNo_range, ber, fer, diff_iters):
    plt.figure(figsize=(12, 6))

    # Plot diff_iters to decoding
    plt.subplot(1, 2, 1)
    plt.plot(EbNo_range, diff_iters, marker='o', label='diff_iters to decode', color='blue')
    plt.yscale('log')
    plt.xlabel('Eb/No (dB)')
    plt.ylabel('diff_iters to decode')
    plt.title('diff_iters to decode vs Eb/No')
    plt.grid(True, which="both", ls="--")
    plt.legend()

    # Plot BER for both models
    plt.subplot(1, 2, 2)
    plt.plot(EbNo_range, ber, marker='x', label='BER', color='green')
    plt.yscale('log')
    plt.xlabel('Eb/No (dB)')
    plt.ylabel('BER')
    plt.title('BER vs Eb/No')
    plt.grid(True, which="both", ls="--")
    plt.legend()

    # Plot FER for both models
    plt.subplot(1, 2, 3)
    plt.plot(EbNo_range, fer, marker='o', label='FER', color='red')
    plt.yscale('log')
    plt.xlabel('Eb/No (dB)')
    plt.ylabel('FER')
    plt.title('FER vs Eb/No')
    plt.grid(True, which="both", ls="--")
    plt.legend()

    # Set the overall title for the figure
    plt.suptitle(title)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


for ix, dataset_type in enumerate(dataset_types.keys()):
    ber = data['LTDM'][ix]['ber']
    fer = data['LTDM'][ix]['bler']
    diff_iters = data['LTDM'][ix]['diff_iters']

    plot_comparison(dataset_type.upper(), EbNo_range_test, ber, fer, diff_iters)