In [None]:
import torch
import math

class Hamming:
    _k = 4
    _n = 7

    def __init__(self):
        self._P = torch.tensor([
            [1, 1, 1, 0],
            [1, 1, 0, 1],
            [1, 0, 1, 1]
        ])

        self._H = torch.cat([self._P, torch.eye(3)], dim=1).type(torch.long)

        self._C = torch.zeros([8, 4])
        err_bit = [-1, -1, -1, 3, -1, 2, 1, 0]
        for n, err in enumerate(err_bit):
            if err >= 0:
                self._C[n, err] = 1

    def encode(self, X:torch.Tensor):
        code_num = len(X)
        X_unfold = X.reshape([-1, self._k])
        A_unfold = (X_unfold @ self._P.T) % 2
        Y_unfold = torch.cat([X_unfold, A_unfold], dim=1)
        Y = Y_unfold.reshape([code_num, -1])
        return Y

    def decode(self, X:torch.Tensor):
        code_num = len(X)
        r = self._n - self._k
        X_unfold = X.reshape([-1, self._n])
        S_unfold = (X_unfold @ self._H.T) % 2
        S_unfold = S_unfold @ torch.pow(2, torch.arange(r, 0, -1) - 1)
        Y_unfold = torch.logical_xor(X_unfold[:,:4], self._C[S_unfold])
        Y = Y_unfold.reshape([code_num, -1])
        return Y.type(torch.long)


class PSK:

    def __init__(self, m, res):
        if m == 2:
            self._code = torch.tensor([0, 1])
        elif m == 4:
            self._code = torch.tensor([0, 1, 3, 2])
        elif m == 8:
            self._code = torch.tensor([0, 1, 3, 2, 6, 7, 5, 4])

        if hasattr(self, "_code"):
            self._m = m
            code_len = int(math.log2(self._m))
            weight = torch.pow(2, torch.arange(code_len, 0, -1) - 1)
            self._code_bin = torch.stack([code // weight % 2 for code in self._code])
            code_inv = torch.zeros(m, dtype=torch.long)
            code_inv[self._code] = torch.arange(m)
            self._theta_0 = 2*torch.pi / m
            theta = (code_inv - ((m-1) / 2)) * self._theta_0
            self._M = torch.stack([torch.cos(theta), torch.sin(theta)]).T
            self._res = res
            wt = torch.linspace(0, 2*torch.pi, res+1)[:-1]
            self._iq_wave = torch.stack([torch.cos(wt), torch.sin(wt)])

    def mod(self, X):
        code_num = len(X)
        code_len = int(math.log2(self._m))
        weight = torch.pow(2, torch.arange(code_len, 0, -1)-1)
        x = torch.matmul(X.reshape([-1, code_len]), weight)
        iq_amp = self._M[x]
        iq_signal = iq_amp @ self._iq_wave
        return iq_signal.reshape([code_num, -1])

    def demod(self, X):
        code_num = len(X)
        iq_amp = (X.reshape([-1, self._res]) @ self._iq_wave.T).T
        theta_hat = torch.atan2(iq_amp[1], iq_amp[0]) + torch.pi
        code_bin_hat = self._code_bin[(theta_hat / self._theta_0).type(torch.int)]
        return code_bin_hat.reshape([code_num, -1])


class QAM16:
    def __init__(self, res):
        self._m = 16
        # 16位格雷码
        code = torch.tensor([0, 1, 3, 2])
        self._code = torch.cat([4 * row + code for row in code]).type(torch.long)
        code_inv = torch.zeros([16], dtype=torch.long)
        code_inv[self._code] = torch.arange(16)
        tick = torch.linspace(-3, 3, 2, dtype=torch.long)
        x_tick = tick.reshape([-1, 1, 1]) * torch.ones([4, 4, 1])
        y_tick = tick.reshape([1, -1, 1]) * torch.ones([4, 4, 1])
        xy_tick = torch.cat([x_tick, y_tick], dim=2).reshape([-1, 2])
        self._M = xy_tick[code_inv]
        self._res = res
        wt = torch.linspace(0, 2*torch.pi, res+1)[:-1]
        self._iq_wave = torch.stack([torch.cos(wt), torch.sin(wt)])

    def mod(self, X):
        code_num = len(X)
        code_len = int(math.log2(self._m))
        weight = torch.pow(2, torch.arange(code_len, 0, -1)-1)
        x = torch.matmul(X.reshape([-1, code_len]), weight)
        iq_amp = self._M[x]
        iq_signal = iq_amp @ self._iq_wave
        return iq_signal.reshape([code_num, -1])
            

In [None]:
arr = torch.randint(0, 2, [32, 32])
print(f"input:{arr}")
hamming = Hamming()
arr_e = hamming.encode(arr)
print(f"encode:{arr_e}")
for i, index in enumerate(torch.randint(0, arr_e.shape[1], [len(arr_e)])):
    arr_e[i][index] = 1-arr_e[i][index]
print(f"err_code:{arr_e}")
arr_d = hamming.decode(arr_e)
print(f"accuracy:{(arr_d == arr).sum() / arr.numel()}")


In [None]:
import torch
import matplotlib.pyplot as plt
psk = PSK(4, 8)
bits = torch.randint(0, 2, [8, 2*4])
print(f"bits:{bits}")
sig = psk.mod(bits)
print(f"sig:{sig}")
demod = psk.demod(sig)
print(f"demod:{demod}")
print(f"accuracy:{(demod == bits).sum() / bits.numel()}")