In [1]:
import numpy as np
from itertools import product
from typing import List, Tuple

# Demo. Decoder for Reed-Muller Codes

### 1. Generator Matrix

In [2]:
def RM_generator_matrix(r: int, m: int, variation: str) -> Tuple[np.ndarray, List[Tuple[int]]]:
    
    n = 2 ** m
    monomials = []

    for deg in range(r + 1):
        for bits in product([0, 1], repeat=m):
            if sum(bits) == deg:
                monomials.append(bits)

    G = []
    for mono in monomials:
        row = []
        for i in range(n):
            x = [int(b) for b in bin(i)[2:].rjust(m, '0')]
            val = 1
            for a, xi in zip(mono, x):
                if a == 1:
                    val &= xi
            row.append(val)
        G.append(row)
    
    G = np.array(G, dtype=int)
    
    if variation == "None":
        G_final = G
        monomials_final = monomials
    
    if variation == "punctured":
        # Remove the first column (corresponding to evaluation point 0...0)
        G_final = G[:, 1:]
        monomials_final = monomials

    if variation == "shortened":
        # Remove the first column (corresponding to evaluation point 0...0)
        G_final = G[:, 1:]

        # Remove the first row (corresponding to constant polynomial = all-0 monomial)
        G_final = G_final[1:, :]
        monomials_final = monomials[1:]
        
    return G_final, monomials_final

In [3]:
# Test
print("RM(1,3):")
print(RM_generator_matrix(1, 3, variation="None"))
print("RM(1,4):")
print(RM_generator_matrix(1, 4, variation="None"))
print("RM(2,4):")
print(RM_generator_matrix(2, 4, variation="None"))

RM(1,3):
(array([[1, 1, 1, 1, 1, 1, 1, 1],
       [0, 1, 0, 1, 0, 1, 0, 1],
       [0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 0, 1, 1, 1, 1]]), [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)])
RM(1,4):
(array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), (0, 1, 0, 0), (1, 0, 0, 0)])
RM(2,4):
(array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 0, 1, 1, 

In [4]:
# Test
print("RM(1,3), punctured:")
print(RM_generator_matrix(1, 3, variation="punctured"))
print("RM(1,4), punctured:")
print(RM_generator_matrix(1, 4, variation="punctured"))
print("RM(2,4), punctured:")
print(RM_generator_matrix(2, 4, variation="punctured"))

RM(1,3), punctured:
(array([[1, 1, 1, 1, 1, 1, 1],
       [1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 1, 1, 1, 1]]), [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)])
RM(1,4), punctured:
(array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), (0, 1, 0, 0), (1, 0, 0, 0)])
RM(2,4), punctured:
(array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
       [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 

In [5]:
# Test
print("RM(1,3), shortened:")
print(RM_generator_matrix(1, 3, variation="shortened"))
print("RM(1,4), shortened:")
print(RM_generator_matrix(1, 4, variation="shortened"))
print("RM(2,4), shortened:")
print(RM_generator_matrix(2, 4, variation="shortened"))

RM(1,3), shortened:
(array([[1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 1, 1, 1, 1]]), [(0, 0, 1), (0, 1, 0), (1, 0, 0)])
RM(1,4), shortened:
(array([[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), [(0, 0, 0, 1), (0, 0, 1, 0), (0, 1, 0, 0), (1, 0, 0, 0)])
RM(2,4), shortened:
(array([[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
       [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1

### 2. Recursive List Decoder

In [6]:
def repetition_ml_decoder(y: np.ndarray) -> np.ndarray:
    """
    RM(0, m): all-0 or all-1 codeword; choose the better one
    """
    factor = 0.99 # Avoid extreme values of y (+1 or -1)
    score0 = np.prod((1 + factor * y) / 2)   # likelihood of all 0 codeword (+1)
    score1 = np.prod((1 - factor * y) / 2)   # likelihood of all 1 codeword (-1)

    c0 = np.ones_like(y)
    c1 = -np.ones_like(y)

    # print(f"Repetition ML scores: {score0:.4f} (0), {score1:.4f} (1)")
    if score0 > score1:
        return c0
    elif score1 > score0:
        return c1
    else:
        # If scores are equal, return with equal probability, unbiased
        return c0 if np.random.rand() < 0.5 else c1

def rm_list_decode(y, r, m, L=4) -> np.ndarray:
    """
    Recursive list decoder Ψ_m^r(L)
    y: input vector in [-1, 1]^n
    r, m: RM(r, m)
    L: list size
    Returns: list of (codeword, path_cost)
    """
    # print("r = ", r, ", m = ", m)
    factor = 0.99  # Avoid extreme values of y (+1 or -1)
    N = len(y)

    # Base cases
    if r == 0:
        return repetition_ml_decoder(y)
    if r == m:
        # RM(m,m) is full space — just use hard-decision
        return np.sign(y + (1-factor) * np.random.uniform(-1, 1, size=N))  # Add noise to avoid extreme values

    yL, yR = y[:N//2], y[N//2:]

    # Step 1: estimate v from componentwise product y_v = yL * yR
    y_v = yL * yR
    v_hat = rm_list_decode(y_v, r-1, m-1, L)

    # Step 2: for each v_hat, estimate u_hat
    y_hat = yR * v_hat  # Compute ŷ = yR * v_hat
    # Compute y_u = (yL + ŷ) / (1 + yL * ŷ)
    denom = 1 + factor * yL * y_hat
    y_u = (yL + y_hat) / denom
    u_hat = rm_list_decode(y_u, r, m-1, L)

    cw = np.concatenate([u_hat, u_hat * v_hat])

    return cw

In [7]:
# Test the RM list decoder with random codewords
failure = 0
sample_num = 100000

r = 2
m = 4
noise = 0.001
G, _ = RM_generator_matrix(r, m, variation="None")

for i in range(sample_num):
    # Generate a random codeword
    true_msg = np.random.randint(0, 2, G.shape[0])
    true_cw = (true_msg @ G) % 2
    noisy_cw = (true_cw + np.random.binomial(1, noise, len(true_cw))) % 2
    y = (-1) ** noisy_cw
    decoded_cw = rm_list_decode(y, r, m, L=4)
    decoded_cw = (1 - decoded_cw) // 2  # Convert from ±1 to {0, 1}
    if not np.array_equal(true_cw, decoded_cw):
        failure += 1
failure_rate = failure / sample_num
print("Failure rate: ", failure_rate)

Failure rate:  0.00011


Then extend to the punctured and shortened version. The idea is: Given a noisy codeword, we first insert $\pm 1$ at the 0-th bit (the punctured one), decode with the original RM decoder, and compare which one is closer to the original noisy codeword.

In [None]:
# Test the recursive list decoder for punctured RM code
failure = 0
sample_num = 100

r = 1
m = 4
noise = 0.01
G, _ = RM_generator_matrix(r, m, variation="punctured")

weight_list = []

for i in range(sample_num):
    # Generate a random codeword
    true_msg = np.random.randint(0, 2, G.shape[0])
    true_cw = (true_msg @ G) % 2
    noisy_cw = (true_cw + np.random.binomial(1, noise, len(true_cw))) % 2
    
    # Insert the 0-th bit (punctured) either 1 or -1, need to compare
    y = (-1) ** noisy_cw
    y0 = np.insert(y, 0, 1)  # Insert a 1 at the 0-th bit
    y1 = np.insert(y, 0, -1)  # Insert a -1 at the 0-th bit
    decoded_cw0 = rm_list_decode(y0, r, m, L=4)
    decoded_cw1 = rm_list_decode(y1, r, m, L=4)
    overlap0 = np.sum(y0 * decoded_cw0)
    overlap1 = np.sum(y1 * decoded_cw1)
    
    if overlap0 > overlap1:
        decoded_cw = decoded_cw0
    else:
        decoded_cw = decoded_cw1
    
    decoded_cw = (1 - decoded_cw[1:]) // 2  # Convert from ±1 to {0, 1}
    if not np.array_equal(true_cw, decoded_cw):
        failure += 1
    weight_list.append(np.sum(decoded_cw)) # Check if the decoded codeword has similar frequency of being odd and even

failure_rate = failure / sample_num
print("Failure rate: ", failure_rate)
weight_list = np.array(weight_list)
print("Even/Odd: ", np.sum(weight_list % 2 == 0) / sample_num) # should be 0.5 for punctured

Failure rate:  0.0
Even/Odd:  0.49


In [38]:
# Test the recursive list decoder for punctured RM code
failure = 0
sample_num = 1000000

r = 1
m = 4
noise = 0.01
G, _ = RM_generator_matrix(r, m, variation="shortened")

weight_list = []

for i in range(sample_num):
    # Generate a random codeword
    true_msg = np.random.randint(0, 2, G.shape[0])
    true_cw = (true_msg @ G) % 2
    noisy_cw = (true_cw + np.random.binomial(1, noise, len(true_cw))) % 2
    
    # Insert the 0-th bit (punctured), only insert 1 for shortend RM code
    y = (-1) ** noisy_cw
    y0 = np.insert(y, 0, 1)  # Insert a 1 at the 0-th bit
    decoded_cw0 = rm_list_decode(y0, r, m, L=4)
    decoded_cw = (1 - decoded_cw0[1:]) // 2  # Convert from ±1 to {0, 1}
    if not np.array_equal(true_cw, decoded_cw):
        failure += 1
    weight_list.append(np.sum(decoded_cw)) # should be all even if shortened

failure_rate = failure / sample_num
print("Failure rate: ", failure_rate)
weight_list = np.array(weight_list)
print("Even/Odd: ", np.sum(weight_list % 2 == 0) / sample_num) # should be 1 for shortened

Failure rate:  6e-06
Even/Odd:  1.0
