In [None]:
from python_code.channel.modulator import BPSKModulator
import torch
import numpy as np
import math

def create_transition_table(n_states: int) -> np.ndarray:
    """
    creates transition table of size [n_states,2]
    previous state of state i and input bit b is the state in cell [i,b]
    """
    transition_table = np.concatenate([np.arange(n_states), np.arange(n_states)]).reshape(n_states, 2)
    return transition_table


device = "cpu"
block_length = 10
memory_length = 2
n_states = 2 ** memory_length
transmission_length = block_length
batch_size = 1

h = np.array([[1.0, 0.5]])  # Example channel coefficients
x = torch.zeros(batch_size, block_length)
y = torch.zeros(batch_size, transmission_length)

c = np.array([1,1,2,2,1,2,2,1,1,1]).reshape(batch_size,block_length)-1
padded_c = np.concatenate([c, np.zeros([c.shape[0], memory_length])], axis=1)
s = 1 - 2 * padded_c
blockwise_s = np.concatenate([s[:, i:-memory_length + i] for i in range(memory_length)], axis=0)
conv = np.dot(h[:, ::-1], blockwise_s)
[row, col] = conv.shape
y[0, :] = torch.tensor(conv)


transition_table_array = create_transition_table(n_states)
transition_table = torch.Tensor(transition_table_array).to(device)

snr = 0
all_states_decimal = np.arange(n_states).astype(np.uint8).reshape(-1, 1)
all_states_binary = np.unpackbits(all_states_decimal, axis=1).astype(int)
all_states_symbols = BPSKModulator.modulate(all_states_binary[:, -memory_length:])
state_priors = np.dot(all_states_symbols, h[:,::-1].T)
state_priors = torch.Tensor(state_priors).to(device)

priors = y.unsqueeze(dim=2) - state_priors.T.repeat(
    repeats=[y.shape[0] // state_priors.shape[1], 1]).unsqueeze(
    dim=1)
# to llr representation
sigma = 1 / 10 ** (-snr / 10)
priors = priors ** 2 / (2 * sigma ** 2) + math.log(math.sqrt(2 * math.pi) * sigma)

#### BCJR (sum product) ####
# compute forward probabilities
alpha = torch.zeros([y.shape[0], transmission_length+1, n_states]).to(device)
alpha[:, 0, 0] = 1  # Initialization: start from state 0
for i in range(1, transmission_length+1):
    for state in range(n_states):
        incoming_states = np.where(transition_table_array[:, 0] == state)[0].tolist() + \
                            np.where(transition_table_array[:, 1] == state)[0].tolist()
        gamma = torch.exp(-priors[:, i - 1, state]).unsqueeze(dim=1)
        alpha[:, i, state] = torch.sum(alpha[:, i - 1, incoming_states] * gamma, dim=1)
    alpha[:, i, :] /= torch.sum(alpha[:, i, :], dim=1, keepdim=True)  # Normalize


# compute backward probabilities
beta = torch.zeros([y.shape[0], transmission_length+1, n_states]).to(device)
beta[:, -1, 0] = 1  # Initialization: end state equally likely
for i in range(transmission_length-1, -1, -1):
    for state in range(n_states):
        outgoing_states = transition_table_array[state]
        gamma = torch.exp(-priors[:, i, state]).unsqueeze(dim=1)
        beta[:, i, state] = torch.sum(beta[:, i + 1, outgoing_states] * gamma, dim=1)
    beta[:, i, :] /= torch.sum(beta[:, i, :], dim=1, keepdim=True)  # Normalize

# compute MAP v1
decoded_word = torch.zeros([y.shape[0], transmission_length])
for i in range(transmission_length):
    up = torch.zeros(y.shape[0])
    down = torch.zeros(y.shape[0])
    for state in range(n_states):
        for jj in range(2):
            next_state = transition_table_array[state, jj]
            _alpha = alpha[:, i, state]
            _beta = beta[:, i, next_state]
            _gamma = torch.exp(-priors[:, i, next_state])
            print(_alpha, _beta, _gamma)
            if jj == 0:
                up += _alpha * _gamma * _beta
            else:
                down += _alpha * _gamma * _beta
    decoded_word[:, i] = torch.where(up < down, 1, 0)

prepend_word = torch.zeros([y.shape[0], memory_length-1]).to(device)
decoded_word = torch.cat([prepend_word, decoded_word], dim=1)
decoded_word = decoded_word[:,:transmission_length]

ber = torch.sum(torch.abs(decoded_word - x)).item() / (batch_size * block_length)


tensor([1.]) tensor([0.5766]) tensor([0.3989])
tensor([1.]) tensor([0.0680]) tensor([0.0540])
tensor([0.]) tensor([0.3497]) tensor([0.2420])
tensor([0.]) tensor([0.0056]) tensor([0.0044])
tensor([0.]) tensor([0.5766]) tensor([0.3989])
tensor([0.]) tensor([0.0680]) tensor([0.0540])
tensor([0.]) tensor([0.3497]) tensor([0.2420])
tensor([0.]) tensor([0.0056]) tensor([0.0044])
tensor([0.8808]) tensor([0.0366]) tensor([0.0540])
tensor([0.8808]) tensor([0.4977]) tensor([0.3989])
tensor([0.1192]) tensor([0.1639]) tensor([0.2420])
tensor([0.1192]) tensor([0.3018]) tensor([0.2420])
tensor([0.]) tensor([0.0366]) tensor([0.0540])
tensor([0.]) tensor([0.4977]) tensor([0.3989])
tensor([0.]) tensor([0.1639]) tensor([0.2420])
tensor([0.]) tensor([0.3018]) tensor([0.2420])
tensor([0.1041]) tensor([0.0064]) tensor([0.0044])
tensor([0.1041]) tensor([0.3455]) tensor([0.2420])
tensor([0.7695]) tensor([0.0784]) tensor([0.0540])
tensor([0.7695]) tensor([0.5697]) tensor([0.3989])
tensor([0.0632]) tensor([0.0