<a href="https://colab.research.google.com/github/shubhamjha-46/OTA_MAC/blob/main/UQ_OTA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Importing** **Libraries**

In [8]:
import numpy as np
from numpy import linalg as LA
import matplotlib
import matplotlib.pyplot as plt
import sympy as sp
import torch
torch.set_default_dtype(torch.float64)
torch.set_printoptions(precision=2)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

**Functions**

In [3]:
### Uniform quantization
def UQ(flat_grad, B, v, Q_space):
  norm_flat_grad = flat_grad/K
  rand_vect = torch.rand(d).to(device)
  f = torch.bucketize(norm_flat_grad, Q_space)
  q = torch.where(2*B/K/(v-1)*rand_vect<=torch.subtract(norm_flat_grad, Q_space[f-1]), f, f-1)
  return(q)

### Lattice encoding
def LatE(stat, p, num_blocks, w, d_ext):
	stat = torch.concatenate((stat, torch.zeros(int(d_ext-d)).to(device)))
	stat_ext = stat.reshape(int(num_blocks), p)
	coeff = w**torch.arange(p).to(device)
	lamda  = torch.multiply(stat_ext, coeff)
	return(torch.sum(lamda, axis = -1))

### ASK modulation
def ASK(lmbd, r):
	return(-np.sqrt(P)+lmbd*2*np.sqrt(P)/(r-1))

### Minimum-Distance decoding
def MD(y_recvd, delt, r):
	return(torch.round((y_recvd/delt+K*np.sqrt(P)/delt-1/2)))

### Lattice decoding (Successive-Cancelllation)
def LatD(y_hat, b, w, d_):
  if torch.sum(y_hat>= w) == 0 and torch.numel(b) == d_:
    return(b)
  else:
    b = torch.concatenate((b, (y_hat%w).reshape(1,len(y_hat))))
    return(LatD((y_hat-y_hat%w)/w, b, w, d_))


**Main** **code**

In [None]:
d = 64			                                                                                #### Dimension
cnt = 0
B_range = [4096]
#B_range = [512, 1024, 2048]
MSE1_B = [0]*len(B_range)
MSE2_B = [0]*len(B_range)

for b in B_range:
  np.random.seed(10)
  torch.manual_seed(0)
  Clients = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
  MSE1_K = [0]*len(Clients)
  MSE2_K = [0]*len(Clients)
  I = 20		                                                                                #### monte-carlo iterations
  for i in range(len(Clients)):
    K = Clients[i]
    print('Clients: {}, dimension: {}, MC_iter: {}'.format(K, d, I))
    sigma_range = [0.01732]                      #### Sigma range
    sig = sigma_range[0]

    MSE1_K_sig = [0.0]
    MSE2_K_sig = [0.0]

    MSE1_K_sig_I = [0.0]*I
    MSE2_K_sig_I = [0.0]*I

    for k in range(I):
      ''' OVER-THE-AIR '''
      snr_range = [180]
      MSE1_K_sig_I_snr = [0.0]*len(snr_range)
      MSE2_K_sig_I_snr = [0.0]*len(snr_range)
      mean = 2*(torch.rand(d).to(device)-0.5)                                               #### True gradients in [-1,1]

      for l in range(len(snr_range)):
        print('\n Experiment for Iteration = {}, B = {}, SNR = {}dB'.format(k, b, snr_range[l]))
        input_avg = torch.zeros(d).to(device)
        q_input_avg = torch.zeros(d).to(device)
        output_avg = torch.zeros(d).to(device)
        psi = []
        x_e = []
        dbSNR = snr_range[l]                                                                #### SNR in dB
        SNR = 10**(dbSNR/10.0)                                                              #### SNR in units/units
        npr = 1                                                                             #### Noise power
        P = SNR*npr/K                                                                       #### Signal power allowed per client

        v = int(np.sqrt(d))+1 if (int(np.sqrt(d)))%2==0 else int(np.sqrt(d))+2
        w = K*(v-1) + 1 	                                                                  #### No. of quantization levels

        p = min(d, max(1, int(np.log(np.sqrt(2*K*SNR/np.log(K))+1)/np.log(w))))             #### Block size
        num_blocks = int(np.ceil(d/p))
        d_ext = num_blocks*p
        print('\n v: {}, w: {}, num Blocks:{}, blocksize: {}'.format(v, w, num_blocks, p))
        r = (w**p-1)/K + 1                                                                  #### Number of ASK codewords in transmission
        del_ = 2*np.sqrt(P)/(r-1)

        for client in range(K):
          t = 6*sig*(torch.rand(d).to(device)-0.5)                                          #### Noise in [-3\sigma, +3\sigma].
          x = mean + t                                                                      #### Noisy gradient in [-1-3sigma, 1+3sigma]
          input_avg = x + input_avg                                                         #### E[x] = mean; var(x) = var(t) = 3sig^2 = sig_paper^2/d; norm(x) <= B^2 = B_paper^2 = d(1+3sig)^2

          ''' UNIFORM QUANTIZATION '''
          B = b*np.sqrt(d)                                                                  #### Maximum norm of noisy grad
          Q_space = -B/K + torch.arange(v).to(device)*2*B/K/(v-1)                           #### Quantization points in [-B/K,B/K]
          x_uq = UQ(x, B, v, Q_space)
          q_input_avg = q_input_avg - B/K + 2*B*x_uq/K/(v-1)

          ''' LATTICE ENCODING '''
          x_lenc = LatE(x_uq, p, num_blocks, w, d_ext)

          ''' ASK MODULATION '''
          x_ask  = ASK(x_lenc, r)
          psi.append(x_ask)
          x_e.append(x_uq)


        ''' MAC transmission '''
        y = sum(psi)+ np.sqrt(npr)*torch.randn(num_blocks).to(device)                       #### Adding awgn noise
        x_hat = MD(y, del_, r)                                                              #### Min-distance Decoding
        var = LatD(x_hat, torch.Tensor([]).to(device), w, d_ext)                            #### Lattice Decoding
        var = var.transpose(0, 1).flatten()
        output = -B + var[:d]*2*B/K/(v-1)                                                   #### Mean estimate of gradients in C_1 group
        MSE_1 = output - mean                                                               #### MSE w.r.t true grad (over MAC)
        MSE_2 = q_input_avg - mean                                                          #### MSE w.r.t avg. of quant. noisy grad (noiseless channel)
        print('Done! Total channel-uses = ', num_blocks)
        print('TM:', mean.norm(), 'QEM:', q_input_avg.norm(), 'UQM:', output.norm())
        print('UQM ERROR:', MSE_1.norm()*np.sqrt(num_blocks))

        MSE1_K_sig_I_snr[l] = MSE_1.norm()*np.sqrt(num_blocks)
        MSE2_K_sig_I_snr[l] = MSE_2.norm()*np.sqrt(num_blocks)

      MSE1_K_sig_I[k] = MSE1_K_sig_I_snr
      MSE2_K_sig_I[k] = MSE2_K_sig_I_snr


    MSE1_K_sig = np.array(sum(torch.Tensor(MSE1_K_sig_I)).to('cpu')/I)
    MSE2_K_sig = np.array(sum(torch.Tensor(MSE2_K_sig_I)).to('cpu')/I)


    MSE1_K[i] = MSE1_K_sig
    MSE2_K[i] = MSE2_K_sig

  MSE1_B[cnt] = MSE1_K
  MSE2_B[cnt] = MSE2_K
  cnt+=1

#np.savetxt("RMSE_CUQ_v_"+str(v)+"_"+str(d)+".dat", MSE_CUQ, delimiter =", ", fmt ='% s')

**Plotting**

In [None]:
pts = 14 ## Number of points for plotting. Maximum = len(sigma_range)
font = {'weight': 'normal', 'size': 14}
matplotlib.rc('font', **font)

for b in range(len(B_range)):
  B1 = B_range[b]
  plt.plot(Clients[:pts], MSE1_B[b][:pts], label='MSE_TM_WZ_OTA')
  plt.plot(Clients[:pts], MSE2_B[b][:pts], label='MSE_TM_WZ_OTA (noiseless)')

  plt.legend()
  plt.xlabel('K')
  plt.ylabel('RMSE x l')
  #plt.xscale('log')
  #plt.yscale('log')
  plt.xticks(ticks= Clients[:pts], labels=Clients[:pts])
  plt.title('d='+str(d)+', iters='+str(I)+', B/sigma='+str(B1/sig)+'SNR='+str(SNR))
  plt.grid()
  plt.show()

for b in range(len(B_range)):
  B1 = B_range[b]
  for i in range(len(Clients)):
    K = Clients[i]
    plt.plot(snr_range[:pts], MSE1_B[b][i][:pts], label='MSE_TM_CUQ_K='+str(K))
    plt.plot(snr_range[:pts], MSE2_B[b][i][:pts], label='MSE_TM_CUQ_K='+str(K)+' (noiseless)')
    plt.legend()
    plt.xlabel('SNR')
    plt.ylabel('RMSE')
    plt.xticks(ticks= snr_range[:pts], labels=snr_range[:pts])
    plt.title('d='+str(d)+', iters='+str(I)+', B/sigma='+str(B1/sigma_range[0]))
    plt.grid()
    plt.show()

b_by_sig = [B_range[b]/sig/np.sqrt(3) for b in range(len(B_range))]


MSE1_modified = np.array(MSE1_B).transpose()
MSE2_modified = np.array(MSE2_B).transpose()

for i in range(len(Clients)):
  K = Clients[i]

  for k in range(len(snr_range)):
    plt.plot(b_by_sig[:pts], MSE1_modified[k][i][:pts], label='MSE_TM_CUQ')
    plt.plot(b_by_sig[:pts], MSE2_modified[k][i][:pts], '.', label='MSE_TM_CUQ'+' (noiseless)')
    plt.legend()
    plt.xlabel('b_by_sig')
    plt.ylabel('RMSE')
    #plt.xscale('log')
    plt.xscale('log')
    plt.xticks(ticks= b_by_sig[:pts], labels=b_by_sig[:pts])
    plt.title('d='+str(d)+', iter='+str(I)+', K='+str(K)+', SNR = '+str(snr_range[k])+'dB')
    plt.grid()
    plt.show()
