In [7]:
import math
import random
import numpy as np
import copy
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from scipy.linalg import toeplitz

device = torch.device('cuda') # 'cpu' or 'cuda'
print(torch.__version__)


## model parameters
n=16 #trasmit antenna
m=16 #receive antenna
##
N=2*n
M=2*m

## training paramete
max_itr = 10 # max_itr
bs = 2000 # mini batch size
num_batch = 1000 # number of mini batches
lr_adam = 5e-3 # learning rate of optimizer
##

## parameters for evauation of generalization error
#total_itr=30 # total number of iterations (multiple number of "itr")
sample = 1000
bs_ = 2000 # number of samples
##


# SB settings
eps = 1.0 #0.1
T_max = max_itr #2000
pump_SB = 1.0/(T_max*eps) #0.01 # pump coeff
D_SB = 1. # delta
xi_SB = 0.1 # xi_0

2.0.1+cu118


In [9]:
def x_gen(bs,n):
    x = torch.rand(bs,n).to(device)
    x[x<0.5] = -1
    x[x>0.5] = 1
    return x
def y_gen(bs,m,x0,H,sigma_std):
    return x0@H+ torch.normal(0.0, sigma_std*torch.ones(bs, m)).to(device)

def trans_2_QUBO(H,y):
    J = H@H.t() - torch.diag(torch.diagonal(H@H.t(),0))
    h = -2*y@H.t()
    lmax_2 = ((J*J).sum()/(N*(N-1)))**0.5 #estimated max. eig.
    return J,h, 1.0/(2*N**0.5*lmax_2)

def trans_2_QUBO_LMMSE(H,y,lam):
    H_inv = torch.linalg.inv(H.t()@H+lam*torch.eye(M,device=device)) #dim:M*M
    J = H@H_inv@H.t() - torch.diag(torch.diagonal(H@H_inv@H.t(),0))
    h = -2*y@H_inv@H.t()
    lmax_2 = ((J*J).sum()/(N*(N-1)))**0.5 #estimated max. eig.
    return J,h, 1.0/(2*N**0.5*lmax_2)

def BER(x,y):
    z = torch.ones(x.size()).to(device)
    z[torch.isclose(torch.sign(x),torch.sign(y))] = 0.
    return z.sum()/(z.numel())

seed_ =12
torch.manual_seed(seed_)

# QPSK
def H_gen(m,n):
    H_re = torch.normal(0.0, std=math.sqrt(0.5) * torch.ones(n,m))
    H_im = torch.normal(0.0, std=math.sqrt(0.5) * torch.ones(n,m))  # sensing matrix
    H = torch.cat((torch.cat((H_re,H_im),0),torch.cat((-1*H_im,H_re),0)),1)
    H = H.to(device)
    return H


def est_SNR(snr,m,n):
    sigma2 = (2*n/math.pow(10,snr/10.0))/2.0
    sigma_std = math.sqrt(sigma2)
    return sigma_std

### DU
class DU_dSB(nn.Module):
    def __init__(self):
        super(DU_dSB,self).__init__()
        self.eps = nn.Parameter( eps + 0.01 * torch.randn(max_itr,device=device))
        self.lam = nn.Parameter(torch.ones(1,device=device))
        self.xi = nn.Parameter(torch.ones(1,device=device))
        self.beta = nn.Parameter(1.0*torch.ones(max_itr,device=device))

    def Pump(self,t):
        tsum = (self.eps).sum()
        return t/tsum

    def Softclamp(self,x):
        si = nn.SiLU()
        return si(10*(x+1))/10-si(10*(x-1))/10 -1

    def dSB(self,k, q,p,t,eps,J,h,D_SB,pump_SB,xi_SB):
        sq = q
        DE_QUBO = 0.5 * h + sq@J
        p_ = p - eps * ((D_SB-self.Pump(t)) * q + self.xi*xi_SB * DE_QUBO) #diff(q,Po)
        q_ = q + (eps * D_SB) * p_ #diff(p,K)
        q_2 = self.Softclamp(q_) #torch.clamp(q_, min=-1., max=1.)
        p_ = p_ - p_ * torch.sigmoid(100* (torch.abs(q_)-1.01) )
        # naively torch.heaviside(1-torch.abs(q_), torch.zeros(1,device=device)) but its derivative is unaveilabe
        return q_2,p_

    def forward(self, H,y,itr,bs):
        J, h, xi_SB= trans_2_QUBO_LMMSE(H,y,self.lam**2)
        q = torch.zeros(bs, N,device=device) # x
        p = torch.zeros(bs, N,device=device) # y
        q_traj = np.zeros([T_max, N]) # trajectory
        p_traj = np.zeros([T_max, N]) # trajectory

        p = 0.01*torch.randn(bs, N,device=device) # random init point

        t = 0.0

        for i in range(itr):
            k = i % max_itr
            t = t + self.eps[k]
            q, p = self.dSB(k,q,p,t,self.eps[k],J,h,D_SB,pump_SB,xi_SB)
            q_traj[i]=q[0,:].cpu().detach().numpy()
            p_traj[i]=p[0,:].cpu().detach().numpy()

        return q, p, q_traj, p_traj

def train(snr):
    #SNR
    sigma_std = est_SNR(snr, m,n)

    network = DU_dSB().to(device)  # generating an instance of TPG-detector
    opt = optim.Adam(network.parameters(), lr=lr_adam )  # setting for optimizer
    torch.autograd.set_detect_anomaly(True)

    torch.manual_seed(1)
    network.train()


    print("-------------------")
    print("snr=",snr)
    for i in range(num_batch):#num_batch):
        H = H_gen(m,n)
        sol = x_gen(bs,N)
        y = y_gen(bs,M,sol,H,sigma_std)

        opt.zero_grad()
        x_hat,_,q_traj ,_= network(H,y,max_itr,bs)

        loss = F.mse_loss(x_hat, sol) #squared_loss
        loss.backward()

        if i % 100 == 0:
            print('loss{0}:{1}'.format(i,loss.data), "BER:", BER(sol,x_hat.sign())) #print_loss
        opt.step()

    print("-----\n SNR:",snr.item())
    print("eps=",network.eps)
    print("beta=",network.beta)
    print("xi=",network.xi)
    print("lam=",network.lam**2)
    print("train done")
    return network

# Generalization Error Evaluation
def eval(network,snr):
    ber_= 0.0
    it = max_itr
    sigma_std = est_SNR(snr, m,n)

    for i in range(sample):
        H = H_gen(m,n)
        sol = x_gen(bs_,N)
        y = y_gen(bs_,M,sol,H,sigma_std)
        xx = torch.zeros(bs_,N,device=device)

        res = 100*torch.ones(bs_,N,device=device)
        x_hat,_,q_traj ,_= network(H,y, it,bs_)
        res_ = (y-x_hat.sign()@H).norm(dim=1).view(bs_,1).repeat(1,N).view(bs_,N) # OK
        xx[res_<res] = x_hat[res_<res]
        res[res_<res] = res_[res_<res]
        ber_ += BER(sol,xx.sign())

    ber_ = ber_/sample
    print("SNR:",snr.item(),"BER (generalization):",ber_.item())


# main part
print("#_ ", "n=", n, "m=",m,"max_itr=", max_itr, "bs=",bs, "num_batch=", num_batch,"learning_rate=", lr_adam)

for snr in torch.arange(15,16,2.5):
    net = train(snr)
    eval(net,snr)


#_  n= 16 m= 16 max_itr= 10 bs= 2000 num_batch= 1000 learning_rate= 0.005
-------------------
snr= tensor(15.)
loss0:0.00400887243449688 BER: tensor(0.0012, device='cuda:0')
loss100:0.0007100640796124935 BER: tensor(0.0002, device='cuda:0')
loss200:0.0015155017608776689 BER: tensor(0.0004, device='cuda:0')
loss300:0.0005173049285076559 BER: tensor(0.0001, device='cuda:0')
loss400:0.0011304032523185015 BER: tensor(0.0003, device='cuda:0')
loss500:9.547865920467302e-05 BER: tensor(3.1250e-05, device='cuda:0')
loss600:0.0005380481597967446 BER: tensor(0.0001, device='cuda:0')
loss700:0.0004004885850008577 BER: tensor(9.3750e-05, device='cuda:0')
loss800:0.00010085271787829697 BER: tensor(3.1250e-05, device='cuda:0')
loss900:0.0007662293501198292 BER: tensor(0.0002, device='cuda:0')
-----
 SNR: 15.0
eps= Parameter containing:
tensor([0.8615, 1.1523, 1.3482, 1.1030, 1.2161, 1.8658, 1.3472, 0.7327, 1.6134,
        2.0544], device='cuda:0', requires_grad=True)
beta= Parameter containing:
tens