In [695]:
from polar_codes.polar_code import PolarCode
from polar_codes.channels.bpsk_awgn_channel import BpskAwgnChannel
import numpy as np
from joblib import Parallel, delayed
from multiprocessing import cpu_count
from tqdm import tqdm
import pickle

In [669]:
channel = BpskAwgnChannel(3)

In [670]:
ebnodb = np.arange(0,6.5, 0.5)
ebno = 10**(ebnodb/10)
snr = ebno
sinr = ebno/2
sinrdb = 10*np.log10(sinr)


In [671]:
sinrdb

array([-3.01029996, -2.51029996, -2.01029996, -1.51029996, -1.01029996,
       -0.51029996, -0.01029996,  0.48970004,  0.98970004,  1.48970004,
        1.98970004,  2.48970004,  2.98970004])

In [672]:
n = 9
K = 256
code = PolarCode(n=n, K=K,
                 construction_method='PW',
                 channel=channel, CRC_len=24)
R = K / 2**n 

115 информационных бит для передачи

In [673]:
u_message = np.asarray([0 if np.random.random_sample() > 0.5 else 1 for _ in range(0, K)], dtype='uint8')
u_message

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0,
       1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0,
       0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1,
       0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1,
       1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,
       0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
       1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
       0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1], dtype=uint8)

In [674]:
x_message = code.encode(u_message)

In [675]:
to_message = channel.modulate(x_message)
from_message = channel.transmit(to_message)
y_message = channel.demodulate(from_message)

In [676]:
%%time
scl_u_est_message = code.decode(y_message, decoding_method='SCL', list_size=32)
#scl_u_est_message

CPU times: user 1.17 s, sys: 6.7 ms, total: 1.18 s
Wall time: 1.19 s


In [678]:
class rec_dec:
    def __init__(self, code, y_message, i, u_est, is_calc_llr, llr_array, pm=0, PM=0):
        self.u_est = u_est
        self.is_calc_llr = is_calc_llr
        self.llr_array = llr_array
        self.y_message = y_message
        self.pm = pm #Iddooose metric
        self.layer = i
        self.code = code
        self.PM = PM #paper_metric
    
    def split(self):
        #Создаем 2 копии этого же instance c разными параметрами
        u_right = self.u_est.copy()
        u_opp = self.u_est.copy()
        
        self.llr = code._fast_llr(self.layer, self.y_message, self.u_est[:self.layer], self.llr_array, self.is_calc_llr) 
        
        u_right[self.layer] = 0 if self.llr > 0 else 1
        u_opp[self.layer] = 1 - u_right[self.layer]
        
        pm_copy_r = self.pm
        pm_copy_opp = self.pm
        llr_copy = self.llr
        
        layer_copy = self.layer
        
        #calc 2 PM and provide them
        PM_right = np.log(1+np.exp(-(1-2*u_right[self.layer])*self.llr))
        new_PM_right = self.PM+PM_right
        
        PM_opp = np.log(1+np.exp(-(1-2*u_opp[self.layer])*self.llr)) 
        new_PM_opp = self.PM+PM_opp
        
        right = rec_dec(self.code, self.y_message, layer_copy+1, u_right, self.is_calc_llr.copy(), self.llr_array.copy(), pm_copy_r, new_PM_right)
        opp = rec_dec(self.code, self.y_message, layer_copy+1, u_opp, self.is_calc_llr.copy(), self.llr_array.copy(), pm_copy_opp+np.abs(llr_copy), new_PM_opp)
        
        return [right, opp]
    
    def meet_frozen(self):
        self.llr = code._fast_llr(self.layer, self.y_message, self.u_est[:self.layer], self.llr_array, self.is_calc_llr) 
        self.u_est[self.layer] = 0
        self.pm += np.abs(self.llr) if (self.llr < 0) else 0 
        self.PM += np.log(1+np.exp(-(1-2*self.u_est[self.layer])*self.llr)) 
        self.layer +=1
    
    def get_pm(self):
        return self.pm
    
    def get_PM(self):
        return self.PM
    
    def get_is_calc(self):
        return self.is_calc_llr
    
    def get_l(self):
        return self.layer
    
    def get_u_est(self):
        return self.u_est
    
    def get_llr_array(self):
        return self.llr_array

In [687]:
def scl_dec(code, y_message, list_size=32, train=None): #we should decide here whether to by Indooose Metric or paper's
    u_est = np.full(code._N, -1)
    is_calc_llr = [False] * code._N * (code._n + 1)
    llr_array = np.full(code._N * (code._n + 1), 0.0, dtype=np.longfloat)
    dec_array = [] # will store L decoders always
    a = rec_dec(code, y_message, 0, u_est, is_calc_llr, llr_array)
    dec_array.append(a)
    M = np.zeros(code._N)
    
    bit_for_flip_train = None
    
    for i in range(code._N): #code._N
        
        if i in code._frozen_bits_positions:
            for elem in dec_array:
                 elem.meet_frozen()
                
        else:
            new_arr = []
            for elem in dec_array:
                temp = elem.split()
                new_arr.append(temp[0])
                new_arr.append(temp[1])
            dec_array = new_arr
            
        
        if len(dec_array) > list_size:
            dec_array.sort(key=lambda x: x.get_PM())
            M[i] = np.sum([x.get_PM() for x in dec_array[list_size:]]) - np.sum([x.get_PM() for x in dec_array[:list_size]])
            
        else:
            M[i] = - np.sum([x.get_PM() for x in dec_array[:list_size]])
        
        if len(dec_array) > list_size:      #pruning from 2L to L decoders
            dec_array.sort(key=lambda x: x.get_PM()) # by paper (mb create new decoder parametr)
            
            if (train is not None) and (bit_for_flip_train is None): #should save i, if right solution among discarded
                
                for inst in dec_array[list_size:]:
                    if (train[:i+1] == inst.get_u_est()[:i+1]).all():
                        bit_for_flip_train = i      
                
            dec_array = dec_array[:list_size] # delete discarded paths
            
    return M, bit_for_flip_train

In [680]:
%%time
c, M, bit = scl_dec(code, y_message, 32, code.extend_info_bits(u_message))

CPU times: user 2.46 s, sys: 667 ms, total: 3.12 s
Wall time: 3.15 s


In [681]:
bit

In [682]:
[x.get_PM() for x in c]

[60.657247999446670922,
 75.6749181383022022,
 75.674918138302777915,
 75.674918138302777915,
 75.67491813830310373,
 83.183753206328610315,
 83.18375320651747228,
 83.18375320651747228,
 83.18375320681818031,
 83.18375320681818031,
 83.18375320684624584,
 83.18375320684624584,
 83.183753207105219904,
 83.183753207105219904,
 83.18375320750011411,
 83.18375320750011411,
 83.183753207667399325,
 83.183753207667399325,
 83.183753207667399325,
 83.183753207667399325,
 83.18375320769525464,
 83.18375320769525464,
 83.18375320769525464,
 83.18375320769525464,
 83.18375320769540287,
 83.18375320769540287,
 83.18375320769813853,
 83.18375320769813853,
 83.18375320769813853,
 83.18375320769813853,
 83.18375320769815924,
 83.18375320769815924]

In [683]:
for d in c:
    print(np.abs(u_message - code.get_message_info_bits(d.get_u_est())).sum())

0
4
8
8
4
4
12
16
4
1
4
4
8
4
4
6
16
10
10
4
16
8
16
8
10
4
10
4
6
4
9
4


In [691]:
sinrdb

array([-3.01029996, -2.51029996, -2.01029996, -1.51029996, -1.01029996,
       -0.51029996, -0.01029996,  0.48970004,  0.98970004,  1.48970004,
        1.98970004,  2.48970004,  2.98970004])

In [None]:
data_M_sinr = []
data_bit_sinr []

for sinrdb_i in sinrdb:
    channel = BpskAwgnChannel(sinrdb_i)
    n = 9
    K = 256
    code = PolarCode(n=n, K=K,
                     construction_method='PW',
                     channel=channel, CRC_len=24)

    
    for _ in range(5000):
        u_message = np.asarray([0 if np.random.random_sample() > 0.5 else 1 for _ in range(0, K)], dtype='uint8') # x100
        x_message = code.encode(u_message)
        to_message = channel.modulate(x_message)
        from_message = channel.transmit(to_message)
        y_message = channel.demodulate(from_message)
        M, bit = scl_dec(code, y_message, 32, code.extend_info_bits(u_message))
        data_M_sinr.append(M)
        data_bit_sinr.append(bit)

In [688]:
def calc_M(sinrdb_i):
    channel = BpskAwgnChannel(sinrdb_i)
    n = 9
    K = 256
    code = PolarCode(n=n, K=K,
                     construction_method='PW',
                     channel=channel, CRC_len=24)

    u_message = np.asarray([0 if np.random.random_sample() > 0.5 else 1 for _ in range(0, K)], dtype='uint8') # x100
    x_message = code.encode(u_message)
    to_message = channel.modulate(x_message)
    from_message = channel.transmit(to_message)
    y_message = channel.demodulate(from_message)
    M, bit = scl_dec(code, y_message, 32, code.extend_info_bits(u_message))
    return M, bit

In [690]:
%%time
calc_M(3)

CPU times: user 2.58 s, sys: 702 ms, total: 3.29 s
Wall time: 3.32 s


(array([-6.93147181e-01, -1.38629436e+00, -2.07944154e+00, -2.77758299e+00,
        -3.47073017e+00, -4.16387735e+00, -4.85702453e+00, -5.45525224e+00,
        -6.14839942e+00, -6.84154660e+00, -7.53469378e+00, -8.42764362e+00,
        -9.12079080e+00, -1.00137406e+01, -1.07068878e+01, -1.10454651e+01,
        -1.17386122e+01, -1.24317594e+01, -1.31249066e+01, -1.38180538e+01,
        -1.45112010e+01, -1.48994634e+01, -1.63382798e+01, -1.65632688e+01,
        -1.72564160e+01, -1.79495632e+01, -1.86427104e+01, -1.87775621e+01,
        -1.90162621e+01, -1.90209316e+01, -1.90372746e+01, -1.90372752e+01,
        -1.97304223e+01, -2.04235695e+01, -2.11167167e+01, -2.18098639e+01,
        -2.25030111e+01, -2.27402680e+01, -2.34334152e+01, -2.34518420e+01,
        -2.41449892e+01, -2.48381364e+01, -2.50266712e+01, -2.50366011e+01,
        -2.51118591e+01, -2.51122260e+01, -2.51125489e+01, -2.51125489e+01,
        -2.58056961e+01, -2.81307557e+01, -2.86032952e+01, -2.86571825e+01,
        -2.8

In [701]:
%%time
data = []
for sinrdb_i in sinrdb:
    data.append(Parallel(n_jobs=cpu_count())(delayed(calc_M)(sinrdb_i) for _ in tqdm(range(5000))))
with open('lstm_train.pickle', 'wb') as handle:
    pickle.dump(data, handle)

100%|██████████| 2/2 [00:00<00:00, 1006.43it/s]
100%|██████████| 2/2 [00:00<00:00, 1174.55it/s]
100%|██████████| 2/2 [00:00<00:00, 1901.32it/s]
100%|██████████| 2/2 [00:00<00:00, 2826.35it/s]
100%|██████████| 2/2 [00:00<00:00, 1693.64it/s]
100%|██████████| 2/2 [00:00<00:00, 3010.99it/s]
100%|██████████| 2/2 [00:00<00:00, 1858.77it/s]
100%|██████████| 2/2 [00:00<00:00, 2367.66it/s]
100%|██████████| 2/2 [00:00<00:00, 1796.66it/s]
100%|██████████| 2/2 [00:00<00:00, 3669.56it/s]
100%|██████████| 2/2 [00:00<00:00, 1976.11it/s]
100%|██████████| 2/2 [00:00<00:00, 1798.97it/s]
100%|██████████| 2/2 [00:00<00:00, 1638.08it/s]


CPU times: user 553 ms, sys: 64.9 ms, total: 618 ms
Wall time: 42.7 s
