Skip to content

Commit

Permalink
[Bug fix] Bad typing in links.py introduced in bcc36dd (#88)
Browse files Browse the repository at this point in the history
[Bug fix] Bad typing in links.py introduced in bcc36dd
  • Loading branch information
BastienTr committed Oct 12, 2020
1 parent 292b4e5 commit ee289ae
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 33 deletions.
16 changes: 8 additions & 8 deletions commpy/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from __future__ import division # Python 2 compatibility

import math
from inspect import getfullargspec
from fractions import Fraction
from inspect import getfullargspec

import numpy as np

Expand Down Expand Up @@ -95,7 +95,7 @@ class LinkModel:
*Default* Es=1.
decoder : function with prototype decoder(array) or decoder(y, H, constellation, noise_var, array) that return a
binary array.
binary ndarray.
*Default* is no process.
rate : float or Fraction in (0,1]
Expand Down Expand Up @@ -127,7 +127,7 @@ class LinkModel:
Es : float
Average energy per symbols.
decoder : function with prototype decoder(binary array) that return a binary array.
decoder : function with prototype decoder(binary array) that return a binary ndarray.
*Default* is no process.
rate : float
Expand Down Expand Up @@ -233,7 +233,7 @@ def link_performance_full_metrics(self, SNRs, tx_max, err_min, send_chunk=None,
# Deals with MIMO channel
if isinstance(self.channel, MIMOFlatChannel):
nb_symb_vector = len(channel_output)
received_msg = np.empty(int(math.ceil(len(msg) / float(self.rate))), dtype=np.int8)
received_msg = np.empty(int(math.ceil(len(msg) / float(self.rate))))
for i in range(nb_symb_vector):
received_msg[receive_size * i:receive_size * (i + 1)] = \
self.receive(channel_output[i], self.channel.channel_gains[i],
Expand All @@ -251,7 +251,7 @@ def link_performance_full_metrics(self, SNRs, tx_max, err_min, send_chunk=None,
# calculate number of error frames
for i in range(number_chunks_per_send):
errors = np.bitwise_xor(msg[send_chunk * i:send_chunk * (i + 1)],
decoded_bits[send_chunk * i:send_chunk * (i + 1)]).sum()
decoded_bits[send_chunk * i:send_chunk * (i + 1)].astype(int)).sum()
bit_err[id_tx] += errors
chunk_loss[id_tx] += 1 if errors > 0 else 0

Expand Down Expand Up @@ -319,7 +319,7 @@ def link_performance(self, SNRs, send_max, err_min, send_chunk=None, code_rate=1
# Deals with MIMO channel
if isinstance(self.channel, MIMOFlatChannel):
nb_symb_vector = len(channel_output)
received_msg = np.empty(int(math.ceil(len(msg) / float(self.rate))), dtype=np.int8)
received_msg = np.empty(int(math.ceil(len(msg) / float(self.rate))))
for i in range(nb_symb_vector):
received_msg[receive_size * i:receive_size * (i + 1)] = \
self.receive(channel_output[i], self.channel.channel_gains[i],
Expand All @@ -332,9 +332,9 @@ def link_performance(self, SNRs, send_max, err_min, send_chunk=None, code_rate=1
decoded_bits = self.decoder(channel_output, self.channel.channel_gains,
self.constellation, self.channel.noise_std ** 2,
received_msg, self.channel.nb_tx * self.num_bits_symbol)
bit_err += np.bitwise_xor(msg, decoded_bits[:len(msg)]).sum()
bit_err += np.bitwise_xor(msg, decoded_bits[:len(msg)].astype(int)).sum()
else:
bit_err += np.bitwise_xor(msg, self.decoder(received_msg)[:len(msg)]).sum()
bit_err += np.bitwise_xor(msg, self.decoder(received_msg)[:len(msg)].astype(int)).sum()
bit_send += send_chunk
BERs[id_SNR] = bit_err / bit_send
if bit_err < err_min:
Expand Down
92 changes: 67 additions & 25 deletions commpy/tests/test_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,94 @@
from numpy.testing import run_module_suite, assert_allclose, dec
from scipy.special import erfc

from commpy.channelcoding.ldpc import get_ldpc_code_params, triang_ldpc_systematic_encode, ldpc_bp_decode
from commpy.channels import MIMOFlatChannel, SISOFlatChannel
from commpy.links import link_performance, LinkModel
from commpy.modulation import QAMModem, kbest
from commpy.modulation import QAMModem, kbest, best_first_detector


@dec.slow
def test_link_performance():
# Set seed
seed(17121996)
seed(8071996)
######################################
# Build models & desired solutions
######################################
models = []
desired_bers = []
snr_range = []
labels = []
rtols = []
code_rates = []

# Apply link_performance to SISO QPSK and AWGN channel
# SISO QPSK and AWGN channel
QPSK = QAMModem(4)

def receiver(y, h, constellation, noise_var):
return QPSK.demodulate(y, 'hard')

model = LinkModel(QPSK.modulate, SISOFlatChannel(fading_param=(1 + 0j, 0)), receiver,
QPSK.num_bits_symbol, QPSK.constellation, QPSK.Es)
models.append(LinkModel(QPSK.modulate, SISOFlatChannel(fading_param=(1 + 0j, 0)), receiver,
QPSK.num_bits_symbol, QPSK.constellation, QPSK.Es))
snr_range.append(arange(0, 9, 2))
desired_bers.append(erfc(sqrt(10 ** (snr_range[-1] / 10) / 2)) / 2)
labels.append('SISO QPSK and AWGN channel')
rtols.append(.25)
code_rates.append(1)

BERs = link_performance(model, range(0, 9, 2), 600e4, 600)
desired = erfc(sqrt(10 ** (arange(0, 9, 2) / 10) / 2)) / 2
assert_allclose(BERs, desired, rtol=0.25,
err_msg='Wrong performance for SISO QPSK and AWGN channel')
full_metrics = model.link_performance_full_metrics(range(0, 9, 2), 1000, 600)
assert_allclose(full_metrics[0], desired, rtol=0.25,
err_msg='Wrong performance for SISO QPSK and AWGN channel')

# Apply link_performance to MIMO 16QAM and 4x4 Rayleigh channel
# MIMO 16QAM, 4x4 Rayleigh channel and hard-output K-Best
QAM16 = QAMModem(16)
RayleighChannel = MIMOFlatChannel(4, 4)
RayleighChannel.uncorr_rayleigh_fading(complex)

def receiver(y, h, constellation, noise_var):
return QAM16.demodulate(kbest(y, h, constellation, 16), 'hard')

model = LinkModel(QAM16.modulate, RayleighChannel, receiver,
QAM16.num_bits_symbol, QAM16.constellation, QAM16.Es)
SNRs = arange(0, 21, 5) + 10 * log10(QAM16.num_bits_symbol)

BERs = link_performance(model, SNRs, 600e4, 600)
desired = (2e-1, 1e-1, 3e-2, 2e-3, 4e-5) # From reference
assert_allclose(BERs, desired, rtol=1.25,
err_msg='Wrong performance for MIMO 16QAM and 4x4 Rayleigh channel')
full_metrics = model.link_performance_full_metrics(SNRs, 1000, 600)
assert_allclose(full_metrics[0], desired, rtol=1.25,
err_msg='Wrong performance for MIMO 16QAM and 4x4 Rayleigh channel')
models.append(LinkModel(QAM16.modulate, RayleighChannel, receiver,
QAM16.num_bits_symbol, QAM16.constellation, QAM16.Es))
snr_range.append(arange(0, 21, 5) + 10 * log10(QAM16.num_bits_symbol))
desired_bers.append((2e-1, 1e-1, 3e-2, 2e-3, 4e-5)) # From reference
labels.append('MIMO 16QAM, 4x4 Rayleigh channel and hard-output K-Best')
rtols.append(1.25)
code_rates.append(1)

# MIMO 16QAM, 4x4 Rayleigh channel and soft-output best-first
QAM16 = QAMModem(16)
RayleighChannel = MIMOFlatChannel(4, 4)
RayleighChannel.uncorr_rayleigh_fading(complex)
ldpc_params = get_ldpc_code_params('commpy/channelcoding/designs/ldpc/wimax/1440.720.txt', True)

def modulate(bits):
return QAM16.modulate(triang_ldpc_systematic_encode(bits, ldpc_params, False).reshape(-1, order='F'))

def decoder(llrs):
return ldpc_bp_decode(llrs, ldpc_params, 'MSA', 15)[0][:720].reshape(-1, order='F')

def demode(symbs):
return QAM16.demodulate(symbs, 'hard')

def receiver(y, h, constellation, noise_var):
return best_first_detector(y, h, constellation, (1, 3, 5), noise_var, demode, 500)

models.append(LinkModel(modulate, RayleighChannel, receiver,
QAM16.num_bits_symbol, QAM16.constellation, QAM16.Es,
decoder, 0.5))
snr_range.append(arange(17, 20, 1))
desired_bers.append((1.7e-1, 1e-1, 2.5e-3)) # From reference
labels.append('MIMO 16QAM, 4x4 Rayleigh channel and soft-output best-first')
rtols.append(2)
code_rates.append(.5)

######################################
# Make tests
######################################

for test in range(len(models)):
BERs = link_performance(models[test], snr_range[test], 5e5, 200, 720, models[test].rate)
assert_allclose(BERs, desired_bers[test], rtol=rtols[test],
err_msg='Wrong performance for ' + labels[test])
full_metrics = models[test].link_performance_full_metrics(snr_range[test], 2500, 200, 720, models[test].rate)
assert_allclose(full_metrics[0], desired_bers[test], rtol=rtols[test],
err_msg='Wrong performance for ' + labels[test])


if __name__ == "__main__":
Expand Down

0 comments on commit ee289ae

Please sign in to comment.