In [2]:
from dataclasses import dataclass
import random
from typing import Callable, List
import pandas as pd
import altair as alt


@dataclass
class TestData:
    """Test data class for testing the performance of the algorithm. The data is a list of integers."""

    name: str
    data: List[int]

    def display(self):
        """Display the data using Altair."""
        df = pd.DataFrame(enumerate(self.data), columns=["index", "value"])
        alt.Chart(df).mark_circle().properties(
            title=self.name, width=1200, height=100
        ).encode(
            alt.X("index:Q"),
            alt.Y("value:Q", scale=alt.Scale(zero=True, domain=[0, 0xFFFFFFFF])),
        ).display()

    def transform(
        self, f: Callable[["TestData"], List[int]], name: str = "transform"
    ) -> "TestData":
        """Transform the data using the function f. The function f should take a TestData object and return a list of integers."""
        return TestData(name=f"{name}({self.name})", data=f(self.data))


TEST_DATA_LEN = 512 // 4 # (512 + 8) * 4 + 12
test_datas = [
    TestData(name="All 0", data=[0x00000000 for _ in range(TEST_DATA_LEN)]),
    TestData(name="All 1", data=[0xFFFFFFFF for _ in range(TEST_DATA_LEN)]),
    TestData(name="increment", data=[(i & 0xFFFFFFFF) for i in range(TEST_DATA_LEN)]),
    TestData(
        name="increment2",
        data=[((i * 0x10000000) & 0xFFFFFFFF) for i in range(TEST_DATA_LEN)],
    ),
    TestData(name="bitwalk", data=[(1 << (i % 32)) for i in range(TEST_DATA_LEN)]),
    TestData(
        name="bitwalk2",
        data=[(~(1 << (i % 32)) & 0xFFFFFFFF) for i in range(TEST_DATA_LEN)],
    ),
    TestData(
        name="random1",
        data=[random.randint(0, 0xFFFFFFFF) for _ in range(TEST_DATA_LEN)],
    ),
    TestData(
        name="random2",
        data=[random.randint(0, 0xFFFFFFFF) for _ in range(TEST_DATA_LEN)],
    ),
]

for test_data in test_datas:
    test_data.display()


## Bit Scramble

- 線形帰還シフトレジスタ生成値とのxor
- データの4byte単位での位置オフセットで乱数を進めているので、復号化に使用する値は事前に決定済

In [3]:
from typing import Generator, List, Optional

class Lfsr32:
    """Linear Feedback Shift Register"""

    def __init__(self, init_value: int = 1, taps: List[int] = [31, 30, 28, 10, 2,1]):
        self.init_value = init_value
        self.value = init_value
        self.taps = taps

    def __tap_bits__(self) -> int:
        return sum([1 << tap for tap in self.taps])

    def reset(self, init_value: Optional[int] = None):
        if init_value is not None:
            init_value = self.init_value
        self.value = self.init_value

    def next(self) -> int:
        tap_bits = self.__tap_bits__()
        self.value = ((self.value >> 1) ^ (-(self.value & 1) & tap_bits)) & 0xFFFFFFFF
        return self.value

    def __str__(self):
        return f"Lfsr32(init={self.init_value}, taps={self.taps})"


def scrample_datas(
    lfsr: Lfsr32,
    data: List[int],
) -> Generator[int, None, None]:
    """Scramble data with LFSR32"""
    lfsr.reset()
    return [d ^ lfsr.next() for d in data]


# test
lfsr = Lfsr32()
scramble_test_datas = [
    test_data.transform(lambda x: scrample_datas(lfsr, x), name=f"xor {lfsr}")
    for test_data in test_datas
]
for test_data in scramble_test_datas:
    test_data.display()

# decode & verify
descramble_test_datas = [
    test_data.transform(lambda x: scrample_datas(lfsr, x), name=f"xor {lfsr}")
    for test_data in scramble_test_datas
]
for test_data in descramble_test_datas:
    test_data.display()
for src, dst in zip(test_datas, descramble_test_datas):
    assert src.data == dst.data




## ECC

- https://www.akita-pu.ac.jp/system/elect/ins/kusakari/japanese/teaching/InfoTheo/2009/note/10.pdf
- https://prml.main.ist.hokudai.ac.jp/wp-content/uploads/2020/07/infth2020-13.pdf


- 整数: $m$ (=parity長)
- 符号長: $n=2^m-1$
- 情報数: $k=n-m$
- ハミング符号の検査行列: $H = [A I_m]$
  - H は m行 x n列
  - すべての列要素が0ではない
  - すべての列要素が相違
  - $[A I_m]$ は組織符号と呼ばれる表現。 $A$ は任意の行列、 $I_m$ は m行 x m列の単位行列
    - 生成行列を求める際に $A^T$ が必要なのでこの表現に直しておく
- 生成行列: $HG^T=GH^T=0$ を満たす行列
  - 組織符号の表現だと $G=[I_kA^T]$
- 符号化: 送信したいデータ$x$ と生成行列$G$の乗算を行う
- 受信符号: $Y=xG$, エラーがある場合は $Y=xG \oplus e_i$ 2bit errorの場合は $Y=xG \oplus e_i \oplus e_j$
- 検査行列で置き換え、またxorを分解: $YH^T=e_iH^T \oplus e_jH^T$
  - 受信符号と検査行列の転置の積 $S=YH^T$ を計算し、これが0なら $e_i=0$, $e_j=0$ となっておりエラーベクトルはなし。訂正不要
  - 非ゼロの場合検査行列の列要素で合致する列番号が訂正bitposに相当する
- 拡張ハミング符号を導入する。符号語全体をxorしたParity bitを1bit追加する
  - 復号には検査行列の右列にAll 0を、下行にAll 1を付与したものを使用する。2bit Errorの場合に検査行列に含まれないsyndromeが生成される

In [35]:
from dataclasses import dataclass
from enum import Enum, auto
import math
from typing import Optional
import numpy as np

class ErrorStatus(Enum):
    """Enum class for the result of the error correction."""
    NO_ERROR = auto()
    SINGLE_ERROR = auto()
    DOUBLE_ERROR = auto()

@dataclass
class DecodeResult:
    """Result class for the error correction."""
    
    status: ErrorStatus
    corrected_data: np.ndarray
    error_bitpos: Optional[int]

    def __str__(self):
        return f"CorrectResult(status={self.status}, error_bitpos={self.error_bitpos})"

    def __repr__(self):
        return self.__str__()

class HammingCode:
    """Hamming code class. This class can encode and decode data using Hamming code."""

    def __init__(self, m: int):
        assert m > 0, "m should be greater than 0"

        # parity len
        self.m = m
        # total len
        self.n = 2**m - 1
        # data len
        self.k = self.n - m
        # parity check matrix
        self.H = self.__create_check_matrix()
        # generator matrix
        self.G = self.__create_generator_matrix()

    def __str__(self):
        return f"HammingCode(m={self.m}, n={self.n}, k={self.k})"

    def __create_check_matrix(self) -> np.ndarray:
        """Calculate the parity check matrix H for a Hamming code with m bits."""

        # 組織符号 H = [A I_m] 形式で作成
        A = np.zeros((self.m, self.n-self.m), dtype=int)
        I_m = np.eye(self.m, dtype=int)

        # Aの各列に非ゼロかつ他の列(I_m含む)と重複しない要素を入れていく
        # I_mは単位行列のため、各bitposに1が立っている値(1,2,4,8,...)が入っていることが自明
        col = 0
        x = 0
        while col < self.k: # Aの列数分繰り返す. I_mは決定済みのため不要
            # xが0の場合はスキップ
            if x != 0:
                # XがI_mの列と重複しないかチェック
                check_i_m_list = list([(1 << bitpos) for bitpos in range(self.m) if (1 << bitpos) == x])
                if len(check_i_m_list) == 0:
                    # xは昇順にしていて重複しないので設定
                    for row in range(self.m):
                        A[row, col] = 1 if x  & (1 << row) else 0
                    col += 1
                    # print(f"[col={col}, x={x:X}] set x")
                else:
                    # print(f"[col={col}, x={x:X}] skip x is in I_m, {check_i_m_list}")
                    pass
            else:
                # print(f"[col={col}, x={x:X}] skip x=0")
                pass
            x += 1
            assert math.log2(x) < self.n, f"x={x:X} is too large"
        # A, I_mを結合
        H = np.concatenate([A, I_m], axis=1)

        return H


    def __create_generator_matrix(self) -> np.ndarray:
        """Calculate the generator matrix G from the parity check matrix H."""
        # Calculate G using the formula G = [I_k A^T]
        I_k = np.eye(self.k, dtype=int)
        A_T = self.H[:, :self.k].transpose()
        G = np.concatenate([I_k, A_T], axis=1)

        return G

    def total_len(self, use_exhamming: bool) -> int:
        """Return the total length of the encoded data."""
        return self.n if not use_exhamming else self.n + 1
    
    def encode(self, data: np.ndarray, use_exhamming: bool) -> np.ndarray:
        """Encode the input data using the generator matrix G."""
        encoded_data = np.dot(data, self.G) % 2
        # parity bit用に全要素のxorを追加
        if use_exhamming:
            parity_bit = np.sum(encoded_data, axis=1, keepdims=True) % 2
            encoded_data = np.concatenate([encoded_data, parity_bit], axis=1)
        return encoded_data


    def decode(self, data: np.ndarray, use_exhamming: bool = True) -> DecodeResult:
        """Decode the input data using the parity check matrix H."""

        # 拡張ハミング符号対応をいれる。一番右列にAll 0を追加後、一番下行にAll 1を追加
        H = self.H.copy()
        if use_exhamming:
            H = np.concatenate([H, np.zeros((self.m, 1), dtype=int)], axis=1)
            H = np.concatenate([H, np.ones((1, self.n+1), dtype=int)], axis=0) #+1はAll 0列を追加したため

        # Create the syndrome matrix S
        S = np.dot(data, H.T) % 2

        # Check if there is an error
        if np.all(S == 0):
            return DecodeResult(ErrorStatus.NO_ERROR, data, None)

        # Create the error vector e
        correct_pos = None
        e = np.zeros(self.total_len(use_exhamming), dtype=int)
        for i in range(self.total_len(use_exhamming)):
            # print(f"i/n={i}/{n}, S[0]={S[0]}, H[:,{i}]={H[:, i]}")
            if np.array_equal(S[0], H[:, i]):
                e[i] = 1
                correct_pos = i
        # print(f"e:\n{e}")

        # 拡張ハミング符号の場合でerrorbitを決定できなかった場合はdouble errorとする
        if correct_pos is None:
            assert use_exhamming, "correct_pos is None, but not use_exhamming"
            return DecodeResult(ErrorStatus.DOUBLE_ERROR, data, None)

        # Correct the data
        corrected_data = (data + e) % 2

        return DecodeResult(ErrorStatus.SINGLE_ERROR, corrected_data, correct_pos)


# Test the functions
m = 3
use_exhamming = True
ecc = HammingCode(m = 3)
print("============================================================================")
print(f"{ecc}")
print(f"use_exhamming:{use_exhamming}")
print(f"H:\n{ecc.H}")
print(f"G:\n{ecc.G}")

raw_data = np.array([[0, 1, 0, 0]])
encoded_data = ecc.encode(data=raw_data, use_exhamming=use_exhamming)
print("============================================================================")
print(f"raw_data    : {raw_data}")
print(f"encoded_data: {encoded_data}")

ret = ecc.decode(data=encoded_data, use_exhamming=use_exhamming)
print("============================================================================")
print(f"decoded_data (no error)  : {ret.corrected_data} ({ret})")
assert ret.status == ErrorStatus.NO_ERROR, f"ret.status={ret.status} != ErrorStatus.NO_ERROR"
assert np.array_equal(raw_data, ret.corrected_data[:,:raw_data.shape[1]]), f"{raw_data} != {ret.corrected_data}"

for biterror_pos in range(ecc.total_len(use_exhamming)):
    single_error_encoded_data = encoded_data.copy()
    single_error_encoded_data[0, biterror_pos] ^= 1
    ret = ecc.decode(data=single_error_encoded_data, use_exhamming=use_exhamming)
    print("============================================================================")
    print(f"single biterror bitpos={biterror_pos}")
    print(f"single_error_encoded_data: {single_error_encoded_data}")
    print(f"single_error_decoded_data: {ret.corrected_data} ({ret})")
    assert np.array_equal(raw_data, ret.corrected_data[:,:raw_data.shape[1]]), f"{raw_data} != {ret.corrected_data}"
    assert ret.status == ErrorStatus.SINGLE_ERROR, f"ret.status={ret.status} != ErrorStatus.SINGLE_ERROR"
    assert biterror_pos == ret.error_bitpos, f"bitpos={biterror_pos} != correct_pos={ret.error_bitpos}"

for biterror_pos in range(ecc.total_len(use_exhamming)):
    biterror_pos2 = (biterror_pos+3) % ecc.total_len(use_exhamming)
    double_error_encoded_data = encoded_data.copy()
    double_error_encoded_data[0, biterror_pos] ^= 1
    double_error_encoded_data[0, biterror_pos2] ^= 1
    ret = ecc.decode(data=double_error_encoded_data, use_exhamming=use_exhamming)
    print("============================================================================")
    print(f"double biterror bitpos={biterror_pos}, {biterror_pos2}")
    print(f"double_error_encoded_data: {double_error_encoded_data}")
    print(f"double_error_decoded_data: {ret.corrected_data} ({ret})")
    assert ret.status == ErrorStatus.DOUBLE_ERROR, f"ret.status={ret.status} != ErrorStatus.DOUBLE_ERROR"
    assert ret.error_bitpos is None, f"ret.error_bitpos={ret.error_bitpos} != None"


HammingCode(m=3, n=7, k=4)
use_exhamming:True
H:
[[1 1 0 1 1 0 0]
 [1 0 1 1 0 1 0]
 [0 1 1 1 0 0 1]]
G:
[[1 0 0 0 1 1 0]
 [0 1 0 0 1 0 1]
 [0 0 1 0 0 1 1]
 [0 0 0 1 1 1 1]]
raw_data    : [[0 1 0 0]]
encoded_data: [[0 1 0 0 1 0 1 1]]
decoded_data (no error)  : [[0 1 0 0 1 0 1 1]] (CorrectResult(status=ErrorStatus.NO_ERROR, error_bitpos=None))
single biterror bitpos=0
single_error_encoded_data: [[1 1 0 0 1 0 1 1]]
single_error_decoded_data: [[0 1 0 0 1 0 1 1]] (CorrectResult(status=ErrorStatus.SINGLE_ERROR, error_bitpos=0))
single biterror bitpos=1
single_error_encoded_data: [[0 0 0 0 1 0 1 1]]
single_error_decoded_data: [[0 1 0 0 1 0 1 1]] (CorrectResult(status=ErrorStatus.SINGLE_ERROR, error_bitpos=1))
single biterror bitpos=2
single_error_encoded_data: [[0 1 1 0 1 0 1 1]]
single_error_decoded_data: [[0 1 0 0 1 0 1 1]] (CorrectResult(status=ErrorStatus.SINGLE_ERROR, error_bitpos=2))
single biterror bitpos=3
single_error_encoded_data: [[0 1 0 1 1 0 1 1]]
single_error_decoded_data: [[0 1