Skip to content

Commit

Permalink
Merge pull request #27 from kirlf/master
Browse files Browse the repository at this point in the history
Update conv_encode
  • Loading branch information
veeresht committed May 23, 2019
2 parents f39bc48 + b915f5b commit b6969c2
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 64 deletions.
146 changes: 82 additions & 64 deletions commpy/channelcoding/convcode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


# Authors: Veeresh Taranalli <veeresht@gmail.com>
# License: BSD 3-Clause

Expand All @@ -18,59 +16,44 @@
class Trellis:
"""
Class defining a Trellis corresponding to a k/n - rate convolutional code.
Parameters
----------
memory : 1D ndarray of ints
Number of memory elements per input of the convolutional encoder.
g_matrix : 2D ndarray of ints (octal representation)
Generator matrix G(D) of the convolutional encoder. Each element of
G(D) represents a polynomial.
feedback : int, optional
Feedback polynomial of the convolutional encoder. Default value is 00.
code_type : {'default', 'rsc'}, optional
Use 'rsc' to generate a recursive systematic convolutional code.
If 'rsc' is specified, then the first 'k x k' sub-matrix of
G(D) must represent a identity matrix along with a non-zero
feedback polynomial.
Attributes
----------
k : int
Size of the smallest block of input bits that can be encoded using
the convolutional code.
n : int
Size of the smallest block of output bits generated using
the convolutional code.
total_memory : int
Total number of delay elements needed to implement the convolutional
encoder.
number_states : int
Number of states in the convolutional code trellis.
number_inputs : int
Number of branches from each state in the convolutional code trellis.
next_state_table : 2D ndarray of ints
Table representing the state transition matrix of the
convolutional code trellis. Rows represent current states and
columns represent current inputs in decimal. Elements represent the
corresponding next states in decimal.
output_table : 2D ndarray of ints
Table representing the output matrix of the convolutional code trellis.
Rows represent current states and columns represent current inputs in
decimal. Elements represent corresponding outputs in decimal.
Examples
--------
>>> from numpy import array
Expand Down Expand Up @@ -98,7 +81,6 @@ class Trellis:
[3 0]
[1 2]
[2 1]]
"""
def __init__(self, memory, g_matrix, feedback = 0, code_type = 'default'):

Expand All @@ -107,7 +89,8 @@ def __init__(self, memory, g_matrix, feedback = 0, code_type = 'default'):
if code_type == 'rsc':
for i in range(self.k):
g_matrix[i][i] = feedback

self.code_type = code_type

self.total_memory = memory.sum()
self.number_states = pow(2, self.total_memory)
self.number_inputs = pow(2, self.k)
Expand Down Expand Up @@ -176,7 +159,7 @@ def _generate_grid(self, trellis_length):
""" Private method """

grid = np.mgrid[0.12:0.22*trellis_length:(trellis_length+1)*(0+1j),
0.1:0.1+self.number_states*0.1:self.number_states*(0+1j)].reshape(2, -1)
0.1:0.5+self.number_states*0.1:self.number_states*(0+1j)].reshape(2, -1)

return grid

Expand Down Expand Up @@ -231,27 +214,22 @@ def _generate_labels(self, grid, state_order, state_radius, font):
def visualize(self, trellis_length = 2, state_order = None,
state_radius = 0.04, edge_colors = None):
""" Plot the trellis diagram.
Parameters
----------
trellis_length : int, optional
Specifies the number of time steps in the trellis diagram.
Default value is 2.
state_order : list of ints, optional
Specifies the order in the which the states of the trellis
are to be displayed starting from the top in the plot.
Default order is [0,...,number_states-1]
state_radius : float, optional
Radius of each state (circle) in the plot.
Default value is 0.04
edge_colors = list of hex color codes, optional
A list of length equal to the number_inputs,
containing color codes that represent the edge corresponding
to the input.
"""
if edge_colors is None:
edge_colors = ["#9E1BE0", "#06D65D"]
Expand All @@ -260,7 +238,7 @@ def visualize(self, trellis_length = 2, state_order = None,
state_order = list(range(self.number_states))

font = "sans-serif"
fig = plt.figure()
fig = plt.figure(figsize=(12, 6), dpi=150)
ax = plt.axes([0,0,1,1])
trellis_patches = []

Expand All @@ -281,26 +259,23 @@ def visualize(self, trellis_length = 2, state_order = None,
ax.add_collection(collection)
ax.set_xticks([])
ax.set_yticks([])
#plt.legend([edge_patches[0], edge_patches[1]], ["1-input", "0-input"])
plt.legend([edge_patches[0], edge_patches[1]], ["1-input", "0-input"])
#plt.savefig('trellis')
plt.show()


def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=None):
def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=None):
"""
Encode bits using a convolutional code.
Parameters
----------
message_bits : 1D ndarray containing {0, 1}
Stream of bits to be convolutionally encoded.
generator_matrix : 2-D ndarray of ints
Generator matrix G(D) of the convolutional code using which the input
bits are to be encoded.
M : 1D ndarray of ints
Number of memory elements per input of the convolutional encoder.
trellis: pre-initialized Trellis structure.
termination: {'cont', 'term'}, optional
Create ('term') or not ('cont') termination bits.
puncture_matrix: 2D ndarray containing {0, 1}, optional
Matrix used for the puncturing algorithm
Returns
-------
coded_bits : 1D ndarray containing {0, 1}
Expand All @@ -311,26 +286,30 @@ def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=No
n = trellis.n
total_memory = trellis.total_memory
rate = float(k)/n

code_type = trellis.code_type

if puncture_matrix is None:
puncture_matrix = np.ones((trellis.k, trellis.n))

number_message_bits = np.size(message_bits)

# Initialize an array to contain the message bits plus the truncation zeros
if code_type == 'default':
inbits = np.zeros(number_message_bits + total_memory + total_memory % k,
'int')
number_inbits = number_message_bits + total_memory + total_memory % k

# Pad the input bits with M zeros (L-th terminated truncation)
inbits[0:number_message_bits] = message_bits
number_outbits = int(number_inbits/rate)

else:

if termination == 'cont':
inbits = message_bits
number_inbits = number_message_bits
number_outbits = int((number_inbits + total_memory)/rate)
number_outbits = int(number_inbits/rate)
else:
# Initialize an array to contain the message bits plus the truncation zeros
if code_type == 'rsc':
inbits = message_bits
number_inbits = number_message_bits
number_outbits = int((number_inbits + total_memory)/rate)
else:
number_inbits = number_message_bits + total_memory + total_memory % k
inbits = np.zeros(number_inbits, 'int')
# Pad the input bits with M zeros (L-th terminated truncation)
inbits[0:number_message_bits] = message_bits
number_outbits = int(number_inbits/rate)

outbits = np.zeros(number_outbits, 'int')
p_outbits = np.zeros(int(number_outbits*
Expand All @@ -349,8 +328,7 @@ def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=No
current_state = next_state_table[current_state][current_input]
j += 1

if code_type == 'rsc':

if code_type == 'rsc' and termination == 'term':
term_bits = dec2bitarray(current_state, trellis.total_memory)
term_bits = term_bits[::-1]
for i in range(trellis.total_memory):
Expand All @@ -360,11 +338,12 @@ def conv_encode(message_bits, trellis, code_type = 'default', puncture_matrix=No
current_state = next_state_table[current_state][current_input]
j += 1

j = 0
for i in range(number_outbits):
if puncture_matrix[0][i % np.size(puncture_matrix, 1)] == 1:
p_outbits[j] = outbits[i]
j = j + 1
if puncture_matrix is not None:
j = 0
for i in range(number_outbits):
if puncture_matrix[0][i % np.size(puncture_matrix, 1)] == 1:
p_outbits[j] = outbits[i]
j = j + 1

return p_outbits

Expand Down Expand Up @@ -474,32 +453,25 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
"""
Decodes a stream of convolutionally encoded bits using the Viterbi Algorithm
Parameters
----------
coded_bits : 1D ndarray
Stream of convolutionally encoded bits which are to be decoded.
generator_matrix : 2D ndarray of ints
Generator matrix G(D) of the convolutional code using which the
input bits are to be decoded.
M : 1D ndarray of ints
Number of memory elements per input of the convolutional encoder.
tb_length : int
Traceback depth (Typically set to 5*(M+1)).
decoding_type : str {'hard', 'unquantized'}
The type of decoding to be used.
'hard' option is used for hard inputs (bits) to the decoder, e.g., BSC channel.
'unquantized' option is used for soft inputs (real numbers) to the decoder, e.g., BAWGN channel.
Returns
-------
decoded_bits : 1D ndarray
Decoded bit stream.
References
----------
.. [1] Todd K. Moon. Error Correction Coding: Mathematical Methods and
Expand Down Expand Up @@ -571,3 +543,49 @@ def viterbi_decode(coded_bits, trellis, tb_depth=None, decoding_type='hard'):
current_number_states = 1

return decoded_bits[0:len(decoded_bits)-tb_depth-1]

def puncturing(message, punct_vec):
'''
Applying of the punctured procedure.
Parameters
----------
message: input message {0,1}
punct_vec: puncturing vector {0,1}
Returns
-------
punctured: output punctured vector {0,1}
'''
shift = 0
N = len(punct_vec)
punctured = []
for idx, item in enumerate(message):
if punct_vec[idx-shift*N] == 1:
punctured.append(item)
if idx%N == 0:
shift = shift + 1
return np.array(punctured)

def depuncturing(punctured, punct_vec, shouldbe):
'''
Applying of the inserting zeros procedure.
Parameters
----------
punctured: input punctured message {0,1}
punct_vec: puncturing vector {0,1}
shouldbe: length of the initial message (before puncturing)
Returns
-------
depunctured: output vector {0,1}
'''
shift = 0
shift2 = 0
N = len(punct_vec)
depunctured = np.zeros((shouldbe,))
for idx, item in enumerate(depunctured):
if punct_vec[idx - shift*N] == 1:
depunctured[idx] = float(punctured[idx-shift2])
else:
shift2 = shift2 + 1
if idx%N == 0:
shift = shift + 1;
return depunctured
4 changes: 4 additions & 0 deletions commpy/channelcoding/tests/test_convcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def test_conv_encode_viterbi_decode(self):
coded_bits = conv_encode(msg, self.trellis_1)
decoded_bits = viterbi_decode(coded_bits.astype(float), self.trellis_1, 15)
assert_array_equal(decoded_bits[:-2], msg)

coded_bits = conv_encode(msg, self.trellis_1, termination = 'cont')
decoded_bits = viterbi_decode(coded_bits.astype(float), self.trellis_1, 15)
assert_array_equal(decoded_bits, msg)

coded_bits = conv_encode(msg, self.trellis_1)
coded_syms = 2.0*coded_bits - 1
Expand Down

0 comments on commit b6969c2

Please sign in to comment.