**Importing** **Libraries**

In [1]:
import numpy as np
from numpy import linalg as LA
import matplotlib
import matplotlib.pyplot as plt
import sympy as sp
import torch
torch.set_default_dtype(torch.float64)
torch.set_printoptions(precision=15)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

**Functions**

In [2]:
### Fast hadamard transform. Recreated from sympy for GPU support
def fwht(seq, inverse):
    n = len(seq)
    if n < 2:
        return a
    if n&(n-1):
        n = 2**n.bit_length()
    a = torch.concat((seq, torch.zeros(n-len(seq)).to(device))) # append zeros to make it power of 2
    h = 2
    while h<=n:
        hf = h // 2
        i = torch.arange(0,n,h).to(device)
        i = i.reshape(len(i), 1)
        j = torch.arange(0, hf, 1).to(device)
        j = j.repeat(len(i),1)
        u, v = a[i+j], a[i+j+hf]
        a[i+j], a[i+j+hf] = u+v, u-v
        h*=2

    if inverse:
        a /= n
    return a

### Uniform quantization
def UQ(flat_grad, B, v, Q_space):
  norm_flat_grad = flat_grad/K
  rand_vect = torch.rand(d).to(device)
  f = torch.bucketize(norm_flat_grad, Q_space)
  q = torch.where(2*B/K/(v-1)*rand_vect<=torch.subtract(norm_flat_grad, Q_space[f-1]), f, f-1)
  return(q)

### Lattice encoding
def LatE(stat, p, num_blocks, w, d_ext):
	stat = torch.concatenate((stat, torch.zeros(int(d_ext-d)).to(device)))
	stat_ext = stat.reshape(int(num_blocks), p)
	coeff = w**torch.arange(p).to(device)
	lamda  = torch.multiply(stat_ext, coeff)
	return(torch.sum(lamda, axis = -1))

### ASK modulation
def ASK(lmbd, r):
	return(-np.sqrt(P)+lmbd*2*np.sqrt(P)/(r-1))

### Minimum-Distance decoding
def MD(y_recvd, delt, r):
	return(torch.round((y_recvd/delt+K*np.sqrt(P)/delt-1/2)))

### Lattice decoding (Successive-Cancelllation)
def LatD(y_hat, b, w, d_):
  if torch.sum(y_hat>= w) == 0 and torch.numel(b) == d_:
    return(b)
  else:
    b = torch.concatenate((b, (y_hat%w).reshape(1,len(y_hat))))
    return(LatD((y_hat-y_hat%w)/w, b, w, d_))


**Main** **code**

In [3]:

cx = 0
d = 64			                                                                    #### Dimension
cnt = 0
B_range = [4096]
#B_range = [512, 1024, 2048]
MSE1_B = [0]*len(B_range)
MSE4_B = [0]*len(B_range)

for b in B_range:
  np.random.seed(10)
  torch.manual_seed(0)
  Clients = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
  MSE1_K = [0]*len(Clients)
  MSE2_K = [0]*len(Clients)
  MSE3_K = [0]*len(Clients)
  MSE4_K = [0]*len(Clients)
  I = 20		                                                                    #### monte-carlo iterations
  for i in range(len(Clients)):
    K = Clients[i]
    print('Clients: {}, dimension: {}, MC_iter: {}'.format(K, d, I))
    sigma_range = [0.01732]                      #### Sigma range
    sig = sigma_range[0]

    MSE1_K_sig = [0.0]
    MSE2_K_sig = [0.0]
    MSE3_K_sig = [0.0]
    MSE4_K_sig = [0.0]


    MSE1_K_sig_I = [0.0]*I
    MSE2_K_sig_I = [0.0]*I
    MSE3_K_sig_I = [0.0]*I
    MSE4_K_sig_I = [0.0]*I

    for k in range(I):
      ''' OVER-THE-AIR '''
      snr_range = [180]
      MSE1_K_sig_I_snr = [0.0]*len(snr_range)
      MSE2_K_sig_I_snr = [0.0]*len(snr_range)
      MSE3_K_sig_I_snr = [0.0]*len(snr_range)
      MSE4_K_sig_I_snr = [0.0]*len(snr_range)
      mean = 2*(torch.rand(d).to(device)-0.5)                                   #### True gradients in [-1,1]

      for l in range(len(snr_range)):
        print('\n Experiment for Iteration = {}, B = {}, SNR = {}dB'.format(k, b, snr_range[l]))
        input_avg = torch.zeros(d).to(device)
        q_input_avg = torch.zeros(d).to(device)
        output_avg = torch.zeros(d).to(device)
        psi = []
        x_e = []
        dbSNR = snr_range[l]                                                                     #### SNR ratio in dB
        SNR = 10**(dbSNR/10.0)                                                          #### SNR in units/units
        npr = 1                                                                     #### Absolute noise power
        P = SNR*npr/K                                                                 #### Absolute signal power allowed per client

        #v = int(2*b)+1
        v = int(np.sqrt(d))+1 if (int(np.sqrt(d)))%2==0 else int(np.sqrt(d))+2
        w = K*(v-1) + 1 	                                            #### No. of quantization levels

        p = min(d, max(1, int(np.log(np.sqrt(2*K*SNR/np.log(K))+1)/np.log(w))))                               #### Block size
        num_blocks = int(np.ceil(d/p))
        d_ext = num_blocks*p
        print('\n v: {}, w: {}, num Blocks:{}, blocksize: {}'.format(v, w, num_blocks, p))
        r = (w**p-1)/K + 1                                                          #### Number of ASK codewords in transmission
        del_ = 2*np.sqrt(P)/(r-1)

        for client in range(K):
          t = 6*sig*(torch.rand(d).to(device)-0.5)                              #### Noise in [-3\sigma, +3\sigma].
          x = mean + t                                                          #### Noisy gradient in [-1-3sigma, 1+3sigma]
                                                                                #### E[x] = mean; var(x) = var(t) = 3sig^2 = sig_paper^2/d; norm(x) <= B^2 = B_paper^2 = d(1+3sig)^2
          input_avg = x + input_avg

          ''' UNIFORM QUANTIZATION '''
          B = b*np.sqrt(d)                                              #### Maximum norm of noisy grad

          Q_space = -B/K + torch.arange(v).to(device)*2*B/K/(v-1)               #### Quantization points in [-B/K,B/K]

          x_uq = UQ(x, B, v, Q_space)
          q_input_avg = q_input_avg - B/K + 2*B*x_uq/K/(v-1)

          ''' LATTICE ENCODING '''
          x_lenc = LatE(x_uq, p, num_blocks, w, d_ext)
          ''' ASK MODULATION '''
          x_ask  = ASK(x_lenc, r)
          #print('ASK codeword', x_ask)
          psi.append(x_ask)
          x_e.append(x_uq)
          #print('Done for client', client)

        ''' MAC transmission '''
        y = sum(psi)+ np.sqrt(npr)*torch.randn(num_blocks).to(device)           #### Adding awgn noise
        #print('Sum of ASK enc:', sum(psi))
        #print('Channel output', y)
        #print(y)
        x_hat = MD(y, del_, r)                                                    #### Min-distance Decoding
        #print('sum of ASK dec', -K*np.sqrt(P)+x_hat*2*np.sqrt(P)/(r-1))
        #print('Index decoded', x_hat)
        var = LatD(x_hat, torch.Tensor([]).to(device), w, d_ext)                  #### Lattice Decoding
        var = var.transpose(0, 1).flatten()
        #print('Sum of UQ enc', sum(x_e))
        #print('Sum of UQ dec', var)
        output = -B + var[:d]*2*B/K/(v-1)                                         #### Mean estimate of gradients in C_1 group
        #print('First transmission decoded')
        #print(output, mean)
        MSE_1 = output - mean                                                     #### MSE w.r.t true grad (over MAC)
        MSE_2 = output - input_avg/K                                              #### MSE w.r.t avg. of noisy grad (over MAC)
        MSE_3 = input_avg/K - mean                                                #### MSE w.r.t empirical mean (Benchmark, noiseless channel)
        MSE_4 = q_input_avg - mean                                                #### MSE w.r.t avg. of quant. noisy grad (noiseless channel)
        print('Done! Total channel-uses = ', num_blocks)
        print('TM:', mean.norm(), 'QEM:', q_input_avg.norm(), 'UQM:', output.norm())
        print('UQM ERROR:', MSE_1.norm()*np.sqrt(num_blocks))

        MSE1_K_sig_I_snr[l] = MSE_1.norm()*np.sqrt(num_blocks)
        MSE2_K_sig_I_snr[l] = MSE_2.norm()*np.sqrt(num_blocks)
        MSE3_K_sig_I_snr[l] = MSE_3.norm()*np.sqrt(num_blocks)
        MSE4_K_sig_I_snr[l] = MSE_4.norm()*np.sqrt(num_blocks)
      MSE1_K_sig_I[k] = MSE1_K_sig_I_snr
      MSE2_K_sig_I[k] = MSE2_K_sig_I_snr
      MSE3_K_sig_I[k] = MSE3_K_sig_I_snr
      MSE4_K_sig_I[k] = MSE4_K_sig_I_snr


    MSE1_K_sig = np.array(sum(torch.Tensor(MSE1_K_sig_I)).to('cpu')/I)
    MSE2_K_sig = np.array(sum(torch.Tensor(MSE2_K_sig_I)).to('cpu')/I)
    MSE3_K_sig = np.array(sum(torch.Tensor(MSE3_K_sig_I)).to('cpu')/I)
    MSE4_K_sig = np.array(sum(torch.Tensor(MSE4_K_sig_I)).to('cpu')/I)

    MSE1_K[i] = MSE1_K_sig
    MSE2_K[i] = MSE2_K_sig
    MSE3_K[i] = MSE3_K_sig
    MSE4_K[i] = MSE4_K_sig

  MSE1_B[cnt] = MSE1_K
  MSE4_B[cnt] = MSE4_K
  cnt+=1

#np.savetxt("RMSE_CUQ_v_"+str(v)+"_"+str(d)+".dat", MSE_CUQ, delimiter =", ", fmt ='% s')

Clients: 200, dimension: 64, MC_iter: 20

 Experiment for Iteration = 0, B = 4096, SNR = 180dB

 v: 9, w: 1601, num Blocks:22, blocksize: 3
Done! Total channel-uses =  22
TM: tensor(4.354756285278644) QEM: tensor(0.) UQM: tensor(115.852375029601475)
UQM ERROR: tensor(544.417288651236277)

 Experiment for Iteration = 1, B = 4096, SNR = 180dB

 v: 9, w: 1601, num Blocks:22, blocksize: 3
Done! Total channel-uses =  22
TM: tensor(4.445313023823938) QEM: tensor(70.944801078021229) UQM: tensor(147.683380243001864)
UQM ERROR: tensor(689.159984830089456)

 Experiment for Iteration = 2, B = 4096, SNR = 180dB

 v: 9, w: 1601, num Blocks:22, blocksize: 3
Done! Total channel-uses =  22
TM: tensor(4.293361265333729) QEM: tensor(40.960000000000008) UQM: tensor(115.852375029601475)
UQM ERROR: tensor(541.344942031067603)

 Experiment for Iteration = 3, B = 4096, SNR = 180dB

 v: 9, w: 1601, num Blocks:22, blocksize: 3
Done! Total channel-uses =  22
TM: tensor(4.902292985899986) QEM: tensor(0.) UQM: te

**Plotting**

In [1]:
pts = 14 ## Number of points for plotting. Maximum = len(sigma_range)
font = {'weight': 'normal', 'size': 14}
matplotlib.rc('font', **font)

####### The below is for WZ with d^2 levels in CUQ ######
'''MSE1_B = [[np.array([3.17405946, 0.35217659, 0.36060385, 0.34892212])], [np.array([6.5057385 , 0.53404811, 0.53430262, 0.52615879])], [np.array([12.82181643,  0.76855412,  0.76843619,  0.75659643])], [np.array([25.25224325,  1.11188592,  1.10644111,  1.07614049])], [np.array([49.98421408,  1.6672498 ,  1.68239821,  1.58654196])], [np.array([101.14558496,   2.61247943,   2.50976243,   2.32345564])], [np.array([201.87061386,   4.27989231,   3.98053124,   3.5174812 ])], [np.array([415.87833972,   7.11806425,   6.7142643 ,   5.42661781])], [np.array([822.70202898,  12.77635324,  11.29004116,   8.99446504])], [np.array([1648.0022715 ,   24.53414021,   21.18104624,   16.14522054])], [np.array([3200.21330874,   46.13691111,   39.91412021,   30.53211794])]]
MSE4_B = [[np.array([0.35440561, 0.34938205, 0.35880823, 0.34851481])], [np.array([0.5221536 , 0.52683578, 0.53019908, 0.52345472])], [np.array([0.76942668, 0.75088376, 0.75719615, 0.75145968])], [np.array([1.07513441, 1.05497651, 1.06576003, 1.06191561])], [np.array([1.53086658, 1.50840045, 1.56732525, 1.52751654])], [np.array([2.14527163, 2.16600343, 2.22881454, 2.16277094])], [np.array([3.19422718, 3.02482801, 3.05920336, 3.10496264])], [np.array([4.32939589, 4.3161013 , 4.48234544, 4.22329875])], [np.array([6.31527585, 6.09199437, 6.26251484, 5.71020263])], [np.array([9.24613809, 8.58495586, 7.87638433, 8.02889029])], [np.array([11.54683883, 11.94345642, 11.59205749, 12.62681684])]]
MSE_WZ = [[np.array([5.78897253, 5.38065371, 5.38258418, 5.29397665])], [np.array([6.57783493, 5.17140159, 4.98366062, 4.55779069])], [np.array([11.07042354,  7.64618362,  6.13125953,  5.78847083])], [np.array([20.32392915, 13.46582667, 10.50547807,  8.95491307])], [np.array([41.104201  , 25.80426851, 18.43607257, 16.22035121])], [np.array([83.67904936, 53.34692381, 37.23331967, 31.17285751])], [np.array([164.6343602 , 100.50493833,  74.43839158,  64.45778997])], [np.array([309.61812762, 207.15305647, 156.63401743, 130.05954001])]]
MSE_WZ_ideal = [[np.array([0.600277  , 0.58114297, 0.60906026, 0.58306563])], [np.array([1.24458024, 1.25956621, 1.24563784, 1.26578862])], [np.array([2.52578236, 2.64831355, 2.59470439, 2.59958956])], [np.array([5.05012974, 5.14785544, 5.29787991, 5.01841629])], [np.array([10.15053377, 10.12965803,  9.94118585, 10.10078418])], [np.array([19.81244993, 20.41454499, 20.0994908 , 20.14080771])], [np.array([41.63053575, 40.67922181, 40.81665609, 41.48666244])], [np.array([81.26760234, 82.23856563, 83.99207257, 81.56505437])]]
B_range = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
Clients = [400]
sigma_range = [0.0176]
sig = sigma_range[0]
I=50
d=32
snr_range = [0, 40, 100, 180]'''
###########################################


########## The below is for limited B_range and SNR. Good for 0dB and 5 dB ##############################
'''B_range = [2, 4, 8, 16, 32, 64, 128, 256]
snr_range = [0, 5, 10, 15]
MSE_WZ = [[np.array([5.78897253, 5.38065371, 5.38258418, 5.29397665])], [np.array([6.57783493, 5.17140159, 4.98366062, 4.55779069])], [np.array([11.07042354,  7.64618362,  6.13125953,  5.78847083])], [np.array([20.32392915, 13.46582667, 10.50547807,  8.95491307])], [np.array([41.104201  , 25.80426851, 18.43607257, 16.22035121])], [np.array([83.67904936, 53.34692381, 37.23331967, 31.17285751])], [np.array([164.6343602 , 100.50493833,  74.43839158,  64.45778997])], [np.array([309.61812762, 207.15305647, 156.63401743, 130.05954001])]]
MSE_WZ_ideal = [[np.array([0.600277, 0.58114297, 0.60906026, 0.58306563])], [np.array([1.24458024, 1.25956621, 1.24563784, 1.26578862])], [np.array([2.52578236, 2.64831355, 2.59470439, 2.59958956])], [np.array([5.05012974, 5.14785544, 5.29787991, 5.01841629])], [np.array([10.15053377, 10.12965803,  9.94118585, 10.10078418])], [np.array([19.81244993, 20.41454499, 20.0994908 , 20.14080771])], [np.array([41.63053575, 40.67922181, 40.81665609, 41.48666244])], [np.array([81.26760234, 82.23856563, 83.99207257, 81.56505437])]]'''
#############################################################

####### We increased the range of B_range for 10dB and 15dB to see the crossover. ########
'''snr_range = [10, 15]
MSE_WZ =  [[np.array([5.38258418, 5.29397665])], [np.array([4.98366062, 4.55779069])], [np.array([6.13125953,  5.78847083])], [np.array([10.50547807,  8.95491307])], [np.array([18.43607257, 16.22035121])], [np.array([37.23331967, 31.17285751])], [np.array([74.43839158,  64.45778997])], [np.array([156.63401743, 130.05954001])], [np.array([307.77397708, 243.17379068])], [np.array([596.02682429, 521.90510531])], [np.array([1167.35568421,  984.36180697])]]
MSE_WZ_ideal = [[np.array([0.60906026, 0.58306563])], [np.array([1.24563784, 1.26578862])], [np.array([2.59470439, 2.59958956])], [np.array([5.29787991, 5.01841629])], [np.array([9.94118585, 10.10078418])], [np.array([20.0994908 , 20.14080771])], [np.array([40.81665609, 41.48666244])], [np.array([83.99207257, 81.56505437])], [np.array([165.59505735, 162.45435766])], [np.array([325.22210522, 327.86261155])], [np.array([641.76164581, 641.61219751])]]
MSE1_B = [[np.array([1.07111144, 0.67983617])], [np.array([2.08988936, 1.24128847])], [np.array([4.11848309,  2.41355012])], [np.array([7.97437698,  4.63475183])], [np.array([16.27528335,  9.15725285])], [np.array([32.22020926,  18.80348619])], [np.array([63.47255104,  35.32803634])], [np.array([131.65154164,  73.9646764])], [np.array([251.64001385, 149.12191649])], [np.array([510.18942432,  292.55891468])], [np.array([1012.06446919,  581.00854461])]]
MSE4_B = [[np.array([0.34117339, 0.35603351])], [np.array([0.52063096, 0.52471865])], [np.array([0.75062592, 0.75563973])], [np.array([1.09735321, 1.06081474])], [np.array([1.5459673 , 1.49297159])], [np.array([2.14641154, 2.19631585])], [np.array([3.08457669, 3.04669675])], [np.array([4.36506297, 4.28374433])], [np.array([5.65264386, 6.19822862])], [np.array([7.64623657, 9.21385072])], [np.array([11.68064844, 12.66663568])]]

B_range = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]'''
#################################################

for b in range(len(B_range)):
  B1 = B_range[b]
  plt.plot(Clients[:pts], MSE1_B[b][:pts], label='MSE_TM_WZ_OTA')
  #plt.plot(Tot_clnts[:pts], MSE2_K[i][j][:pts], label='MSE_EM_WZ_OTA_K='+str(K)+' SNR='+str(snr_range[j])+' sigma='+str(sigma_range[j]))
  #plt.plot(Tot_clnts[:pts], MSE3_K[i][j][:pts], label='MSE_TM_EM_K='+str(K)+' (noiseless)'+' sigma='+str(sigma_range[j]))
  plt.plot(Clients[:pts], MSE4_B[b][:pts], label='MSE_TM_WZ_OTA (noiseless)')

    #plt.plot(snr_range[:pts], SMSE1_K[i][j][:pts], label='MSE_TM_UQ_OTA_K='+str(2*K))
    #plt.plot(sigma_range[:pts], MSE2_K[i][j][:pts], label='MSE_EM_CUQ_K='+str(K)+' SNR='+str(snr_range[j])+' sigma='+str(sigma_range[j]))
    #plt.plot(sigma_range[:pts], MSE3_K[i][j][:pts], label='MSE_TM_EM_K='+str(K)+' (noiseless)'+' sigma='+str(sigma_range[j]))
    #plt.plot(snr_range[:pts], SMSE4_K[i][j][:pts], label='MSE_TM_UQ_OTA_K='+str(2*K)+' (noiseless)')


  plt.legend()
  plt.xlabel('K')
  plt.ylabel('RMSE x l')
  #plt.xscale('log')
  #plt.yscale('log')
  plt.xticks(ticks= Clients[:pts], labels=Clients[:pts])
  plt.title('d='+str(d)+', iters='+str(I)+', B/sigma='+str(B1/sig)+'SNR='+str(SNR))
  plt.grid()
  plt.show()




for b in range(len(B_range)):
  B1 = B_range[b]
  for i in range(len(Clients)):
    K = Clients[i]
    plt.plot(snr_range[:pts], MSE1_B[b][i][:pts], label='MSE_TM_CUQ_K='+str(K))
    plt.plot(snr_range[:pts], MSE4_B[b][i][:pts], label='MSE_TM_CUQ_K='+str(K)+' (noiseless)')
    #plt.plot(snr_range[:pts], MSE_WZ[b][i][:pts], label='MSE_TM_WZ_K='+str(K))
    #plt.plot(snr_range[:pts], MSE_WZ_ideal[b][i][:pts], label='MSE_TM_WZ_K='+str(K)+' (noiseless)')
    plt.legend()
    plt.xlabel('SNR')
    plt.ylabel('RMSE')
    #plt.xscale('log')
    #plt.yscale('log')
    plt.xticks(ticks= snr_range[:pts], labels=snr_range[:pts])
    plt.title('d='+str(d)+', iters='+str(I)+', B/sigma='+str(B1/sigma_range[0]))
    plt.grid()
    plt.show()

b_by_sig = [B_range[b]/sig/np.sqrt(3) for b in range(len(B_range))]


MSE1_modified = np.array(MSE1_B).transpose()
MSE4_modified = np.array(MSE4_B).transpose()
#MSE_WZ_modified = np.array(MSE_WZ).transpose()
#MSE_WZ_ideal_modified = np.array(MSE_WZ_ideal).transpose()

for i in range(len(Clients)):
  K = Clients[i]

  for k in range(len(snr_range)):
    plt.plot(b_by_sig[:pts], MSE1_modified[k][i][:pts], label='MSE_TM_CUQ')
    #plt.plot(b_by_sig[:pts], MSE_WZ_modified[k][i][:pts], label='MSE_TM_WZ')
    #plt.plot(b_by_sig[:pts], MSE_WZ_ideal_modified[k][i][:pts], label='MSE_TM_WZ (noiseless)')
    plt.plot(b_by_sig[:pts], MSE4_modified[k][i][:pts], '.', label='MSE_TM_CUQ'+' (noiseless)')
    plt.legend()
    plt.xlabel('b_by_sig')
    plt.ylabel('RMSE')
    #plt.xscale('log')
    plt.xscale('log')

    plt.xticks(ticks= b_by_sig[:pts], labels=b_by_sig[:pts])
    plt.title('d='+str(d)+', iter='+str(I)+', K='+str(K)+', SNR = '+str(snr_range[k])+'dB')
    plt.grid()
    plt.show()


NameError: ignored

In [5]:
print(MSE1_B)
print(MSE4_B)


[[array([654.55306616]), array([482.70947378]), array([337.94316881]), array([259.82566544]), array([201.4417586]), array([172.40378392]), array([153.44723694]), array([140.9917854]), array([125.6709342]), array([111.70492671])]]
[[array([105.49777476]), array([144.76118523]), array([101.62699821]), array([96.6094582]), array([80.83069935]), array([77.78865838]), array([76.13351325]), array([73.63022669]), array([67.80976343]), array([62.88444011])]]
