In [1]:
import torch
import torch.nn as nn
import numpy as np

from polar import PolarCode,get_frozen
from reliability_sequence import Reliability_Sequence
from utils import errors_ber,errors_bler


In [2]:
n = 3
N = 2**n 
K = 3
snr = 1
batch_size = 1
num_samples = 1000


In [3]:

def create_data(num_samples, batch_size, n, K, snr):
    msg_bits_list = []
    bpsk_list = []
    codeword_list = []
    corrupted_codeword_list = []

    # Loop to generate data samples
    for i in range(num_samples):
        msg_bits = (torch.rand(batch_size, K) > 0.5).float()
        bpsk = 1 - 2 * msg_bits

        polar = PolarCode(n, K, Fr=None, use_cuda=True, hard_decision=True)
        codeword = polar.encode(bpsk)
        corrupted_codewords = polar.channel(codeword, snr)

        msg_bits_list.append(msg_bits.cpu().numpy())
        bpsk_list.append(bpsk.cpu().numpy())
        codeword_list.append(codeword.cpu().numpy())
        corrupted_codeword_list.append(corrupted_codewords.cpu().numpy())

    filename = f"polar_dataset_N{2**n}_K{K}_SNR{snr}_bs{batch_size}.npz"
    np.savez(f"polar_dataset_N{2**n}_K{K}_SNR{snr}_bs{batch_size}.npz", msg_bits=msg_bits_list, corrupted_codeword=corrupted_codeword_list,bpsk = bpsk_list,codeword=codeword_list)
    print(f"Dataset saved as {filename}")
    

In [4]:
create_data(num_samples,batch_size,n,K,snr)

Dataset saved as polar_dataset_N8_K3_SNR1_bs1.npz


In [5]:
df = np.load('data\polar_dataset_N8_K3_SNR1.npz')

In [6]:
from tqdm.autonotebook import tqdm
Fr = get_frozen(N, K,rs=Reliability_Sequence)
polar = PolarCode(n, K, Fr = Fr,use_cuda=False,hard_decision=True)
device = 'cpu'
ber_SC_total=0
bler_SC_total=0
x=10000
for bpsk_bits, corrupted_codeword in tqdm(zip(df['bpsk'][:x], df['corrupted_codeword'][:x]),total=len(df['bpsk'][:x])):
    bpsk_tensor = torch.tensor(bpsk_bits, dtype=torch.float32,device=device)
    corrupted_codeword_tensor = torch.tensor(corrupted_codeword, dtype=torch.float32,device=device)

    SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(corrupted_codeword_tensor, snr=snr)
    ber_SC = errors_ber(bpsk_tensor,decoded_SC_msg_bits.sign()).item()
    bler_SC = errors_bler(bpsk_tensor,decoded_SC_msg_bits.sign()).item()

    ber_SC_total+=ber_SC
    bler_SC_total+=bler_SC

  from tqdm.autonotebook import tqdm
  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:05<00:00, 177.12it/s]


In [7]:
decoded_SC_msg_bits, corrupted_codeword_tensor, bpsk_tensor

(tensor([[ 1.,  1., -1.]]),
 tensor([[-0.5889, -1.5760,  0.5610, -1.0886, -0.2946, -2.1059, -1.6572, -2.8287]]),
 tensor([[ 1.,  1., -1.]]))

In [8]:
corrupted_codeword_tensor.shape

torch.Size([1, 8])

In [9]:
ber_SC_total/len(df['msg_bits'][:x]),bler_SC_total/len(df['msg_bits'][:x])

(0.03466666740179062, 0.057)

### LSTM LOGIC1

In [9]:
class LSTMDecoder(nn.Module):
    def __init__(self,N,K,hidden_size,rs):
        super().__init__()
        self.N = N
        self.K = K
        self.hidden_size = hidden_size

        self.frozen_positions = get_frozen(self.N, self.K, rs)
        self.frozen_mask = torch.full((N,), -1, dtype=torch.int8)
        # self.frozen_mask = torch.zeros(N,dtype=torch.bool)
        self.frozen_mask[self.frozen_positions] = 1
        self.lstms = nn.ModuleList([
            nn.LSTM(input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True)
            for _ in range(N)
        ])
        self.fcs= nn.ModuleList([
            nn.Linear(hidden_size, 1) for _ in range(N)
        ])
    
    def forward(self,corrupted_codeword):
        batch_size = corrupted_codeword.size(0)
        device = corrupted_codeword.device
        x = corrupted_codeword.unsqueeze(-1) if corrupted_codeword.dim() == 2 else corrupted_codeword

        h0 = [torch.zeros(1, batch_size, self.hidden_size, device=device) for _ in range(self.N)]
        c0 = [torch.zeros(1, batch_size, self.hidden_size, device=device) for _ in range(self.N)]

        decoded_outputs = []
        for i in range(self.N):
            output, (h0[i], c0[i]) = self.lstms[i](x, (h0[i], c0[i]))
            print(output.shape)
            decoded_bits = self.fcs[i](output).squeeze(-1)
            decoded_outputs.append(decoded_bits)
            x = corrupted_codeword + (self.frozen_mask.float().to(device)*decoded_bits.sign()) #current logic
            x = x.unsqueeze(-1)
        
        decoded_outputs = torch.stack(decoded_outputs,dim=1)
        # non_frozen_mask = ~self.frozen_mask
        non_frozen_mask = (self.frozen_mask == -1).nonzero(as_tuple=True)[0]

        final_predictions = decoded_outputs[:,-1,non_frozen_mask]

        return decoded_outputs, final_predictions


### LSTM LOGIC2

In [None]:
class LSTMDecoder2(nn.Module):
    def __init__(self,N,K,hidden_size,rs):
        super().__init__()
        self.N = N
        self.K = K
        self.hidden_size = hidden_size

        self.frozen_positions = get_frozen(self.N, self.K, rs)
        self.frozen_mask = torch.full((N,), 0, dtype=torch.int8)
        # self.frozen_mask = torch.zeros(N,dtype=torch.int8)
        self.frozen_mask[self.frozen_positions] = 1
        self.lstms = nn.ModuleList([
            nn.LSTM(input_size=1, hidden_size=hidden_size, num_layers=1, batch_first=True)
            for _ in range(N)
        ])
        self.fcs= nn.ModuleList([
            nn.Linear(hidden_size+1, 1) for _ in range(N) # 1 for the frozen bit
        ])
    
    def forward(self,corrupted_codeword):
        batch_size = corrupted_codeword.size(0)
        device = corrupted_codeword.device
        x = corrupted_codeword.unsqueeze(-1) if corrupted_codeword.dim() == 2 else corrupted_codeword

        h0 = [torch.zeros(1, batch_size, self.hidden_size, device=device) for _ in range(self.N)]
        c0 = [torch.zeros(1, batch_size, self.hidden_size, device=device) for _ in range(self.N)]

        decoded_outputs = []
        for i in range(self.N):
            lstm_output, (h0[i], c0[i]) = self.lstms[i](x, (h0[i], c0[i]))

            frozen_bit = self.frozen_mask[i].float().to(device) 
            print(frozen_bit)
            frozen_bit = frozen_bit.unsqueeze(0).unsqueeze(0).expand(batch_size, lstm_output.size(1), -1) 
            print(frozen_bit.shape)           
            lstm_output_with_frozen = torch.cat([lstm_output, frozen_bit], dim=-1)
            print(lstm_output_with_frozen.shape)
            predicted_bit = self.fcs[i](lstm_output_with_frozen).squeeze(-1)
            print(predicted_bit.shape)
            decoded_outputs.append(predicted_bit)
            # x = corrupted_codeword + (self.frozen_mask.float().to(device)*decoded_bits.sign()) #current logic
            x = torch.cat([x, predicted_bit.unsqueeze(-1)], dim=-1)[:,:,-1:]
            print(F'x_{x.shape}')
        
        decoded_outputs = torch.stack(decoded_outputs,dim=1)
        # non_frozen_mask = ~self.frozen_mask
        non_frozen_mask = (self.frozen_mask == 0).nonzero(as_tuple=True)[0]

        final_predictions = decoded_outputs[:,-1,non_frozen_mask]

        return decoded_outputs, final_predictions


In [59]:
device = 'cuda'
model = LSTMDecoder2(N=N,K=K,hidden_size=32,rs=Reliability_Sequence).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [60]:
def calculate_ber_bler(predictions, targets):
    """ Calculate BER and BLER using errors_ber and errors_bler functions. """
    ber = errors_ber(targets, predictions.sign()).item()
    bler = errors_bler(targets, predictions.sign()).item()
    return ber, bler

In [61]:
num_epochs=30
data_len = len(df['msg_bits'][:900])
for epoch in range(num_epochs):
    total_loss = 0
    total_ber,total_bler =0,0
    for msg_bits, corrupted_codeword in tqdm(zip(df['msg_bits'][:900], df['corrupted_codeword'][:900]),total=data_len):
        msg_tensor = torch.tensor(msg_bits, dtype=torch.float32).to(device)
        corrupted_codeword_tensor = torch.tensor(corrupted_codeword, dtype=torch.float32).to(device)
        optimizer.zero_grad()
        decoded_outputs, final_predictions = model(corrupted_codeword_tensor)
        loss = criterion(final_predictions,msg_tensor)

        loss.backward()
        optimizer.step()
        total_loss+= loss.item()

        ber, bler = calculate_ber_bler((final_predictions>=0).float(), msg_tensor)
        total_ber += ber
        total_bler += bler
    avg_loss = total_loss / data_len
    avg_ber = total_ber / data_len
    avg_bler = total_bler / data_len
    if ((epoch+1) % 5 == 0):
         print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, BER: {avg_ber:.4f}, BLER: {avg_bler:.4f}")

100%|██████████| 900/900 [00:12<00:00, 74.35it/s]
100%|██████████| 900/900 [00:12<00:00, 73.56it/s]
100%|██████████| 900/900 [00:11<00:00, 79.73it/s]
100%|██████████| 900/900 [00:11<00:00, 81.09it/s]
100%|██████████| 900/900 [00:11<00:00, 79.08it/s]


Epoch 5/30, Loss: 0.6933, BER: 0.4981, BLER: 0.8867


100%|██████████| 900/900 [00:11<00:00, 80.25it/s]
100%|██████████| 900/900 [00:11<00:00, 79.94it/s]
100%|██████████| 900/900 [00:10<00:00, 82.36it/s]
100%|██████████| 900/900 [00:10<00:00, 83.56it/s]
100%|██████████| 900/900 [00:10<00:00, 83.60it/s]


Epoch 10/30, Loss: 0.6933, BER: 0.4930, BLER: 0.8811


100%|██████████| 900/900 [00:11<00:00, 81.28it/s]
100%|██████████| 900/900 [00:10<00:00, 82.32it/s]
100%|██████████| 900/900 [00:11<00:00, 81.55it/s]
100%|██████████| 900/900 [00:10<00:00, 82.09it/s]
100%|██████████| 900/900 [00:10<00:00, 82.65it/s]


Epoch 15/30, Loss: 0.6933, BER: 0.4930, BLER: 0.8811


100%|██████████| 900/900 [00:11<00:00, 81.69it/s]
100%|██████████| 900/900 [00:10<00:00, 83.09it/s]
100%|██████████| 900/900 [00:10<00:00, 83.30it/s]
100%|██████████| 900/900 [00:10<00:00, 83.00it/s]
100%|██████████| 900/900 [00:10<00:00, 83.31it/s]


Epoch 20/30, Loss: 0.6932, BER: 0.4930, BLER: 0.8811


100%|██████████| 900/900 [00:11<00:00, 81.66it/s]
100%|██████████| 900/900 [00:10<00:00, 82.09it/s]
100%|██████████| 900/900 [00:11<00:00, 78.16it/s]
100%|██████████| 900/900 [00:11<00:00, 79.29it/s]
100%|██████████| 900/900 [00:11<00:00, 78.92it/s]


Epoch 25/30, Loss: 0.6932, BER: 0.4930, BLER: 0.8811


100%|██████████| 900/900 [00:11<00:00, 79.10it/s]
100%|██████████| 900/900 [00:11<00:00, 78.89it/s]
100%|██████████| 900/900 [00:11<00:00, 79.57it/s]
100%|██████████| 900/900 [00:11<00:00, 80.91it/s]
100%|██████████| 900/900 [00:11<00:00, 79.93it/s]

Epoch 30/30, Loss: 0.6932, BER: 0.4930, BLER: 0.8811





In [62]:
ber_total,bler_total = 0,0
test_loader = df['msg_bits'][900:]
count = 0
with torch.no_grad():
    for msg_bits, corrupted_codeword in zip(df['msg_bits'][900:], df['corrupted_codeword'][900:]):
        count += 1
        msg_tensor = torch.tensor(msg_bits, dtype=torch.float32).to(device)
        corrupted_codeword_tensor = torch.tensor(corrupted_codeword, dtype=torch.float32).to(device)

        _,final_predictions = model(corrupted_codeword_tensor)
        if (count%20==0):
            print(f'final_predictions {(final_predictions>=0).float()}')
            print(f'msg_tensor {msg_tensor}')
            if((final_predictions>0).float()==msg_tensor).all():
                print('decoded correctly')
            else:
                print('decoded incorrectly')
            print('----------')
        ber,bler = calculate_ber_bler((final_predictions>=0).float(),msg_tensor)
        ber_total += ber
        bler_total += bler
    avg_ber = ber_total / len(test_loader)
    avg_bler = bler_total / len(test_loader)
    print(f"Test Results - BER: {avg_ber:.4f}, BLER: {avg_bler:.4f}")

final_predictions tensor([[0., 0., 0.]], device='cuda:0')
msg_tensor tensor([[1., 0., 1.]], device='cuda:0')
decoded incorrectly
----------
final_predictions tensor([[0., 0., 0.]], device='cuda:0')
msg_tensor tensor([[1., 1., 0.]], device='cuda:0')
decoded incorrectly
----------
final_predictions tensor([[0., 0., 0.]], device='cuda:0')
msg_tensor tensor([[0., 1., 0.]], device='cuda:0')
decoded incorrectly
----------
final_predictions tensor([[0., 0., 0.]], device='cuda:0')
msg_tensor tensor([[0., 1., 0.]], device='cuda:0')
decoded incorrectly
----------
final_predictions tensor([[0., 0., 0.]], device='cuda:0')
msg_tensor tensor([[0., 0., 1.]], device='cuda:0')
decoded incorrectly
----------
Test Results - BER: 0.4967, BLER: 0.8700
