# Coded Demxing

This notebook implements coded Demixing using the CCS-AMP encoder/decoder for multi-class unsourced random access using Hadamard design matrices.


In [1]:
import numpy as np
import math
import matplotlib.pyplot as plt
import FactorGraphGeneration as FGG
import ipdb

OuterCode = FGG.Graph8()

## Fast Hadamard Transforms

The ```PyFHT_local``` code can all be found in `pyfht`, which uses a C extension to speed up the fht function.
Only one import suffices, with the latter being much faster.

In [2]:
# import PyFHT_local
from pyfht import block_sub_fht

In [3]:
def pa(L, C, P, a, f):
    pa = 2**(-2 * a * C * np.arange(L) / L)
    pa[int(f*L):] = pa[int(f*L)]
    pa /= pa.sum() / P
    return pa

# Outer Tree encoder

This function encodes the payloads corresponding to users into codewords from the specified tree code. 

Parity bits in section $i$ are generated based on the information sections $i$ is connected to

Computations are done within the ring of integers modulo length of the section to enable FFT-based BP on the outer graph

This function outputs the sparse representation of encoded messages

In [4]:
def Tree_encode(tx_message,K,messageBlocks,G,L,J):
    encoded_tx_message = np.zeros((K,L),dtype=int)
    
    encoded_tx_message[:,0] = tx_message[:,0:J].dot(2**np.arange(J)[::-1])
    for i in range(1,L):
        if messageBlocks[i]:
            # copy the message if i is an information section
            encoded_tx_message[:,i] = tx_message[:,np.sum(messageBlocks[:i])*J:(np.sum(messageBlocks[:i])+1)*J].dot(2**np.arange(J)[::-1])
        else:
            # compute the parity if i is a parity section
            indices = np.where(G[i])[0]
            ParityInteger=np.zeros((K,1),dtype='int')
            for j in indices:
                ParityInteger = ParityInteger + encoded_tx_message[:,j].reshape(-1,1)
            encoded_tx_message[:,i] = np.mod(ParityInteger,2**J).reshape(-1)
    
    return encoded_tx_message

This function converts message indices into $L$-sparse vectors of length $L 2^J$.

In [5]:
def convert_indices_to_sparse(encoded_tx_message_indices,L,J,K):
    aggregate_state_s_sparse=np.zeros((L*2**J,1),dtype=int)
    for i in range(L):
        section_indices_vectorized_rows = encoded_tx_message_indices[:,i]
        section_indices_vectorized_cols = section_indices_vectorized_rows.reshape([-1,1])
        np.add.at(aggregate_state_s_sparse, (i*2**J)+section_indices_vectorized_cols, 1)

    return aggregate_state_s_sparse

This function returns the index representation corresponding to a SPARC-like vector.

In [6]:
def convert_sparse_to_indices(cs_decoded_tx_message_sparse,L,J,listSize):
    cs_decoded_tx_message = np.zeros((listSize,L),dtype=int)
    for i in range(L):
        aggregate_section_sHat_sparse = cs_decoded_tx_message_sparse[i*2**J:(i+1)*2**J]
        indices_low_values = (aggregate_section_sHat_sparse.reshape(2**J,)).argsort()[np.arange(2**J-listSize)]
        indices_high_values = np.setdiff1d(np.arange(2**J),indices_low_values)
        cs_decoded_tx_message[:,i] = indices_high_values

    return cs_decoded_tx_message

Extract information bits from retained paths in the tree.

In [7]:
def extract_msg_indices(Paths,cs_decoded_tx_message, L,J):
    msg_bits = np.empty(shape=(0,0))
    L1 = Paths.shape[0]
    for i in range(L1):
        msg_bit=np.empty(shape=(0,0))
        path = Paths[i].reshape(1,-1)
        for j in range(path.shape[1]):
            msg_bit = np.hstack((msg_bit,cs_decoded_tx_message[path[0,j],j].reshape(1,-1))) if msg_bit.size else cs_decoded_tx_message[path[0,j],j]
            msg_bit=msg_bit.reshape(1,-1)
        msg_bits = np.vstack((msg_bits,msg_bit)) if msg_bits.size else msg_bit

    return msg_bits

## SPARC Codebook

We use the `block_sub_fht` which computes the equivalent of $A.\beta$ by using $L$ separate $M\times M$ Hadamard matrices. However we want each entry to be divided by $\sqrt{n}$ to get the right variance, and we need to do a reshape on the output to get column vectors, so we'll wrap those operations here.

Returns two functions `Ab` and `Az` which compute $A\cdot B$ and $z^T\cdot A$ respectively.

In [8]:
def sparc_codebook(L, M, n):
    Ax, Ay, _ = block_sub_fht(n, M, L, seed=None, ordering=None) # seed must be explicit
    def Ab(b):
        return Ax(b).reshape(-1, 1)/ np.sqrt(n)
    def Az(z):
        return Ay(z).reshape(-1, 1)/ np.sqrt(n) 
    return Ab, Az

# Vector Approximation

This function outputs the closest approximation to the input vector given that its L1 norm is 1 and no entry is greater than 1/K

In [9]:
def approximateVector(x, K):    

    # normalize initial value of x
    xOrig = x / np.linalg.norm(x, ord=1)
    
    # create vector to hold best approximation of x
    xHt = xOrig.copy()
    u = np.zeros(len(xHt))
    
    # run approximation algorithm
    while np.amax(xHt) > (1/K):
        minIndices = np.argmin([(1/K)*np.ones(xHt.shape), xHt], axis=0)
        xHt = np.min([(1/K)*np.ones(xHt.shape), xHt], axis=0)
        
        deficit = 1 - np.linalg.norm(xHt, ord=1)
        
        if deficit > 0:
            mIxHtNorm = np.linalg.norm((xHt*minIndices), ord=1)
            scaleFactor = (deficit + mIxHtNorm) / mIxHtNorm
            xHt = scaleFactor*(minIndices*xHt) + (1/K)*(np.ones(xHt.shape) - minIndices)

    # return admissible approximation of x
    return xHt

## Posterior Mean Estimator (PME)

This function implements the posterior mean estimator for situations where prior probabilities are uninformative.

In [10]:
def pme0(q, r, Pl, tau,n,M):
    """Posterior mean estimator (PME)
    
    Args:
        q (float): Prior probability
        r (float): Effective observation
        Pl (float): Signal amplitudes
        tau (float): Standard deviation of noise
    Returns:
        sHat (float): Probability s is one
    
    """
    rt_n_Pl = np.sqrt(n*Pl).repeat(M).reshape(-1, 1)
    exps = q * np.exp(r * rt_n_Pl / tau**2)
  
    sums = exps.reshape(L, M).sum(axis=1).repeat(M).reshape(-1, 1)
    sHat = (rt_n_Pl * exps / sums).astype(float).reshape(-1, 1)
    

    return sHat

# Dynamic Denoiser

This function performs believe propagation (BP) on the factor graph of the outer code.

In [11]:
def dynamicDenoiser1(r,G,messageBlocks,L,M,K,tau,d,numBPiter):
    """
    Args:
        r (float): Effective observation
        d (float): Signal amplitude
        tau (float): Standard deviation of noise
    """
    p0 = 1-(1-1/M)**K
    p1 = p0*np.ones(r.shape,dtype=float)
    mu = np.zeros(r.shape,dtype=float)

    # Compute local estimate (lambda) based on effective observation using PME.
    localEstimates = pme0(p0, r, d, tau)
    
    # Reshape local estimate (lambda) into an LxM matrix
    Beta = localEstimates.reshape(L,-1)
    for i in range(L):
        Beta[i,:] = approximateVector(Beta[i,:], K)

    # There is an issue BELOW for numBPiter greater than one!
    for iter in range(numBPiter):    
        # Rotate PME 180deg about y-axis
        Betaflipped = np.hstack((Beta[:,0].reshape(-1,1),np.flip(Beta[:,1:],axis=1)))
        # Compute and store all FFTs
        BetaFFT = np.fft.fft(Beta)
        BetaflippedFFT = np.fft.fft(Betaflipped)
        for i in range(L):
            if messageBlocks[i]:
                # Parity sections connected to info section i
                parityIndices = np.where(G[i])[0]   # Identities of parity block(s) attached
                BetaIFFTprime = np.empty((0,0)).astype(float)
                for j in parityIndices:  # Compute message for check associated with parity j
                    # Other info blocks connected to this parity block
                    messageIndices = np.setdiff1d(np.where(G[j])[0],i)  ## all indicies attahced to j, other than i
                    BetaFFTprime = np.vstack((BetaFFT[j],BetaflippedFFT[messageIndices,:]))  ## j is not part of G[j]
                    # Multiply the relevant FFTs
                    BetaFFTprime = np.prod(BetaFFTprime,axis=0)
                    # IFFT
                    BetaIFFTprime1 = np.fft.ifft(BetaFFTprime).real # multiple parity
                    BetaIFFTprime = np.vstack((BetaIFFTprime,BetaIFFTprime1)) if BetaIFFTprime.size else BetaIFFTprime1
                    # need to stack from all parity
                BetaIFFTprime = np.prod(BetaIFFTprime,axis=0) # pointwise product of distribution
            else:
                BetaIFFTprime = np.empty((0,0)).astype(float)
                # Information sections connected to this parity section (assuming no parity over parity sections)
                Indices = np.where(G[i])[0]
                # FFT
                BetaFFTprime = BetaFFT[Indices,:]
                # Multiply the relevant FFTs
                BetaFFTprime = np.prod(BetaFFTprime,axis=0)
                # IFFT
                BetaIFFTprime = np.fft.ifft(BetaFFTprime).real            
            mu[i*M:(i+1)*M] = approximateVector(BetaIFFTprime, K).reshape(-1,1)

    return mu

In [12]:
def dynamicDenoiser2(r,OuterCode,tau,Pl,numBPiter):
    """
    Args:
        r (float): Effective observation
        d (float): Signal amplitude
        tau (float): Standard deviation of noise
    """
    M = OuterCode.getsparseseclength()
    L = OuterCode.getvarcount()

    p0 = 1/M
    p1 = p0*np.ones(r.shape,dtype=float)
    mu = np.zeros(r.shape,dtype=float)

    # Compute local estimate (lambda) based on effective observation using PME.
    localEstimates = pme0(p0, r, Pl, tau,n,M)
    
    # Reshape local estimate (lambda) into an LxM matrix
    Beta = localEstimates.reshape(L,-1)
    OuterCode.reset()
    for varnodeid in OuterCode.getvarlist():
        i = varnodeid - 1
        #Beta[i,:] = approximateVector(Beta[i,:], K)
        OuterCode.setobservation(varnodeid, Beta[i,:]) # CHECK
    
    for iter in range(1):    # CHECK: Leave at 1 for now
        OuterCode.updatechecks()
        OuterCode.updatevars()

    for varnodeid in OuterCode.getvarlist():
        i = varnodeid - 1
#         Beta[i,:] = OuterCode.getestimate(varnodeid)
        Beta[i,:] = OuterCode.getextrinsicestimate(varnodeid)
        mu[i*M:(i+1)*M] = approximateVector(Beta[i,:], 1).reshape(-1,1)

    return mu

In [13]:
def dynamicDenoiser3(r,G,messageBlocks,L,M,K,tau,d,numBPiter):
    """
    Args:
        r (float): Effective observation
        d (float): Signal amplitude
        tau (float): Standard deviation of noise
    """
    p0 = 1-(1-1/M)**K
    p1 = p0*np.ones(r.shape,dtype=float)
    mu = np.zeros(r.shape,dtype=float)

    # Compute local estimate (lambda) based on effective observation using PME.
    localEstimates = pme0(p0, r, d, tau)
    
    # Reshape local estimate (lambda) into an LxM matrix
    Beta = localEstimates.reshape(L,-1)
    for i in range(L):
        Beta[i,:] = approximateVector(Beta[i,:], K)

        # Rotate PME 180deg about y-axis
        Betaflipped = np.hstack((Beta[:,0].reshape(-1,1),np.flip(Beta[:,1:],axis=1)))
        # Compute and store all FFTs
        BetaFFT = np.fft.fft(Beta)
        BetaflippedFFT = np.fft.fft(Betaflipped)
        for i in range(L):
            if messageBlocks[i]:
                # Parity sections connected to info section i
                parityIndices = np.where(G[i])[0]   # Identities of parity block(s) attached
                BetaIFFTprime = np.empty((0,0)).astype(float)
                for j in parityIndices:  # Compute message for check associated with parity j
                    # Other info blocks connected to this parity block
                    messageIndices = np.setdiff1d(np.where(G[j])[0],i)  ## all indicies attahced to j, other than i
                    BetaFFTprime = np.vstack((BetaflippedFFT[j],BetaflippedFFT[messageIndices,:]))  ## j is not part of G[j]
                    # Multiply the relevant FFTs
                    BetaFFTprime = np.prod(BetaFFTprime,axis=0)
                    # IFFT
                    BetaIFFTprime1 = np.fft.ifft(BetaFFTprime).real # multiple parity
                    BetaIFFTprime = np.vstack((BetaIFFTprime,BetaIFFTprime1)) if BetaIFFTprime.size else BetaIFFTprime1
                    # need to stack from all parity
                BetaIFFTprime = np.prod(BetaIFFTprime,axis=0) # pointwise product of distribution
            else:
                BetaIFFTprime = np.empty((0,0)).astype(float)
                # Information sections connected to this parity section (assuming no parity over parity sections)
                Indices = np.where(G[i])[0]
                # FFT
                BetaFFTprime = BetaflippedFFT[Indices,:]
                # Multiply the relevant FFTs
                BetaFFTprime = np.prod(BetaFFTprime,axis=0)
                # IFFT
                BetaIFFTprime = np.fft.ifft(BetaFFTprime).real
            mu[i*M:(i+1)*M] = approximateVector(BetaIFFTprime, K).reshape(-1,1)

    return mu

## AMP
This is the actual AMP algorithm. It's a mostly straightforward transcription from the relevant equations, but note we use `longdouble` types because the expentials are often too big to fit into a normal `double`.

In [14]:
def amp_state_update(z, s, Pl, L, M, Ab, Az, denoiserType, numBPiter,OuterCode):

    """
    Args:
        s: State update through AMP composite iteration
        z: Residual update through AMP composite iteration
        tau (float): Standard deviation of noise
        mu: Product of messages from adjoining factors
    """
    n = z.size

    # Compute tau online using the residual
    tau = np.sqrt(np.sum(z**2)/n)

    # Compute effective observation
    r = (s + Az(z))

    # Compute updated state
    # HERE: It remains unclear what to constrain and renormalize
    if denoiserType==0:
        # Use the uninformative prior p0 for Giuseppe's scheme
        p0 = 1/M
        s = pme0(p0, r, Pl, tau,n,M)
    else:
        mu = dynamicDenoiser2(r,OuterCode,tau,Pl,numBPiter)
        s = pme0(mu, r, Pl, tau,n,M)
    #ipdb.set_trace()
    return s

In [15]:
def amp_residual(y, z, s, Pl, Ab):
    """
    Args:
        s1: State update through AMP composite iteration
        s2: State update through AMP composite iteration
        y: Original observation
        tau (float): Standard deviation of noise
    """
    n = y.size
    P = np.sum(Pl)
    
    # Compute tau online using the residual
    tau = np.sqrt(np.sum(z**2)/n)

    # Compute residual
    Onsager = (z/tau**2) * (P - np.sum(s**2) / n)
   
    z_plus = y - Ab(s) + Onsager
    print(tau)
    return z_plus

If tree decoder outputs more than $K$ valid paths, retain $K-\delta$ of them based on their LLRs

$\delta$ is currently set to zero

In [16]:
def pick_topKminusdelta_paths(Paths, cs_decoded_tx_message, s, J,K,delta):
    
    L1 = Paths.shape[0]
    LogL = np.zeros((L1,1))
    for i in range(L1):
        msg_bit=np.empty(shape=(0,0))
        path = Paths[i].reshape(1,-1)
        for j in range(path.shape[1]):
            msg_bit = np.hstack((msg_bit,j*(2**J)+cs_decoded_tx_message[path[0,j],j].reshape(1,-1))) if msg_bit.size else j*(2**J)+cs_decoded_tx_message[path[0,j],j]
            msg_bit=msg_bit.reshape(1,-1)
        LogL[i] = np.sum(np.log(s[msg_bit])) 
    Indices =  LogL.reshape(1,-1).argsort()[0,-(K-delta):]
    Paths = Paths[Indices,:].reshape(((K-delta),-1))
    
    return Paths


# Simulation

In [23]:
L = 16
M=2**16
σ_n = 1
EbN0dB = 0
numSim = 1

R = 0.04
T = 25
pa_a = 0.7
pa_f = 0.7
P = (10**(EbN0dB/10))*2*R*σ_n**2

numBPiter = 1
numParitySections = 8
# Compute the SNR, capacity, and n, from the input parameters
snr = P / σ_n**2
C = 0.5 * np.log2(1 + snr)
#Tstar = np.ceil(2*C/np.log2(C/R))
#print(C,Tstar)
J = np.log2(M).astype(int)

# Generate the power allocation and set of tau coefficients
Pl = pa(L, C, P, pa_a, pa_f)
#Pl = np.array([P,P,P,P,P,P,P,P,P,P])/L
numInfoSections = L-numParitySections

messageIndices = [1, 2, 4, 5, 7, 8, 10, 11]

B = numInfoSections*J

n = int(B / R) 

ber=0

simCount=1

for simIndex in range(simCount):
    print('Simulation Number: ' + str(simIndex))
    print(B)
    # Generate user message sequence
    tx_message = np.random.randint(2, size=(1,B))

    # Outer-encode the message sequences
    #codewords = OuterCode.encodemessages(tx_message).reshape(-1,1)
    codewords = OuterCode.encodemessages(tx_message).reshape(-1,1)
    
    #ipdb.set_trace()
    
    encoded_tx_message_indices = np.where(codewords)[0].reshape(-1,1) - M*np.arange(L).reshape(-1,1)
    
    # Multiply with power coefficients
    sTrue = codewords*np.sqrt(n*Pl).repeat(M).reshape(-1, 1)
    
    #ipdb.set_trace()


    # Generate the binned SPARC codebook
    Ab, Az = sparc_codebook(L, M, n)
    
    # Generate our transmitted signal X
    x = Ab(sTrue)
    
    # Generate random channel noise and thus also received signal y
    noise = np.random.randn(n, 1) * σ_n
    y = (x + noise).reshape(-1, 1)

    z = y.copy()
    s = np.zeros((L*M, 1))

    for t in range(T):
        s = amp_state_update(z, s, Pl, L, M, Ab, Az, 1, numBPiter,OuterCode)

        z = amp_residual(y, z, s, Pl, Ab)

    #print(s.shape)
    print('Graph Decode')
    #ipdb.set_trace()
    # Decoding wiht Graph
    originallist = codewords.copy()
    recoveredcodewords = FGG.decoder(OuterCode,s,1)
    
    recovered_tx_message_indices = np.where(np.array(recoveredcodewords).reshape(-1,1))[0].reshape(-1,1) - M*np.arange(L).reshape(-1,1)
    #encoded_tx_message_indices = encoded_tx_message_indices.reshape(1,).tolist()
    #recovered_tx_message_indices = recovered_tx_message_indices.reshape(1,).tolist()
    #print(encoded_tx_message_indices)
    #print(recovered_tx_message_indices)
    
    correct = np.sum(np.array(recovered_tx_message_indices[messageIndices]) == np.array(encoded_tx_message_indices[messageIndices])) 
    print(correct)
    # Compute BER
    #mismatches = (bin(a^b).count('1') for (a, b) in zip(encoded_tx_message_indices[0:numInfoSections], recovered_tx_message_indices[0:numInfoSections]))
    

    #ber = ber + sum(mismatches)
    
#bitErrorRate= ber/(B*simCount)

#print("Biterror rate =  ", bitErrorRate)



Simulation Number: 0
128
1
1.0269666118464065
1.011370308533076
1.0005072587749573
0.9996726521230073
0.9996608736561233
0.9996628794094442
0.9996625810239197
0.9996620622664806
0.9996619450670148
0.9996619507913287
0.9996619574083009
0.9996619582052978
0.9996619581178108
0.9996619581076457
0.9996619581177726
0.999661958120636
0.9996619581206918
0.9996619581205922
0.9996619581205811
0.9996619581205844
0.9996619581205856
0.9996619581205848
0.9996619581205857
0.9996619581205852
0.9996619581205854
Graph Decode
Root section ID: 62172
8
