# Decoding with Belief Propagation

We want to iterate the following the following equations:

$$ \hat{m}_{\mu j} = \tanh ( \beta(\rho) J_\mu ) \prod_{ l \in {\cal L} (\mu) \backslash j } m_{\mu l}  \; \; ,  $$ 

$$ m_{\mu j} = \tanh \left(  \sum_{ \nu \in {\cal M} (j) \backslash \mu } \ \tanh^{-1} \left(  \hat{m}_{\nu j}  \right)     \right) + \beta(\rho_\xi)  \; \; ,  $$  

which are Eqs.(96) from the paper [Low-density parity-check codes—A statistical physics perspective](https://www.sciencedirect.com/science/article/pii/S1076567002800180), by R. Vicente, D. Saad and Y. Kabashima.

The function $ \beta(x) $ is the Nishimori temperature,

$$ \beta(x) = \frac{1}{2} \log \left(  \frac{1- \rho}{\rho}  \right) \; \; .  $$

The quantity $\rho$ is the flip probablility of the noisy channel (BSC),

$$ P ( J | J^{(0)} ) = (1 - \rho) \delta_{J, J^{(0)} } + \rho \delta_{J, -J^{(0)} }   \; \; , $$

whereas the prior distribution for each meassage bit is assumed to be

$$ P ( S_j ) = (1 - \rho_\xi) \delta_{+1, S_j }  +  \rho_\xi \delta_{-1, S_j }  \; \; .  $$

The object $ {\cal L} (\mu)$ represents the set of $K$ non-zero elements on the row $\mu$ of the code generator matrix ${\cal G}$ (the one which adds redundancy), 

$$  {\cal L} (\mu) = \langle i_1, i_2, ..., i_K \rangle  \; \; .  $$

The are $C$ non-zero elements per column on the matrix ${\cal G}$:

$$ \sum_{\mu : j \in {\cal L}(\mu)} i_j = C \; \; ; \; \; \forall j = 1, ..., K \; \; . $$

The object $ {\cal M} (j)$ represents the set of all index sets that contain $j$.

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

Defining the variables of the problem:

In [2]:
# Message lengh
N = 100

# Codeword lengh
M = 200

# Non-zero elements per row of the generation matrix
K = 4

# Number of messages
n = 10

# Noisy channel
p = 0.3
beta = 0.5*np.log( (1 - p) / p)

# Message prior
p_prior = 0.1
beta_prior = 0.5*np.log( (1 - p_prior) / p_prior)

## Generating messages

Each message is a $N$ dimensional vector. Generate a set of $n$ messages.

In [3]:
random = torch.rand([n, N])
    
message = torch.zeros([n, N])

for j in range(random.shape[0]):
    
    for k in range(random.shape[1]):
               
            #### -1 with probability p_prior
            if random[j,k] <= p_prior:
                message[j,k] = -1.
            #### +1 with probability 1 - p_prior
            else:
                message[j,k] = +1.

In [4]:
message.shape

torch.Size([10, 100])

Each message is encoded to a high dimensional vector ${\bf J}^{(0)} \in \{ \pm 1  \}^M$ defined as 

$$   J^{(0)}_{\langle i_1, i_2, ...., i_K \rangle} = \xi_1 \xi_ 2 ... \xi_K  \; \; ,$$

where $M$ sets of $K \in [ 1, ..., N]$ indexes are randomly chosen.

In [112]:
encoding = torch.randint(0, N, [M, K])

In [113]:
encoding.shape

torch.Size([200, 4])

From `encoding`, we construct the encoded message ${\bf J}^{(0)}$.

In [7]:
for j in range(message.shape[0]):
    
    J0 = torch.take(message[j], encoding).prod(dim=1)

In [8]:
# Initializing
J0 = torch.take(message[0], encoding).prod(dim=1)
J0 = J0.unsqueeze(0)

for j in range(1, message.shape[0]):
    
    J0_ = torch.take(message[j], encoding).prod(dim=1)
    J0_ = J0_.unsqueeze(0)
    
    J0 = torch.cat((J0, J0_), dim= 0)

Now the corrupted version.

In [9]:
J = J0.clone()

In [10]:
random = torch.rand(J.shape)
                      
for j in range(J.shape[0]):
    for k in range(J.shape[1]):
          
        if random[j, k] <= p:
            J[j, k] = -J[j, k]

Let us focus in one received message to iterate the belief propagation equations.

$$ \hat{m}_{\mu j} = \tanh ( \beta(\rho) J_\mu ) \prod_{ l \in {\cal L} (\mu) \backslash j } m_{\mu l}  \; \; ,  $$ 

$$ m_{\mu j} = \tanh \left(  \sum_{ \nu \in {\cal M} (j) \backslash \mu} \ \tanh^{-1} \left(  \hat{m}_{\nu j}  \right)     \right) + \beta(\rho_\xi)  \; \; ,  $$  

with $j = 1, ..., N$ and $\mu = 1, ..., M$.

We cal this message `J_`. We will worry later about a loop over all the received messages.

In [11]:
J_ = J[0]
print(J_)
print(J_.shape)

tensor([ 1.,  1., -1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1.,  1.,  1.,
         1.,  1., -1., -1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1.,  1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,
         1.,  1., -1., -1.,  1., -1., -1., -1.,  1.,  1., -1.,  1.,  1., -1.,
         1.,  1., -1., -1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1., -1.,  1., -1., -1., -1., -1., -1., -1., -1.,  1.,
         1.,  1., -1.,  1.,  1.,  1.,  1., -1., -1.,  1.,  1., -1., -1., -1.,
         1.,  1., -1., -1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1., -1.,  1.,
         1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1., -1.,  1.,  1., -1., -1.,
         1., -1., -1.,  1., -1.,  1., -1.,  1., -1.,  1.,  1., -1.,  1., -1.,
         1., -1.,  1., -1.,  1., -1., -1.,  1.,  1., -1.,  1., -1.,  1.,  1.,
        -1., -1., -1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1., -1.,  1., -1., -1., -1., -1.,  1., -1., -1., -1., -

Random initialization of the beliefs $m_{\mu l}$.

In [153]:
m = torch.rand(M, N)

In [154]:
m

tensor([[0.5737, 0.3259, 0.1629,  ..., 0.5391, 0.5564, 0.4237],
        [0.8206, 0.2177, 0.8888,  ..., 0.0783, 0.1224, 0.3776],
        [0.6112, 0.9259, 0.9299,  ..., 0.8411, 0.8729, 0.6649],
        ...,
        [0.5682, 0.4077, 0.6779,  ..., 0.6039, 0.6576, 0.7261],
        [0.2614, 0.4347, 0.6711,  ..., 0.1953, 0.9119, 0.3932],
        [0.4544, 0.2197, 0.0443,  ..., 0.8243, 0.3393, 0.3847]])

Initialize an empty tensor to represent $\hat{m}_{\mu l}$.

In [155]:
m_hat = torch.empty(M, N)

In [156]:
m_hat

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

We want to calculate $\hat{m}_{\mu j}$.

$$ \hat{m}_{\mu j} = \tanh ( \beta(\rho) J_\mu ) \prod_{ l \in {\cal L} (\mu) \backslash j } m_{\mu l}  \; \; ,  $$ 

This first implementation has two `for` loops. This is potentially harmful if one cares about efficienty. We obviously do, but since we are just beginning, lets go on like this.

In [157]:
for mu in range(M):
    for j in range(N):
          
        # Keep only L(mu) which a are different of j
        index_no_j = torch.nonzero(encoding[mu] != j).squeeze()
        L_no_j = encoding[mu][index_no_j]
        
        # Message update        
        m_hat[mu, j] = torch.tanh( beta* J_[mu])*torch.take(m[mu], L_no_j).prod(dim=0)

In [158]:
m_hat

tensor([[ 5.2171e-06,  5.2171e-06,  5.2171e-06,  ...,  5.2171e-06,
          5.2171e-06,  5.2171e-06],
        [ 2.0069e-02,  2.0069e-02,  2.0069e-02,  ...,  2.0069e-02,
          2.0069e-02,  2.0069e-02],
        [-2.1374e-02, -2.1374e-02, -2.1374e-02,  ..., -2.1374e-02,
         -2.1374e-02, -2.1374e-02],
        ...,
        [ 1.1966e-03,  1.1966e-03,  1.1966e-03,  ...,  1.1966e-03,
          1.1966e-03,  1.1966e-03],
        [ 2.8509e-03,  2.8509e-03,  2.8509e-03,  ...,  2.8509e-03,
          2.8509e-03,  2.8509e-03],
        [ 1.1827e-02,  1.1827e-02,  1.1827e-02,  ...,  1.4347e-02,
          1.1827e-02,  1.1827e-02]])

The next step is to implement:

$$ m_{\mu j} = \tanh \left(  \sum_{ \nu \in {\cal M} (j) \backslash \mu} \ \tanh^{-1} \left(  \hat{m}_{\nu j}  \right)     \right) + \beta(\rho_\xi)  \; \; ,  $$ 

In [159]:
for j in range(N):
    
    for mu in range(M):
        
        M_set = torch.where(encoding == j)[0]
        
        M_set_no_mu = M_set[torch.nonzero(M_set != mu).squeeze()]
        
        #print(np.arctanh(m_hat[:, j]))
        
        m[mu, j] = torch.tanh(torch.take(np.arctanh(m_hat[:, j]), M_set_no_mu).sum() + beta_prior)

In [160]:
m

tensor([[0.8080, 0.6590, 0.7828,  ..., 0.9041, 0.6813, 0.7584],
        [0.8080, 0.6590, 0.7828,  ..., 0.9041, 0.6813, 0.7584],
        [0.8080, 0.6590, 0.7828,  ..., 0.9041, 0.6813, 0.7584],
        ...,
        [0.8080, 0.6590, 0.7828,  ..., 0.9041, 0.6813, 0.7584],
        [0.8080, 0.6590, 0.7828,  ..., 0.9041, 0.6813, 0.7584],
        [0.8080, 0.6590, 0.7828,  ..., 0.9014, 0.6813, 0.7584]])

In [117]:
torch.where(encoding == 12)[0]

tensor([ 16,  18,  38,  51,  59, 110, 142, 146, 149, 168, 169, 175, 185, 190])

In [143]:
m_hat0 = m[:, 0]
print(m_hat0)
print(m_hat0.shape)

tensor([-1.9573e+00,  1.3855e+00,  1.9028e-01, -1.7844e+00,  5.0827e-01,
         3.8991e-01, -3.3806e-02,  6.4215e-01, -1.0100e-01,  5.0615e-01,
        -3.0457e-01,  4.5092e-01,  6.5737e-01, -2.0985e+00,  8.0717e-01,
         1.1149e+00, -6.9266e-02,  7.7242e-02, -1.3143e+00, -1.1171e+00,
         3.2288e-01, -6.8386e-01, -8.5211e-01, -9.8985e-01,  1.2180e+00,
        -4.7351e-01, -1.2878e-01, -6.0639e-01, -2.6454e-01, -2.1705e+00,
        -1.4315e+00, -2.4212e+00,  6.0035e-01,  1.3754e+00,  6.0277e-01,
        -1.8355e+00,  4.7591e-01,  1.2254e+00,  2.3661e+00, -7.6653e-01,
        -1.5692e+00, -1.0605e+00, -6.7472e-01,  4.5668e-01, -1.6376e+00,
        -3.0720e-01, -5.6713e-01, -1.0771e-01, -1.8200e+00,  5.5905e-01,
         7.9153e-01,  2.1574e+00, -1.1527e+00,  6.0929e-01,  2.0709e+00,
        -6.2848e-01,  1.1242e+00,  6.9097e-01, -9.8604e-01,  3.9571e-01,
        -2.1559e+00,  1.5495e+00,  6.5942e-01, -6.2958e-01,  1.8700e+00,
         6.6519e-01,  1.1562e+00, -2.7643e-01, -6.4

In [127]:
M_ = torch.where(encoding == 0)[0]
M_

tensor([ 19,  25,  41,  77,  81, 110, 118, 135, 135, 142, 177, 192])

In [128]:
M_

tensor([ 19,  25,  41,  77,  81, 110, 118, 135, 135, 142, 177, 192])

In [129]:
torch.nonzero(M_ != 19).squeeze()

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [130]:
M_[torch.nonzero(M_ != 19).squeeze()]

tensor([ 25,  41,  77,  81, 110, 118, 135, 135, 142, 177, 192])

In [139]:
M_set_no = M_[torch.nonzero(M_ != 19).squeeze()]

In [140]:
M_set_no

tensor([ 25,  41,  77,  81, 110, 118, 135, 135, 142, 177, 192])

In [146]:
torch.take(m_hat[:, 0], M_set_no_mu).sum()

tensor(1.0440)

In [70]:
prod_ind

tensor([0, 1, 3])

In [72]:
L_less = encoding[0][prod_ind]

In [73]:
L_less

tensor([ 5, 98, 58])

In [76]:
print(m[0][5])
print(m[0][98])
print(m[0][58])

tensor(0.6863)
tensor(0.4078)
tensor(1.2403)


In [78]:
torch.take(m[0], L_less).prod(dim=0)

tensor(0.3471)

In [68]:
torch.nonzero(encoding[0] != 60).squeeze().shape

torch.Size([3])

In [57]:
encoding[0][2]

tensor(60)

In [58]:
encoding[0][torch.nonzero(encoding[0] != 60)]

tensor([[ 5],
        [98],
        [58]])

In [60]:
encoding[0][torch.nonzero(encoding[0] != 60)].shape

torch.Size([3, 1])

In [26]:
T

tensor([1, 2, 4, 5])

In [None]:
print(encoding[0])
print(m[0].shape)

In [20]:
encoding.shape

torch.Size([200, 4])

In [23]:
encoding[23]

tensor([66, 31, 39, 22])

In [15]:
print(J_)
print(J_.shape)

tensor([ 1.,  1., -1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1.,  1.,  1.,
         1.,  1., -1., -1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1., -1., -1., -1., -1.,  1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,
         1.,  1., -1., -1.,  1., -1., -1., -1.,  1.,  1., -1.,  1.,  1., -1.,
         1.,  1., -1., -1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1., -1.,  1., -1., -1., -1., -1., -1., -1., -1.,  1.,
         1.,  1., -1.,  1.,  1.,  1.,  1., -1., -1.,  1.,  1., -1., -1., -1.,
         1.,  1., -1., -1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1., -1.,  1.,
         1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1., -1.,  1.,  1., -1., -1.,
         1., -1., -1.,  1., -1.,  1., -1.,  1., -1.,  1.,  1., -1.,  1., -1.,
         1., -1.,  1., -1.,  1., -1., -1.,  1.,  1., -1.,  1., -1.,  1.,  1.,
        -1., -1., -1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1., -1.,  1., -1., -1., -1., -1.,  1., -1., -1., -1., -

In [16]:
prod_full = torch.take(m[0], encoding).prod(dim=1)

In [17]:
encoding == 0

tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [Fal

In [18]:
torch.nonzero(encoding == 0)

tensor([[ 28,   2],
        [ 91,   2],
        [119,   3],
        [130,   0],
        [168,   1],
        [180,   2],
        [191,   2]])

In [None]:
torch.nonzero(encoding == 0)

In [145]:
m_aux = m.clone()

In [120]:
m_aux

tensor([[-0.0793, -0.6300, -0.3786,  ..., -0.2133,  2.8556, -1.1283],
        [ 0.9498, -0.0605, -1.5593,  ..., -0.4778,  0.6944, -0.0537],
        [-0.9065,  1.2505, -0.1733,  ..., -1.1476, -1.1839,  0.6398],
        ...,
        [ 1.5321, -0.4942, -0.8504,  ..., -1.7777,  0.4258,  0.4261],
        [ 0.8070,  0.0048, -0.0686,  ..., -0.1955,  1.2301,  0.4021],
        [-0.1257, -2.2173,  1.3299,  ...,  0.8645, -0.8442,  0.8819]])

In [158]:
m_aux.shape

torch.Size([200, 100])

In [121]:
exclude = torch.nonzero(encoding == 0)

In [122]:
exclude

tensor([[ 43,   1],
        [ 56,   3],
        [ 79,   0],
        [ 79,   3],
        [ 82,   1],
        [ 96,   3],
        [137,   0],
        [152,   0],
        [176,   1],
        [187,   1]])

In [162]:
for j in range(1):
    
    exclude = torch.nonzero(encoding == j)
    
    for k in range(exclude.shape[0]):
        
        exc = exclude[k]
        
        
        z = m_aux[exc[0]].scatter_(0, exc[1:], 1000)
    

In [163]:
z

tensor([ 1.0000e+03,  1.0000e+03,  1.0000e+03,  1.0000e+03,  3.4074e-01,
        -2.6404e-01, -9.2077e-01, -1.4677e-02,  1.0955e+00, -5.1544e-01,
        -8.7062e-01,  7.6919e-01,  2.8699e-01,  9.6886e-01,  1.7844e-02,
        -6.2256e-01, -1.1709e+00,  7.6046e-01,  5.5965e-01, -1.1070e-01,
         6.0074e-01, -4.8264e-01, -1.9182e-01, -3.7910e-01, -5.9526e-01,
        -9.5433e-02, -3.3138e-01,  6.0327e-01, -1.1608e+00, -7.3140e-01,
         1.6031e-02,  9.5363e-01,  1.4151e+00,  8.6919e-01,  1.0650e+00,
         3.5212e-01,  8.8221e-01,  8.4107e-01,  1.1909e+00, -2.7354e-01,
         7.8464e-01,  6.9084e-01,  1.0274e-01, -3.2291e-01, -8.9670e-01,
         3.9512e-01,  2.0630e-02,  7.4143e-01, -1.1805e+00, -1.9506e+00,
         9.7839e-01,  1.3199e+00, -2.1029e-01,  9.7971e-01, -3.9878e-02,
         9.6507e-01,  2.9632e-01, -1.2175e+00,  1.4639e+00,  2.0341e-01,
        -1.4287e+00, -1.2803e-01, -1.4928e+00,  2.0406e+00,  2.4494e-01,
        -7.2996e-01,  3.7852e-01,  1.3450e+00,  6.3

In [156]:
exclude.shape[0]

10

In [157]:
exclude[5]

tensor([96,  3])

In [126]:
exclude.shape

torch.Size([10, 2])

In [142]:
exclude[0][0]

tensor(43)

In [127]:
m_aux.shape

torch.Size([200, 100])

In [149]:
m_aux[[[exclude[0]]]]

ValueError: only one element tensors can be converted to Python scalars

In [136]:
m_aux.shape

torch.Size([200, 100])

In [141]:
m_aux[[43,1]]

tensor([[ 0.0302, -0.1636, -1.7418, -0.1757,  0.3791, -0.8397, -0.7856,  1.5271,
         -0.4584,  0.1491,  1.0438,  0.4929,  0.1003,  1.3059,  0.4218,  1.0057,
         -1.2351, -0.1522,  0.9724,  0.4100,  0.3694,  0.3277, -0.5289, -0.4909,
          0.3446, -1.8066,  1.0155,  0.8825, -0.2777, -1.1976,  0.2885,  0.2092,
         -0.8086, -0.3490, -0.1686,  0.1132, -0.7963,  0.5539, -0.6405,  0.7526,
         -1.3016, -0.1355,  0.1975, -1.7613, -0.6621, -1.1961, -0.0089,  0.2253,
          0.0095, -0.6792, -1.4876, -0.4274,  0.0210,  0.0621, -1.9757, -0.9194,
         -0.9840,  0.9646, -0.4086,  0.4658,  1.1966,  0.6169,  0.8855, -0.6732,
          0.3679,  2.0794,  0.0585, -0.2454, -0.5907,  0.7409, -1.5762, -1.0074,
          0.0781,  0.7580,  0.6308,  0.2033, -0.8695, -0.3131, -0.1380,  1.5031,
         -0.4139,  0.3099,  0.8547, -2.8883, -0.2570,  0.2521, -0.8031, -0.0739,
         -1.1229, -0.3560,  0.5832,  2.4520, -1.0634,  1.5417,  0.0282,  1.9118,
          1.1215,  0.4459,  

In [114]:
encoding == 0

tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [Fal

In [111]:
m[0]

tensor([-7.9326e-02, -6.2997e-01, -3.7863e-01,  2.9347e-01, -4.6815e-01,
        -1.2414e-03, -3.9178e-02,  1.1563e+00,  1.4654e+00, -2.7126e-01,
        -1.8753e+00,  7.5393e-01,  1.6476e-01, -1.5836e+00,  5.6198e-01,
         2.5038e-01,  3.1218e-01,  1.0094e+00,  1.2251e+00,  1.9295e-01,
        -1.5820e-01,  8.2826e-01, -1.5143e+00,  3.7592e-01, -1.2277e-01,
        -1.4465e+00, -1.3537e+00,  7.7615e-01, -5.9863e-01,  2.5851e-03,
        -6.4988e-01, -5.5231e-01,  1.6726e-02,  5.1850e-01,  8.4310e-01,
        -9.1750e-01,  1.5605e-01, -8.0841e-01, -1.0547e+00,  6.9209e-01,
         3.0936e-01,  5.4713e-01, -9.7568e-01, -1.7187e+00, -1.5435e+00,
         4.0454e-02, -7.6023e-01, -2.3817e-01,  1.1795e+00,  1.0296e+00,
         2.0394e-01,  1.1942e+00,  8.0457e-01,  3.0606e-01, -6.8432e-01,
        -4.9388e-03, -8.1257e-01, -1.9065e-01,  6.7450e-02,  7.8054e-01,
         1.5632e-01,  6.7083e-01, -1.8288e+00, -1.3203e+00, -2.0107e+00,
         5.3763e-01, -4.7031e-01,  1.3088e+00, -7.3

In [None]:
m[0, k]

In [108]:
torch.take(m[0], encoding).shape

torch.Size([200, 4])

In [109]:
torch.take(m[0], encoding)

tensor([[ 5.1850e-01, -2.7126e-01,  1.6726e-02, -2.3817e-01],
        [-9.3843e-02,  3.5888e-01,  1.1795e+00, -1.0492e-01],
        [ 5.1850e-01, -9.7568e-01,  8.0457e-01,  6.9209e-01],
        [-7.6023e-01,  9.8000e-01, -1.4465e+00,  1.1942e+00],
        [-1.8753e+00,  3.7592e-01, -1.2233e+00, -7.6023e-01],
        [ 3.0936e-01,  1.1795e+00, -1.3537e+00,  5.5517e-01],
        [-8.1257e-01,  5.3763e-01, -1.6398e-01,  7.5393e-01],
        [-1.7768e+00, -1.3537e+00,  1.2597e-01,  3.7592e-01],
        [-3.9178e-02,  2.9347e-01, -1.8288e+00,  1.0094e+00],
        [ 1.2251e+00,  1.0296e+00,  1.1563e+00,  1.6726e-02],
        [ 3.0936e-01,  1.1563e+00,  7.7615e-01, -1.7768e+00],
        [-9.7568e-01, -1.3203e+00,  1.6476e-01,  1.5530e+00],
        [ 1.3088e+00,  5.5517e-01, -1.2414e-03,  2.5038e-01],
        [-5.5231e-01,  1.0374e+00, -1.2233e+00, -1.2233e+00],
        [-1.6398e-01, -8.0841e-01, -9.3843e-02,  1.2597e-01],
        [-7.3765e-02, -2.7126e-01,  1.4654e+00, -2.0531e-01],
        

In [106]:
encoding

tensor([[33,  9, 32, 47],
        [79, 80, 48, 96],
        [33, 42, 52, 39],
        [46, 71, 25, 51],
        [10, 23, 72, 46],
        [40, 48, 26, 94],
        [56, 65, 75, 11],
        [84, 26, 82, 23],
        [ 6,  3, 62, 17],
        [18, 49,  7, 32],
        [40,  7, 27, 84],
        [42, 63, 12, 83],
        [67, 94,  5, 15],
        [31, 89, 72, 72],
        [75, 37, 79, 82],
        [68,  9,  8, 86],
        [72, 86, 40, 84],
        [95, 60, 30, 14],
        [22, 42,  9, 15],
        [51, 85, 20, 76],
        [92, 14, 12, 78],
        [16, 80, 50, 41],
        [19, 42, 77, 80],
        [62,  4, 77, 42],
        [34, 77, 94, 88],
        [67, 85, 62, 45],
        [37, 86, 88, 39],
        [54, 47, 13, 67],
        [38, 89, 17, 50],
        [86, 47, 96, 23],
        [64, 93, 46,  4],
        [67,  5,  8, 80],
        [ 5, 14, 74, 64],
        [90, 68, 19, 65],
        [22,  3, 68, 89],
        [39, 68, 43, 37],
        [55, 21, 67, 13],
        [69, 46, 51, 91],
        [20,

In [93]:
torch.tanh(beta*J_)

tensor([-0.4000, -0.4000,  0.4000,  0.4000,  0.4000,  0.4000, -0.4000, -0.4000,
         0.4000, -0.4000,  0.4000,  0.4000, -0.4000, -0.4000,  0.4000, -0.4000,
         0.4000, -0.4000, -0.4000, -0.4000,  0.4000,  0.4000,  0.4000,  0.4000,
        -0.4000,  0.4000,  0.4000,  0.4000,  0.4000,  0.4000,  0.4000, -0.4000,
         0.4000, -0.4000, -0.4000, -0.4000,  0.4000,  0.4000,  0.4000,  0.4000,
        -0.4000,  0.4000,  0.4000, -0.4000,  0.4000, -0.4000,  0.4000, -0.4000,
        -0.4000,  0.4000,  0.4000, -0.4000, -0.4000, -0.4000,  0.4000,  0.4000,
         0.4000, -0.4000, -0.4000,  0.4000,  0.4000,  0.4000,  0.4000, -0.4000,
        -0.4000,  0.4000,  0.4000, -0.4000,  0.4000,  0.4000, -0.4000,  0.4000,
        -0.4000,  0.4000,  0.4000,  0.4000, -0.4000, -0.4000,  0.4000, -0.4000,
         0.4000, -0.4000,  0.4000,  0.4000,  0.4000,  0.4000,  0.4000, -0.4000,
         0.4000,  0.4000,  0.4000,  0.4000, -0.4000,  0.4000,  0.4000, -0.4000,
         0.4000, -0.4000,  0.4000, -0.40

In [94]:
m[0]

tensor([-7.9326e-02, -6.2997e-01, -3.7863e-01,  2.9347e-01, -4.6815e-01,
        -1.2414e-03, -3.9178e-02,  1.1563e+00,  1.4654e+00, -2.7126e-01,
        -1.8753e+00,  7.5393e-01,  1.6476e-01, -1.5836e+00,  5.6198e-01,
         2.5038e-01,  3.1218e-01,  1.0094e+00,  1.2251e+00,  1.9295e-01,
        -1.5820e-01,  8.2826e-01, -1.5143e+00,  3.7592e-01, -1.2277e-01,
        -1.4465e+00, -1.3537e+00,  7.7615e-01, -5.9863e-01,  2.5851e-03,
        -6.4988e-01, -5.5231e-01,  1.6726e-02,  5.1850e-01,  8.4310e-01,
        -9.1750e-01,  1.5605e-01, -8.0841e-01, -1.0547e+00,  6.9209e-01,
         3.0936e-01,  5.4713e-01, -9.7568e-01, -1.7187e+00, -1.5435e+00,
         4.0454e-02, -7.6023e-01, -2.3817e-01,  1.1795e+00,  1.0296e+00,
         2.0394e-01,  1.1942e+00,  8.0457e-01,  3.0606e-01, -6.8432e-01,
        -4.9388e-03, -8.1257e-01, -1.9065e-01,  6.7450e-02,  7.8054e-01,
         1.5632e-01,  6.7083e-01, -1.8288e+00, -1.3203e+00, -2.0107e+00,
         5.3763e-01, -4.7031e-01,  1.3088e+00, -7.3

In [95]:
encoding

tensor([[33,  9, 32, 47],
        [79, 80, 48, 96],
        [33, 42, 52, 39],
        [46, 71, 25, 51],
        [10, 23, 72, 46],
        [40, 48, 26, 94],
        [56, 65, 75, 11],
        [84, 26, 82, 23],
        [ 6,  3, 62, 17],
        [18, 49,  7, 32],
        [40,  7, 27, 84],
        [42, 63, 12, 83],
        [67, 94,  5, 15],
        [31, 89, 72, 72],
        [75, 37, 79, 82],
        [68,  9,  8, 86],
        [72, 86, 40, 84],
        [95, 60, 30, 14],
        [22, 42,  9, 15],
        [51, 85, 20, 76],
        [92, 14, 12, 78],
        [16, 80, 50, 41],
        [19, 42, 77, 80],
        [62,  4, 77, 42],
        [34, 77, 94, 88],
        [67, 85, 62, 45],
        [37, 86, 88, 39],
        [54, 47, 13, 67],
        [38, 89, 17, 50],
        [86, 47, 96, 23],
        [64, 93, 46,  4],
        [67,  5,  8, 80],
        [ 5, 14, 74, 64],
        [90, 68, 19, 65],
        [22,  3, 68, 89],
        [39, 68, 43, 37],
        [55, 21, 67, 13],
        [69, 46, 51, 91],
        [20,

In [96]:
torch.take(m[0], encoding)

tensor([[ 5.1850e-01, -2.7126e-01,  1.6726e-02, -2.3817e-01],
        [-9.3843e-02,  3.5888e-01,  1.1795e+00, -1.0492e-01],
        [ 5.1850e-01, -9.7568e-01,  8.0457e-01,  6.9209e-01],
        [-7.6023e-01,  9.8000e-01, -1.4465e+00,  1.1942e+00],
        [-1.8753e+00,  3.7592e-01, -1.2233e+00, -7.6023e-01],
        [ 3.0936e-01,  1.1795e+00, -1.3537e+00,  5.5517e-01],
        [-8.1257e-01,  5.3763e-01, -1.6398e-01,  7.5393e-01],
        [-1.7768e+00, -1.3537e+00,  1.2597e-01,  3.7592e-01],
        [-3.9178e-02,  2.9347e-01, -1.8288e+00,  1.0094e+00],
        [ 1.2251e+00,  1.0296e+00,  1.1563e+00,  1.6726e-02],
        [ 3.0936e-01,  1.1563e+00,  7.7615e-01, -1.7768e+00],
        [-9.7568e-01, -1.3203e+00,  1.6476e-01,  1.5530e+00],
        [ 1.3088e+00,  5.5517e-01, -1.2414e-03,  2.5038e-01],
        [-5.5231e-01,  1.0374e+00, -1.2233e+00, -1.2233e+00],
        [-1.6398e-01, -8.0841e-01, -9.3843e-02,  1.2597e-01],
        [-7.3765e-02, -2.7126e-01,  1.4654e+00, -2.0531e-01],
        

In [97]:
torch.take(m[0], encoding).shape

torch.Size([200, 4])

In [98]:
\prod_{ l \in {\cal L} (\mu) \ j } m_{\mu j}

tensor([ 5.6028e-04,  4.1675e-03, -2.8169e-01,  1.2870e+00, -6.5559e-01,
        -2.7422e-01,  5.4010e-02,  1.1390e-01,  2.1225e-02,  2.4396e-02,
        -4.9330e-01,  3.2960e-01, -2.2585e-04, -8.5741e-01, -1.5671e-03,
        -6.0202e-03, -1.3805e-01,  5.1098e-03, -1.0035e-01, -1.1268e-02,
         3.0792e-02,  1.2501e-02, -1.6638e-02, -2.0572e-01,  3.9136e-02,
        -8.6092e-02,  3.9000e-02, -3.3781e-01, -2.2525e-01, -1.9286e-03,
        -9.1374e-01, -8.5449e-04, -9.6634e-04, -1.4092e-03,  3.4008e-02,
        -7.0933e-02,  8.4785e-03,  2.5474e+00,  4.1546e-03, -1.5749e-01,
         2.1376e-04,  4.7780e-03,  3.4120e-03,  6.4099e-03, -1.1395e-04,
         3.1736e-02, -2.4003e-01,  3.0513e-01, -5.4056e-02, -2.6736e-02,
         5.5885e-02,  4.2292e-01, -2.8587e-01,  9.4316e-05, -5.8276e-02,
         2.3085e-02,  2.1555e-02,  6.0974e-02, -9.8555e-03, -6.1514e-05,
        -6.6597e-01, -1.3281e+00,  8.4027e-04,  4.0690e-03,  4.8479e-01,
         1.5165e-03,  7.3134e-03,  9.1457e-02, -5.1

In [99]:
torch.take(m[0], encoding).prod(dim=1).shape

torch.Size([200])

In [103]:
encoding

tensor([[33,  9, 32, 47],
        [79, 80, 48, 96],
        [33, 42, 52, 39],
        [46, 71, 25, 51],
        [10, 23, 72, 46],
        [40, 48, 26, 94],
        [56, 65, 75, 11],
        [84, 26, 82, 23],
        [ 6,  3, 62, 17],
        [18, 49,  7, 32],
        [40,  7, 27, 84],
        [42, 63, 12, 83],
        [67, 94,  5, 15],
        [31, 89, 72, 72],
        [75, 37, 79, 82],
        [68,  9,  8, 86],
        [72, 86, 40, 84],
        [95, 60, 30, 14],
        [22, 42,  9, 15],
        [51, 85, 20, 76],
        [92, 14, 12, 78],
        [16, 80, 50, 41],
        [19, 42, 77, 80],
        [62,  4, 77, 42],
        [34, 77, 94, 88],
        [67, 85, 62, 45],
        [37, 86, 88, 39],
        [54, 47, 13, 67],
        [38, 89, 17, 50],
        [86, 47, 96, 23],
        [64, 93, 46,  4],
        [67,  5,  8, 80],
        [ 5, 14, 74, 64],
        [90, 68, 19, 65],
        [22,  3, 68, 89],
        [39, 68, 43, 37],
        [55, 21, 67, 13],
        [69, 46, 51, 91],
        [20,