In [None]:
%matplotlib inline
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
try_cuda = True
use_cuda = try_cuda and torch.cuda.is_available()

In [None]:
class SameShapeConv1d(torch.nn.Module):
    def __init__(self, num_layer, in_channels, out_channels, kernel_size, activation = 'elu', no_act = False):
        super(SameShapeConv1d, self).__init__()

        self.cnns = torch.nn.ModuleList()
        self.num_layer = num_layer
        self.no_act = no_act
        for idx in range(num_layer):
            if idx == 0:
                self.cnns.append(torch.nn.Conv1d(in_channels = in_channels, out_channels=out_channels,
                                                      kernel_size=kernel_size, stride=1, padding=(kernel_size // 2),
                                                      dilation=1, groups=1, bias=True)
                )
            else:
                self.cnns.append(torch.nn.Conv1d(in_channels = out_channels, out_channels=out_channels,
                                                      kernel_size=kernel_size, stride=1, padding=(kernel_size // 2),
                                                      dilation=1, groups=1, bias=True)
                )

        if activation == 'elu':
            self.activation = F.elu
        elif activation == 'relu':
            self.activation = F.relu
        elif activation == 'selu':
            self.activation = F.selu
        elif activation == 'prelu':
            self.activation = F.prelu
        else:
            self.activation = F.elu

    def forward(self, inputs):
        inputs = torch.transpose(inputs, 1,2)
        x = inputs
        for idx in range(self.num_layer):
            if self.no_act:
                x = self.cnns[idx](x)
            else:
                x = self.activation(self.cnns[idx](x))

        outputs = torch.transpose(x, 1,2)
        return outputs

class ENCBase(torch.nn.Module):
    def __init__(self, args):
        super(ENCBase, self).__init__()

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.args = args
        self.reset_precomp()

    def set_parallel(self):
        pass

    def set_precomp(self, mean_scalar, std_scalar):
        self.mean_scalar = mean_scalar.to(self.this_device)
        self.std_scalar  = std_scalar.to(self.this_device)

    # not tested yet
    def reset_precomp(self):
        self.mean_scalar = torch.zeros(1).type(torch.FloatTensor).to(self.this_device)
        self.std_scalar  = torch.ones(1).type(torch.FloatTensor).to(self.this_device)
        self.num_test_block= 0.0

    def enc_act(self, inputs):
        if self.args.enc_act == 'tanh':
            return  F.tanh(inputs)
        elif self.args.enc_act == 'elu':
            return F.elu(inputs)
        elif self.args.enc_act == 'relu':
            return F.relu(inputs)
        elif self.args.enc_act == 'selu':
            return F.selu(inputs)
        elif self.args.enc_act == 'sigmoid':
            return F.sigmoid(inputs)
        elif self.args.enc_act == 'linear':
            return inputs
        else:
            return inputs

    def power_constraint(self, x_input):

        if self.args.no_code_norm:
            return x_input
        else:
            this_mean    = torch.mean(x_input)
            this_std     = torch.std(x_input)

            if self.args.precompute_norm_stats:
                self.num_test_block += 1.0
                self.mean_scalar = (self.mean_scalar*(self.num_test_block-1) + this_mean)/self.num_test_block
                self.std_scalar  = (self.std_scalar*(self.num_test_block-1) + this_std)/self.num_test_block
                x_input_norm = (x_input - self.mean_scalar)/self.std_scalar
            else:
                x_input_norm = (x_input-this_mean)*1.0 / this_std

            if self.args.train_channel_mode == 'block_norm_ste':
                stequantize = STEQuantize.apply
                x_input_norm = stequantize(x_input_norm, self.args)

            if self.args.enc_truncate_limit>0:
                x_input_norm = torch.clamp(x_input_norm, -self.args.enc_truncate_limit, self.args.enc_truncate_limit)

            return x_input_norm

class ENC_interCNN(ENCBase):
    def __init__(self, args, p_array):
        # turbofy only for code rate 1/3
        super(ENC_interCNN, self).__init__(args)
        self.args             = args

        # Encoder
        if self.args.encoder == 'TurboAE_rate3_cnn':
            self.enc_cnn_1       = SameShapeConv1d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k,
                                                      out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size)

            self.enc_cnn_2       = SameShapeConv1d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k,
                                                      out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size)

            self.enc_cnn_3       = SameShapeConv1d(num_layer=args.enc_num_layer, in_channels=args.code_rate_k,
                                                      out_channels= args.enc_num_unit, kernel_size = args.enc_kernel_size)
        else: # Dense
            raise AssertionError("SNH")


        self.enc_linear_1    = torch.nn.Linear(args.enc_num_unit, 1)
        self.enc_linear_2    = torch.nn.Linear(args.enc_num_unit, 1)
        self.enc_linear_3    = torch.nn.Linear(args.enc_num_unit, 1)

        self.interleaver      = Interleaver(args, p_array)


    def set_interleaver(self, p_array):
        self.interleaver.set_parray(p_array)

    def set_parallel(self):
        self.enc_cnn_1 = torch.nn.DataParallel(self.enc_cnn_1)
        self.enc_cnn_2 = torch.nn.DataParallel(self.enc_cnn_2)
        self.enc_cnn_3 = torch.nn.DataParallel(self.enc_cnn_3)
        self.enc_linear_1 = torch.nn.DataParallel(self.enc_linear_1)
        self.enc_linear_2 = torch.nn.DataParallel(self.enc_linear_2)
        self.enc_linear_3 = torch.nn.DataParallel(self.enc_linear_3)

    def forward(self, inputs):

        if self.args.is_variable_block_len:
            block_len = inputs.shape[1]
            # reset interleaver
            if self.args.is_interleave != 0:           # fixed interleaver.
                seed = np.random.randint(0, self.args.is_interleave)
                rand_gen = mtrand.RandomState(seed)
                p_array = rand_gen.permutation(arange(block_len))
                self.set_interleaver(p_array)

        inputs     = 2.0*inputs - 1.0
        x_sys      = self.enc_cnn_1(inputs)
        x_sys      = self.enc_act(self.enc_linear_1(x_sys))

        x_p1       = self.enc_cnn_2(inputs)
        x_p1       = self.enc_act(self.enc_linear_2(x_p1))

        x_sys_int  = self.interleaver(inputs)
        x_p2       = self.enc_cnn_3(x_sys_int)
        x_p2       = self.enc_act(self.enc_linear_3(x_p2))

        x_tx       = torch.cat([x_sys,x_p1, x_p2], dim = 2)

        codes = self.power_constraint(x_tx)

        return codes

class DEC_LargeCNN(torch.nn.Module):
    def __init__(self, args, p_array):
        super(DEC_LargeCNN, self).__init__()
        self.args = args

        use_cuda = not args.no_cuda and torch.cuda.is_available()
        self.this_device = torch.device("cuda" if use_cuda else "cpu")

        self.interleaver          = Interleaver(args, p_array)
        self.deinterleaver        = DeInterleaver(args, p_array)

        self.dec1_cnns      = torch.nn.ModuleList()
        self.dec2_cnns      = torch.nn.ModuleList()
        self.dec1_outputs   = torch.nn.ModuleList()
        self.dec2_outputs   = torch.nn.ModuleList()

        if self.args.encoder == 'TurboAE_rate3_cnn':
            CNNLayer = SameShapeConv1d
        else:
            raise AssertionError("SNH")


        for idx in range(args.num_iteration):
            self.dec1_cnns.append(CNNLayer(num_layer=args.dec_num_layer, in_channels=2 + args.num_iter_ft,
                                                  out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
            )

            self.dec2_cnns.append(CNNLayer(num_layer=args.dec_num_layer, in_channels=2 + args.num_iter_ft,
                                                  out_channels= args.dec_num_unit, kernel_size = args.dec_kernel_size)
            )
            self.dec1_outputs.append(torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

            if idx == args.num_iteration -1:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, 1))
            else:
                self.dec2_outputs.append(torch.nn.Linear(args.dec_num_unit, args.num_iter_ft))

    def set_parallel(self):
        for idx in range(self.args.num_iteration):
            self.dec1_cnns[idx] = torch.nn.DataParallel(self.dec1_cnns[idx])
            self.dec2_cnns[idx] = torch.nn.DataParallel(self.dec2_cnns[idx])
            self.dec1_outputs[idx] = torch.nn.DataParallel(self.dec1_outputs[idx])
            self.dec2_outputs[idx] = torch.nn.DataParallel(self.dec2_outputs[idx])


    def set_interleaver(self, p_array):
        self.interleaver.set_parray(p_array)
        self.deinterleaver.set_parray(p_array)

    def forward(self, received):

        if self.args.is_variable_block_len:
            block_len = received.shape[1]
            # reset interleaver
            if self.args.is_interleave != 0:           # fixed interleaver.
                seed = np.random.randint(0, self.args.is_interleave)
                rand_gen = mtrand.RandomState(seed)
                p_array = rand_gen.permutation(arange(block_len))
                self.set_interleaver(p_array)
        else:
            block_len = self.args.block_len

        received = received.type(torch.FloatTensor).to(self.this_device)
        # Turbo Decoder
        r_sys     = received[:,:,0].view((self.args.batch_size, block_len, 1))
        r_sys_int = self.interleaver(r_sys)
        r_par1    = received[:,:,1].view((self.args.batch_size, block_len, 1))
        r_par2    = received[:,:,2].view((self.args.batch_size, block_len, 1))

        #num_iteration,
        prior = torch.zeros((self.args.batch_size, block_len, self.args.num_iter_ft)).to(self.this_device)

        for idx in range(self.args.num_iteration - 1):
            x_this_dec = torch.cat([r_sys, r_par1, prior], dim = 2)

            x_dec  = self.dec1_cnns[idx](x_this_dec)
            x_plr      = self.dec1_outputs[idx](x_dec)

            if self.args.extrinsic:
                x_plr = x_plr - prior

            x_plr_int  = self.interleaver(x_plr)

            x_this_dec = torch.cat([r_sys_int, r_par2, x_plr_int ], dim = 2)

            x_dec  = self.dec2_cnns[idx](x_this_dec)

            x_plr      = self.dec2_outputs[idx](x_dec)

            if self.args.extrinsic:
                x_plr = x_plr - x_plr_int

            prior      = self.deinterleaver(x_plr)

        # last round
        x_this_dec = torch.cat([r_sys,r_par1, prior], dim = 2)

        x_dec     = self.dec1_cnns[self.args.num_iteration - 1](x_this_dec)
        x_plr      = self.dec1_outputs[self.args.num_iteration - 1](x_dec)

        if self.args.extrinsic:
            x_plr = x_plr - prior

        x_plr_int  = self.interleaver(x_plr)

        x_this_dec = torch.cat([r_sys_int, r_par2, x_plr_int ], dim = 2)

        x_dec     = self.dec2_cnns[self.args.num_iteration - 1](x_this_dec)
        x_plr      = self.dec2_outputs[self.args.num_iteration - 1](x_dec)

        final      = torch.sigmoid(self.deinterleaver(x_plr))

        return final

class Interleaver(torch.nn.Module):
    def __init__(self, args, p_array):
        super(Interleaver, self).__init__()
        self.args = args
        self.p_array = torch.LongTensor(p_array).view(len(p_array))

    def set_parray(self, p_array):
        self.p_array = torch.LongTensor(p_array).view(len(p_array))

    def forward(self, inputs):

        inputs = inputs.permute(1,0,2)
        res    = inputs[self.p_array]
        res    = res.permute(1, 0, 2)

        return res


class DeInterleaver(torch.nn.Module):
    def __init__(self, args, p_array):
        super(DeInterleaver, self).__init__()
        self.args = args

        self.reverse_p_array = [0 for _ in range(len(p_array))]
        for idx in range(len(p_array)):
            self.reverse_p_array[p_array[idx]] = idx

        self.reverse_p_array = torch.LongTensor(self.reverse_p_array).view(len(p_array))

    def set_parray(self, p_array):

        self.reverse_p_array = [0 for _ in range(len(p_array))]
        for idx in range(len(p_array)):
            self.reverse_p_array[p_array[idx]] = idx

        self.reverse_p_array = torch.LongTensor(self.reverse_p_array).view(len(p_array))

    def forward(self, inputs):
        inputs = inputs.permute(1,0,2)
        res    = inputs[self.reverse_p_array]
        res    = res.permute(1, 0, 2)

        return res

In [None]:
block_len = 100

seed = np.random.randint(0, 1)
rand_gen = np.random.mtrand.RandomState(seed)
p_array1 = rand_gen.permutation(np.arange(block_len))

print('using random interleaver', p_array1)

batch_size=100, bce_lambda=1.0, bec_p=0.0, bec_p_dec=0.0, ber_lambda=1.0, block_len=100, block_len_high=200,
block_len_low=10, bsc_p=0.0, bsc_p_dec=0.0, channel='awgn', code_rate_k=1, code_rate_n=3, dec_act='linear',
dec_kernel_size=5, dec_lr=0.001, dec_num_layer=5, dec_num_unit=100, dec_rnn='gru', decoder='TurboAE_rate3_cnn2d',
demod_lr=0.005, demod_num_layer=1, demod_num_unit=20, dropout=0.0, enc_act='elu', enc_clipping='both', enc_grad_limit=0.01,
enc_kernel_size=5, enc_lr=0.001, enc_num_layer=2, enc_num_unit=100, enc_quantize_level=2, enc_rnn='gru', enc_truncate_limit=0,
enc_value_limit=1.0, encoder='TurboAE_rate3_cnn2d', extrinsic=1, focal_alpha=1.0, focal_gamma=0.0, img_size=10,
init_nw_weight='default', is_interleave=1, is_k_same_code=False, is_parallel=0, is_same_interleaver=1,
is_variable_block_len=False, joint_train=0, k_same_code=2, lambda_maxBCE=0.01, loss='bce', mod_lr=0.005,
mod_num_layer=1, mod_num_unit=20, mod_pc='block_power', mod_rate=2, momentum=0.9, no_code_norm=False,
no_cuda=False, num_ber_puncture=5, num_block=1000, num_epoch=1, num_iter_ft=5, num_iteration=6, num_train_dec=5,
num_train_demod=5, num_train_enc=1, num_train_mod=1, optimizer='adam', precompute_norm_stats=False,
print_pos_ber=False, print_pos_power=False, print_test_traj=False, radar_power=5.0, radar_prob=0.05, 
rec_quantize=False, rec_quantize_level=2, rec_quantize_limit=1.0, snr_points=12, snr_test_end=4.0,
snr_test_start=-1.5, test_channel_mode='block_norm', test_ratio=1, train_channel_mode='block_norm',
train_dec_channel_high=2.0, train_dec_channel_low=-1.5, train_enc_channel_high=1.0, train_enc_channel_low=1.0, vv=5

In [None]:
import pickle

with open('pickelArgs', 'rb') as handle:
    args = pickle.load(handle)

In [None]:
args.code_rate_k

In [None]:
encoder = ENC_interCNN(args, p_array1)
decoder = DEC_LargeCNN(args, p_array1)

In [None]:
interleaverTester = torch.arange(0, args.batch_size * args.block_len * args.code_rate_k).view(args.batch_size, args.block_len, args.code_rate_k)
interleaverTester.shape

In [None]:
inter = Interleaver(args, p_array1)
interleaved = inter(interleaverTester)
# interleaved

In [None]:
deinter = DeInterleaver(args, p_array1)
deinter(interleaved)

In [None]:
# define model, check encoder, decoder...