# CCS-AMP for unsourced multiple access

This notebook contains CCS-AMP encoder/decoder for unsourced multiple access using Hadamard design matrices.

The code is based on the following articles:

A coded compressed sensing scheme for uncoordinated multiple access, available @ https://arxiv.org/pdf/1809.04745.pdf

SPARCs for Unsourced Random Access, available @ https://arxiv.org/abs/1901.06234


## Requirements

We just need a modern numpy.

In [1]:
import numpy as np
import ipdb

In [2]:
def fht(u):
    """
    Perform fast Hadamard transform of u, in-place.
    Note len(u) must be a power of two.
    """
    N = len(u)
    i = N>>1
    while i:
        for j in range(N):
            if (i&j) == 0:
                temp = u[j]
                u[j] += u[i|j]
                u[i|j] = temp - u[i|j]
        i>>= 1

def sub_fht(n, m, seed=0, ordering=None, new_embedding=False):
    """
    Returns functions to compute the sub-sampled Walsh-Hadamard transform,
    i.e., operating with a wide rectangular matrix of random +/-1 entries.

    n: number of rows
    m: number of columns

    It is most efficient (but not required) for max(m,n+1) to be a power of 2.

    seed: determines choice of random matrix
    ordering: optional n-long array of row indices in [1, max(m,n)] to
              implement subsampling; generated by seed if not specified,
              but may be given to speed up subsequent runs on the same matrix.

    Returns (Ax, Ay, ordering):
        Ax(x): computes A.x (of length n), with x having length m
        Ay(y): computes A'.y (of length m), with y having length n
        ordering: the ordering in use, which may have been generated from seed
    """
    assert n > 0, "n must be positive"
    assert m > 0, "m must be positive"
    if new_embedding:
        w = 2**int(np.ceil(np.log2(max(m+1, n+1))))
    else:
        w = 2**int(np.ceil(np.log2(max(m, n+1))))

    if ordering is not None:
        assert ordering.shape == (n,)
    else:
        rng = np.random.RandomState(seed)
        idxs = np.arange(1, w, dtype=np.uint32)
        rng.shuffle(idxs)
        ordering = idxs[:n]

    def Ax(x):
        assert x.size == m, "x must be m long"
        y = np.zeros(w)
        if new_embedding:
            y[w-m:] = x.reshape(m)
        else:
            y[:m] = x.reshape(m)
        fht(y)
        return y[ordering]

    def Ay(y):
        assert y.size == n, "input must be n long"
        x = np.zeros(w)
        x[ordering] = y.reshape(n)
        fht(x)
        if new_embedding:
            return x[w-m:]
        else:
            return x[:m]

    return Ax, Ay, ordering

def block_sub_fht(n, m, l, seed=0, ordering=None, new_embedding=False):
    """
    As `sub_fht`, but computes in `l` blocks of size `n` by `m`, potentially
    offering substantial speed improvements.

    n: number of rows
    m: number of columns per block
    l: number of blocks

    It is most efficient (though not required) when max(m,n+1) is a power of 2.

    seed: determines choice of random matrix
    ordering: optional (l, n) shaped array of row indices in [1, max(m, n)] to
              implement subsampling; generated by seed if not specified, but
              may be given to speed up subsequent runs on the same matrix.

    Returns (Ax, Ay, ordering):
        Ax(x): computes A.x (of length n), with x having length l*m
        Ay(y): computes A'.y (of length l*m), with y having length n
        ordering: the ordering in use, which may have been generated from seed
    """
    assert n > 0, "n must be positive"
    assert m > 0, "m must be positive"
    assert l > 0, "l must be positive"

    if ordering is not None:
        assert ordering.shape == (l, n)
    else:
        if new_embedding:
            w = 2**int(np.ceil(np.log2(max(m+1, n+1))))
        else:
            w = 2**int(np.ceil(np.log2(max(m, n+1))))
        rng = np.random.RandomState(seed)
        ordering = np.empty((l, n), dtype=np.uint32)
        idxs = np.arange(1, w, dtype=np.uint32)
        for ll in range(l):
            rng.shuffle(idxs)
            ordering[ll] = idxs[:n]

    def Ax(x):
        assert x.size == l*m
        out = np.zeros(n)
        for ll in range(l):
            ax, ay, _ = sub_fht(n, m, ordering=ordering[ll],
                                new_embedding=new_embedding)
            out += ax(x[ll*m:(ll+1)*m])
        return out

    def Ay(y):
        assert y.size == n
        out = np.empty(l*m)
        for ll in range(l):
            ax, ay, _ = sub_fht(n, m, ordering=ordering[ll],
                                new_embedding=new_embedding)
            out[ll*m:(ll+1)*m] = ay(y)
        return out

    return Ax, Ay, ordering

## Fast Hadamard Transforms

This code can all be found in `pyfht`, which uses a C extension to speed up the fht function. To make this notebook self contained, it's reproduced entirely in Python here, which will be quite slow!

Skip to the next section if you're not interested in the specific transform implementation.

In [3]:
# If you have pyfht installed (via pip etc) you can import it here
# to use its C-accelerated transform
#from pyfht import block_sub_fht

# Parity Generator Matrix
This function builds the binary parity generator matrix for the outer code.

In [4]:
def generate_parity_matrix(L,messageLengthVector,parityLengthVector):
    # Generate a full matrix, use only the portion needed for tree code
    G = []
    for i in range(1,L):
        Gp = np.random.randint(2,size=(np.sum(messageLengthVector[0:i]),parityLengthVector[i])).tolist()
        G.append(Gp)
    return np.asarray(G)
        

# Outer Tree encoder

This function encodes the payloads corresponding to users into codewords from the specified tree code. 

Parity bits in section $i$ are generated based on the message bits from all the previous sections $[0:i-1]$.

In [5]:
def Tree_encode(tx_message,K,G,L,J,P,Ml,messageLengthVector,parityLengthVector):
    encoded_tx_message = np.zeros((K,Ml+P),dtype=int)
    encoded_tx_message[:,0:messageLengthVector[0]] = tx_message[:,0:messageLengthVector[0]]
    for i in range(1,L):
        ParityInteger=np.zeros((K,1),dtype='int')
        G1=G[i-1]
        for j in range(1,i+1):
            ParityBinary = np.mod(np.matmul(tx_message[:,np.sum(messageLengthVector[0:j-1]):np.sum(messageLengthVector[0:j])],
                                G1[np.sum(messageLengthVector[0:j-1]):np.sum(messageLengthVector[0:j])]),2)
            # Convert into decimal equivalent\n",
            ParityInteger1 = ParityBinary.dot(2**np.arange(ParityBinary.shape[1])[::-1]).reshape([K,1])
            ParityInteger = np.mod(ParityInteger+ParityInteger1,2**parityLengthVector[i])
        # Convert integer parity back into bit    \n",
        Parity = np.array([list(np.binary_repr(int(x),parityLengthVector[i])) for x in ParityInteger], dtype=int)
        encoded_tx_message[:,i*J:i*J+messageLengthVector[i]] = tx_message[:,np.sum(messageLengthVector[0:i]):np.sum(messageLengthVector[0:i+1])]
        # Embed Parity check bits\n",
        encoded_tx_message[:,i*J+messageLengthVector[i]:(i+1)*J] = Parity
    
    return encoded_tx_message

This function converts message sequence into $L$-sparse vectors of length $L 2^J$.

In [6]:
def convert_bits_to_sparse(encoded_tx_message,L,J,K):
    encoded_tx_message_sparse=np.zeros((L*2**J,1),dtype=int)
    for i in range(L):
        A = encoded_tx_message[:,i*J:(i+1)*J]
        B = A.dot(2**np.arange(J)[::-1]).reshape([K,1])
        np.add.at(encoded_tx_message_sparse, i*2**J+B, 1)        
    return encoded_tx_message_sparse

This function reurns the bit representation corresponding to a SPARC-like vector.

In [7]:
def convert_sparse_to_bits(cs_decoded_tx_message_sparse,L,J,listSize):
    cs_decoded_tx_message = np.zeros((listSize,L*J),dtype=int)
    for i in range(L):
        A = cs_decoded_tx_message_sparse[i*2**J:(i+1)*2**J]
        idx = (A.reshape(2**J,)).argsort()[np.arange(2**J-listSize)]
        B = np.setdiff1d(np.arange(2**J),idx)
        C = np.empty(shape=(0,0),dtype=int)
        for j in B:
            C = np.hstack((C,np.array([j],dtype=int))) if C.size else np.array([j],dtype=int)
        cs_decoded_tx_message[:,i*J:(i+1)*J]=np.array([list(np.binary_repr(int(x),J)) for x in C], dtype=int)    
    return cs_decoded_tx_message

Extract information bits from retained paths in the tree.

In [8]:
def extract_msg_bits(Paths,cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector):
    msg_bits = np.empty(shape=(0,0))
    L1 = Paths.shape[0]
    for i in range(L1):
        msg_bit=np.empty(shape=(0,0))
        path = Paths[i].reshape(1,-1)
        for j in range(path.shape[1]):
            msg_bit = np.hstack((msg_bit,cs_decoded_tx_message[path[0,j],J*j:J*j+messageLengthVector[j]].reshape(1,-1))) if msg_bit.size else cs_decoded_tx_message[path[0,j],J*(j):J*(j)+messageLengthVector[j]]
            msg_bit=msg_bit.reshape(1,-1)
        msg_bits = np.vstack((msg_bits,msg_bit)) if msg_bits.size else msg_bit           
    return msg_bits

This function returns the possible parity check bits for the next section, given a path in the tree.

In [9]:
def compute_permissible_parity(Path,cs_decoded_tx_message,G1,L,J,parityLengthVector,messageLengthvector):
    msg_bits = extract_msg_bits(Path,cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector)
    Lpath = Path.shape[1]
    Parity_computed_integer = 0
    for i in range(Lpath):
        ParityBinary = np.mod(np.matmul(msg_bits[:,np.sum(messageLengthVector[0:i]):np.sum(messageLengthVector[0:i+1])],
                            G1[np.sum(messageLengthVector[0:i]):np.sum(messageLengthVector[0:i+1])]),2)
        ParityBinary=ParityBinary.reshape(1,-1)
        # Convert into decimal equivalent\n",
        ParityInteger1 = ParityBinary.dot(2**np.arange(ParityBinary.shape[1])[::-1])
        Parity_computed_integer = np.mod(Parity_computed_integer+ParityInteger1,2**parityLengthVector[Lpath])        
         
    Parity_computed = np.array([list(np.binary_repr(int(x),parityLengthVector[Lpath])) for x in Parity_computed_integer], dtype=int)
    return Parity_computed

Verify the parity check constraints for a section

In [10]:
def parity_check(Parity_computed,Path,k,cs_decoded_tx_message,L,J,parityLengthVector,messageLengthvector):
    index=0
    Lpath = Path.shape[1]
    Parity = cs_decoded_tx_message[k,Lpath*J+messageLengthvector[Lpath]:(Lpath+1)*J]
    if (np.sum(np.absolute(Parity_computed-Parity)) == 0):
        index = 1
    
    return index

Check if multiple paths output by the tree decoder are all comprosed of same information bits

In [11]:
def check_if_identical_msgs(Paths, cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector):   
    msg_bits = extract_msg_bits(Paths,cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector)
    flag = (msg_bits == msg_bits[0]).all()    
    return flag

# Outer Tree decoder

This function implements the tree deocoder proposed in the paper "A coded compressed sensing scheme for uncoordinated multiple access".

In [12]:
def Tree_decoder(cs_decoded_tx_message,G,L,J,B,parityLengthVector,messageLengthvector,listSize):
    tree_decoded_tx_message = np.empty(shape=(0,0))
    for i in range(listSize):
        Paths = np.array([[i]])
        for l in range(1,L):
            # Grab the parity generator matrix corresponding to this section
            G1 = G[l-1]
            new=np.empty( shape=(0,0))
            for j in range(Paths.shape[0]):
                Path=Paths[j].reshape(1,-1)
                # Compute the permissible parity check bits for the section
                Parity_computed = compute_permissible_parity(Path,cs_decoded_tx_message,G1,L,J,parityLengthVector,messageLengthvector)
                for k in range(listSize):
                    # Verify parity constraints for the children of surviving path
                    index = parity_check(Parity_computed,Path,k,cs_decoded_tx_message,L,J,parityLengthVector,messageLengthvector)
                    # If parity constraints are satisfied, update the path
                    if index:
                        new = np.vstack((new,np.hstack((Path.reshape(1,-1),np.array([[k]]))))) if new.size else np.hstack((Path.reshape(1,-1),np.array([[k]])))
            Paths = new 
        if Paths.shape[0] >= 2:
            # If tree decoder outputs multiple paths for a root node, select the first one 
            flag = check_if_identical_msgs(Paths, cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector)
            if flag:
                tree_decoded_tx_message = np.vstack((tree_decoded_tx_message,extract_msg_bits(Paths[0].reshape(1,-1),cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector))) if tree_decoded_tx_message.size else extract_msg_bits(Paths[0].reshape(1,-1),cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector)
            else:
                tree_decoded_tx_message = np.vstack((tree_decoded_tx_message,extract_msg_bits(Paths.reshape(Paths.shape[0],-1),cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector))) if tree_decoded_tx_message.size else extract_msg_bits(Paths.reshape(Paths.shape[0],-1),cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector)
        elif Paths.shape[0] == 1:
            tree_decoded_tx_message = np.vstack((tree_decoded_tx_message,extract_msg_bits(Paths.reshape(1,-1),cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector))) if tree_decoded_tx_message.size else extract_msg_bits(Paths.reshape(1,-1),cs_decoded_tx_message, L,J,parityLengthVector,messageLengthvector)
    return tree_decoded_tx_message

## SPARC Codebook

We use the `block_sub_fht` which computes the equivalent of $A.\beta$ by using $L$ separate $M\times M$ Hadamard matrices. However we want each entry to be divided by $\sqrt{n}$ to get the right variance, and we need to do a reshape on the output to get column vectors, so we'll wrap those operations here.

Returns two functions `Ab` and `Az` which compute $A\cdot B$ and $z^T\cdot A$ respectively.

In [13]:
def sparc_codebook(L, M, n):
    Ax, Ay, _ = block_sub_fht(n, M, L, ordering=None)
    def Ab(b):
        return Ax(b).reshape(-1, 1)/ np.sqrt(n)
    def Az(z):
        return Ay(z).reshape(-1, 1)/ np.sqrt(n)
    return Ab, Az

## AMP
This is the actual AMP algorithm. It's a mostly straightforward transcription from the relevant equations, but note we use `longdouble` types because the expentials are often too big to fit into a normal `double`.

In [14]:
def amp(y, σ_n, P, L, M, T, Ab, Az, p0, K):
    n = y.size
    β = np.zeros((L*M, 1))
    z = y
    Phat = n*P/L
    
    for t in range(T):
        
        τ = np.sqrt(np.sum(z**2)/n)
        # effective observation
        s = (np.sqrt(Phat)*β + Az(z)).astype(np.longdouble) 
        # denoiser
        β = (p0*np.exp(-(s-np.sqrt(Phat))**2/(2*τ**2)))/ (p0*np.exp(-(s-np.sqrt(Phat))**2/(2*τ**2)) + (1-p0)*np.exp(-s**2/(2*τ**2))).astype(float).reshape(-1, 1)
        # residual
        z = y - np.sqrt(Phat)*Ab(β) + (z/(n*τ**2)) * (Phat*np.sum(β) - Phat*np.sum(β**2))
        #print(t,τ)

    return β

In [79]:
K=25 # Number of active users
B=100 # Payload size of each active user
L=16 # Number of sections/sub-blocks
n=30000 # Total number of channel uses (real d.o.f)
T=10 # Number of AMP iterations
listSize = 100  # List size output by the tree decoder for each section
parityLengthVector = np.array([0,7,8,8,9,9,9,9,9,9,9,9,9,9,13,14],dtype=int) # Parity bits distribution
J=((B+np.sum(parityLengthVector))/L).astype(int) # Length of each coded sub-block
M=2**J # Length of each section
messageLengthVector = np.subtract(J*np.ones(L, dtype = 'int'), parityLengthVector).astype(int)
Pa = np.sum(parityLengthVector) # Total number of parity check bits
Ml = np.sum(messageLengthVector) # Total number of information bits
G = generate_parity_matrix(L,messageLengthVector,parityLengthVector)
EbNodB = 4 
p0 = 1-(1-1/M)**K
maxSim=2 # number of simulations
msgDetected=0

# EbN0 in linear scale
EbNo = 10**(EbNodB/10)
P = 2*B*EbNo/n
σ_n = 1
# We assume equal power allocation for all the sections. Code has to be modified a little to accomodate different power allocations
Phat = n*P/L
    

for s in range(maxSim):
    

    #G = generate_parity_matrix(L,messageLengthVector,parityLengthVector)
    
    # Generate active users message sequences
    tx_message = np.random.randint(2, size=(K,B))
    
    # Outer-encode the message sequences
    encoded_tx_message = Tree_encode(tx_message,K,G,L,J,Pa,Ml,messageLengthVector,parityLengthVector)
    
    # Convert bits to sparse representation
    β_0 = convert_bits_to_sparse(encoded_tx_message,L,J,K)
    
    # Generate the binned SPARC codebook
    Ab, Az = sparc_codebook(L, M, n)
    
    x = np.sqrt(Phat)*Ab(β_0)
    
    # Generate random channel noise and thus also received signal y
    z = np.random.randn(n, 1) * σ_n
    y = (x + z).reshape(-1, 1)

    # Run AMP decoding
    β = amp(y, σ_n, P, L, M, T, Ab, Az,p0,K).reshape(-1)
    
    
    # Convert decoded beta back to a message   
    cs_decoded_tx_message = convert_sparse_to_bits(β,L,J,listSize)
    
    # Tree decoder to decode individual messages from lists output by AMP
    tree_decoded_tx_message = Tree_decoder(cs_decoded_tx_message,G,L,J,B,parityLengthVector,messageLengthVector,listSize)
    
    # If tree deocder outputs more than K valid paths, retain only K of them
    if tree_decoded_tx_message.shape[0]>K:
        tree_decoded_tx_message = tree_decoded_tx_message[np.arange(K)]
    
    for i in range(tx_message.shape[0]):
        msgDetected = msgDetected + np.equal(tx_message[i,:],tree_decoded_tx_message).all(axis=1).any()
        
errorRate= (K*maxSim - msgDetected)/(K*maxSim)

print(errorRate)
    

0.04
