## 此jupyter仿真了CAN+模型驱动预编码+基线预编码方案


注意，在论文中我们将噪声功率谱密度固定为-174dBm/Hz，并调整发射功率来绘制仿真曲线。而在此代码中，我们统一将发射功率固定归一化为为36dBm，并将噪声功率设置为可调参数通过更改噪声功率来绘制曲线。事实上，这两者是等价的，只需要保证噪声与发射功率的比值不变就可以得到论文中的仿真结果。
此外，我们并未在信道生成模块里引入发射与接收天线增益，这一部分直接被考虑在发射功率与噪声功率的计算中。由于大尺度衰落较大导致信道幅值过小而不便于神经网络处理，我们对BS-RIS信道的幅度乘上1e6，对RIS-UE信道幅度乘上1e5，相当于能量提升了220dB，相应的我们也需要将噪声功率扩大220dB。这些问题均已在计算发射功率与噪声功率时被考虑。
总的来说，仿真代码与论文中的发射功率与噪声功率的比值是相同的，并且不同方案在仿真中使用的发射功率与噪声功率设置是相同的，因此能够保证不同方案的相对性能与论文一致。


### import

In [2]:
import torch
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 使用第一块GPU（从0开始）
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import random
import torch.nn.functional as F
import torchvision
import numpy as np
from math import *
import matplotlib.pyplot as plt
from torch.autograd import Variable
from IPython import display
import torch.utils.data as Data
import torch.nn as nn
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
from scipy.linalg import block_diag
import datetime
from torch.nn.utils import *
import cvxpy as cp
from Transformer_model import *
from RIS import *

### 频谱效率计算函数定义

In [3]:
### 定义RSMA的频谱效率损失函数
def SE_RS_BB(batch,Nc,K,Nt,N_RS1,N_RS2,RS_cor,H,F_SDMA,F_RS,sigma2):
    #N_RS1：公共数据流数 N_RS2：每个公共数据流服务的用户数，如果是broadcast就是K，如果是NOMA就是1，满足两者相乘等于K，无重叠
    #RS_cor[N_RS1,N_RS2]每个公共束流服务的用户索引
    #H[batch,Nc,K,Nt]  F_RF[batch,Nt,K] F_BB_SDMA[batch,K,K] F_BB_RS[batch,K,RS] 

    H_RS = H @ F_RS #[batch,Nc,K,N_RS1]
    H_SDMA = H @ F_SDMA #[batch,Nc,K,K]

    R = 0
    R_SA = 0
    R_SDMA = 0
    for a in range(N_RS1): #遍历每个公共数据流
        SINRck = torch.zeros(batch,Nc,N_RS2).cuda()
        b_num = 0
        for b in RS_cor[a]: #遍历每个用户算公共数据流频谱效率，取最低的那一个
            signal = torch.abs(H_RS[:,:,b,a]*H_RS[:,:,b,a])#[batch,Nc]
            Interfer = torch.zeros(batch,Nc).cuda()
            for c in range(N_RS1): #算公共数据流间干扰,算其他公共流到这个用户的功率
                if c!=a:
                    Interfer = Interfer + torch.abs(H_RS[:,:,b,c]*H_RS[:,:,b,c])
            for c in range(K):#算所有私有数据流的干扰
                Interfer = Interfer + torch.abs(H_SDMA[:,:,b,c]*H_SDMA[:,:,b,c])
            SINRck[:,:,b_num] = signal/(Interfer+sigma2)
            b_num = b_num + 1
        SINRc,aa = torch.min(SINRck,2) #[batch,Nc]
        # RR = torch.log2(1+SINRc)
        R_SA = R_SA + torch.mean(torch.log2(1+SINRc))
    R_SA = R_SA/K

    R_k = torch.zeros(batch,K).cuda()
    for a in range(N_RS1): 
        for b in RS_cor[a]: #遍历所有私有数据流,这里因为假设了公共数据流不重叠，所以这样遍历私有数据流以区分每个用户所对应的公共数据流
            #因为假设等效信道在用户端可以完美获取，因而自己的公共数据流干扰可以完美消除，而不能消除其他用户的公共数据流
            signal = torch.abs(H_SDMA[:,:,b,b]*H_SDMA[:,:,b,b])#[batch,Nc]
            Interfer = torch.zeros(batch,Nc).cuda()
            for c in range(N_RS1): #算公共数据流间干扰,算其他公共流到这个用户的功率
                if c!=a:
                    Interfer = Interfer + torch.abs(H_RS[:,:,b,c]*H_RS[:,:,b,c])
            for c in range(K):#算所有私有数据流的干扰
                if c!=b:
                    Interfer = Interfer + torch.abs(H_SDMA[:,:,b,c]*H_SDMA[:,:,b,c])
            SINR = signal/(Interfer+sigma2)#[batch,Nc]
            R_k[:,b] = torch.mean(torch.log2(1+SINR),1)#[batch]
        R_min,aa = torch.min(R_k,1) #[batch]
        R_SDMA = R_SDMA + torch.mean(R_min)
            # print(torch.mean(torch.log2(1+SINR)))
    R = R_SDMA + R_SA
    return R,R_SA,R_SDMA
class SE_RS(torch.nn.Module):   #输入是信道和整个F,输出是频谱效率
    def __init__(self):
        super(SE_RS, self).__init__()
    def forward(self, param_list, H_BR,H_RU,Phi,F_RF,F_BB_SDMA,F_BB_RS):#H0第0个维度是样本 第1个维度是用户，第2个维度是子载波，第3个维度是天线
        #N_RS1：公共数据流数 N_RS2：每个公共数据流服务的用户数，如果是broadcast就是K，如果是NOMA就是1，满足两者相乘等于K，无重叠
        #Phi[batch,1,M_ant,M_ant]  F_RF[batch,1,M_ant,K]  F_BB_SDMA[batch,Nc,K,K] F_BB_RS[batch,Nc,K,1] 
        N_RS1 = 1
        N_RS2 = K
        RS_cor = [[0,1,2,3]]

        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]

        batch = H_RU.shape[0]

        H_equ = (torch.zeros(batch,Nc,K,K) + 0j).cuda()
        for i in range(K):
            H_equ[:,:,i] = (H_RU[:,i] @ Phi @ H_BR @ F_RF).reshape(-1,Nc,K)
        
        R,R_SA,R_SDMA = SE_RS_BB(batch,Nc,K,K,N_RS1,N_RS2,RS_cor,H_equ,F_BB_SDMA,F_BB_RS,sigma)
        return R,R_SA,R_SDMA

### 提出的信道获取和预编码方案的网络模块定义

In [6]:
### 定义CAN
class CAN_EN(nn.Module): 
    def __init__(self, param_list): 
        super(CAN_EN,self).__init__()
        
#         param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L]

        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        
        self.S_digital = torch.nn.Parameter(torch.randn(L,Nc,4,1)+1j*torch.randn(L,Nc,4,1)) #L个时隙BS导频,这里假设频率选择
        self.S_analog = torch.nn.Parameter(torch.randn(L,1,M_ant,4)) 
        self.Phase = torch.nn.Parameter(torch.randn(L,M_ant))   #L个时隙RIS相位
        

        
        self.B_quan = 2
        self.trans  = TRANS_BLOCK(2*L,2*L,256,6)
        self.linear = nn.Linear(2*L*Nc, B//self.B_quan) 
        self.bn = nn.BatchNorm1d(B//self.B_quan)
        
        self.QL = QuantizationLayer(self.B_quan)
        
        
        
        
    def forward(self, param_list, H_BR, H_RU):
        #H_BR [ M_ant, M_ant]  H_RU [batch,Nc,M_ant]
        # param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,SNR]
        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        
        Y = torch.zeros(H_RU.shape[0],Nc,L).cuda() + 1j#总观测【batch,Nc,L】
        for l in range(L):
            s = (torch.exp(1j*self.S_analog[l,:,:,:]) @ self.S_digital[l,:,:,:] ).reshape(Nc,M_ant,1)  #导频s [Nc,M_ant]    
            
            s_sigma = torch.sqrt(torch.sum(torch.abs(s*s)))
            s = s/s_sigma.reshape(1,1,1)*torch.min(s_sigma,torch.tensor(sqrt(Nc))).reshape(1,1,1)
#             print(s)
            n = (torch.randn(H_RU.shape[0],Nc) + 1j*torch.randn(H_RU.shape[0],Nc)).cuda()*sqrt(sigma/2)
            
#             print(n)
            
            Theta = torch.diag(torch.exp(1j*self.Phase[l,:])) #RIS反射矩阵【M_ant,M_ant】 
            a =  (Theta @ H_BR @ s).permute(0,2,1) #a[Nc,1,M_ant] = Theta * G * s_l
#             print((a @ H_RU.reshape(-1,Nc,M_ant,1)).reshape(-1,Nc))
            y = (a @ H_RU.reshape(-1,Nc,M_ant,1)).reshape(-1,Nc) + n #[batch, Nc]
            Y[:,:,l] = y
        x = torch.cat((torch.real(Y),torch.imag(Y)), 2) #[batch, Nc, 2*L]
        x = self.trans(x)
        x = x.reshape(-1,2*Nc*L)
        x = self.linear(x)
        x = self.bn(x)
        x = torch.sigmoid(x)
        x = self.QL(x)
        return x

class CAN_DE(nn.Module): 
    def __init__(self, param_list): 
        super(CAN_DE,self).__init__()
        
#         param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L]

        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        
        self.B_quan = 2
        self.DQL = DequantizationLayer(self.B_quan)
        self.FC1 = nn.Linear(B//self.B_quan,2*Nc*L)#全连接
        self.bn1 = nn.BatchNorm1d(2*Nc*L)
        
        
        self.trans  = TRANS_BLOCK(2*L,2*M_ant,256,4)
        
        
        
        
    def forward(self, param_list,  x):
        #H_BR [ M_ant, M_ant]  H_RU [batch,Nc,M_ant]
        # param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,SNR]
        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        
        x = self.DQL(x)-0.5
        x = self.FC1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        
        x = x.reshape(-1,Nc,2*L)
        x = self.trans(x)
        
        H_RU = x[:,:,0:M_ant] + 1j*x[:,:,M_ant:2*M_ant]
        return H_RU

class CAN(nn.Module): 
    def __init__(self, param_list): 
        super(CAN,self).__init__()
        self.encoder = CAN_EN(param_list)
        self.decoder = CAN_DE(param_list)
        
        
    def forward(self, param_list,H_BR, H_RU1):
        
        out = self.encoder(param_list,H_BR,H_RU1)
        H_RU2 = self.decoder(param_list,out)


        return H_RU2

In [7]:
############################# 定义RSMA预编码网络，RIS数据驱动 RSMA数字预编码模型驱动
class RIS_RSMA_Precoding(nn.Module): 
    def __init__(self, param_list): 
        super(RIS_RSMA_Precoding,self).__init__()
        
#         param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L]

        fc = param_list[0]
        BW  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        K = param_list[14]
        

        self.linear_H  = nn.Linear(M_ant*2, 32)
        
        self.trans1  = TRANS_BLOCK(32*K,32,256,3)
        self.linear_RIS = nn.Linear(Nc*32, M_ant)   #生成RIS相位

        
        self.trans2  = TRANS_BLOCK(2*K*K,8*K+2,256,1) #用Transformer在每个子载波上生成所需的参数
        

        
        
        
    def forward(self, param_list, H_BR, H_RU):
        #H_RU [batch,K,Nc,1,M_ant]  s [batch,Nc,2*K] H_BR[1,Nc,M_ant,M_ant]
        # param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,SNR]
        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        K = param_list[14]
        
        batch = H_RU.shape[0]
        
        ## 用Transformer加一些全连接生成RIS的香味，再变换成RIS反射矩阵
        x = torch.cat((torch.real(H_RU),torch.imag(H_RU)), 4) #[batch,K,Nc,1,2*M_ant]
        x = x.reshape(-1,K,Nc,2*M_ant)
        x = self.linear_H(x)   #[batch,K,Nc,32]
        x = x.permute(0,2,1,3) #[batch,Nc,K,32]
        x = x.reshape(-1,Nc,K*32)#[batch,Nc,K*32]
        
        x = self.trans1(x)   #[batch,Nc,32]
        out_RIS = x.reshape(-1,Nc*32)
        Phi_phase = self.linear_RIS(out_RIS)
        Phi = torch.exp(1j*Phi_phase.reshape(-1,1,M_ant))
        Phi = torch.diag_embed(Phi)  #[batch,1,M_ant,M_ant]



        ## 模拟预编码就用MF直接对准
        F_RF = (torch.zeros(M_ant,K) + 0j).cuda()
        for i in range(K):
            F_RF[:,i] = H_BR[Nc//2,(i*N*N+6),:]
        F_RF = F_RF.reshape(1,1,M_ant,K)
        F_RF = torch.real(F_RF) - 1j*torch.imag(F_RF)
        F_RF = (F_RF/torch.abs(F_RF))/sqrt(M_ant)
        
        

        ## 得到等效信道
        H_equ = (torch.zeros(batch,Nc,K,K) + 0j).cuda()
        H_RF = H_BR @ F_RF
        H_RF = Phi @ H_RF
        for i in range(K):
            H_equ[:,:,i] = (H_RU[:,i] @ H_RF).reshape(-1,Nc,K)
            
        ## 为了保证输入网络的等效信道数值不太大或太小，先归一化。这里只是将网络输入数值改变，并不会影响信噪比。
        power = torch.sum(torch.abs(H_equ*H_equ),[2,3])
        H_equ = H_equ/torch.sqrt(power.reshape(-1,Nc,1,1))*sqrt(K)
        

        ## 用Transformer输出所需的参数A
        x = H_equ.reshape(-1,Nc,K*K)
        x = torch.cat((torch.real(x),torch.imag(x)), 2) #[batch,Nc,2*K*K]
        x = self.trans2(x)   #[batch,Nc,8*K+2]

        H_hat = H_equ.to(torch.complex128)
        pri1 = (x[:,:,0:K] + 1j*x[:,:,K:2*K]).to(torch.complex128)    #[batch,Nc,K]
        pri2 = (x[:,:,2*K:3*K] + 1j*x[:,:,3*K:4*K]).to(torch.complex128)   #[batch,Nc,K]
        pub1 = (x[:,:,4*K:5*K] + 1j*x[:,:,5*K:6*K]).to(torch.complex128) #[batch,Nc,K]
        pub2 = (x[:,:,6*K:7*K] + 1j*x[:,:,7*K:8*K]).to(torch.complex128) #[batch,Nc,K]

        sigma_pri = (x[:,:,8*K]).to(torch.complex128)  #[batch,Nc]
        sigma_pub = (x[:,:,8*K+1]).to(torch.complex128)#[batch,Nc]

        ## 根据公式26和27用上述的参数得到数字预编码
        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,K))
            if(k==0):
                B_pri = pri2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pri = B_pri + (sigma_pri.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))

                B_pub = pub2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pub = B_pub + (sigma_pub.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))
            else:
                B_pri = B_pri + pri2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pri = B_pri + (sigma_pri.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))

                B_pub = B_pub + pub2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pub = B_pub + (sigma_pub.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))

        V_pri = (torch.zeros((batch,Nc,K,K)) + 0j).cuda().to(torch.complex128)
        V_pub = (torch.zeros((batch,Nc,K,1)) + 0j).cuda().to(torch.complex128)
        B_inv_pub = torch.inverse(B_pub)
        B_inv_pri = torch.inverse(B_pri)
        # print(np.sum(np.abs(A_inv)**2))
        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,K))
            V_pri[:,:,:,k] = (B_inv_pri @ (pri1[:,:,k].reshape(batch,Nc,1,1) * hk)).reshape(batch,Nc,K)
        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,K))
            if k==0:
                A = pub1[:,:,k].reshape(batch,Nc,1,1) * hk
            else:
                A = A + pub1[:,:,k].reshape(batch,Nc,1,1) * hk
        V_pub = (B_inv_pub @ A)

        w_pub = V_pub
        W_pri = V_pri


        F_BB_RS = (w_pub + 0).to(torch.complex64)

        F_BB_SDMA = (W_pri + 0).to(torch.complex64)


        ## 这里将预编码矩阵的总能量归一化到4。即36dBm
        F_SDMA = F_RF @ F_BB_SDMA  #[batch,Nc,M_ant,K]
        F_RS = F_RF @ F_BB_RS  #[batch,Nc,M_ant,N_RS1]
        Power = (torch.sum(torch.abs(F_SDMA)**2,[2,3]) + torch.sum(torch.abs(F_RS)**2,[2,3])).reshape(-1,Nc,1,1)
        F_BB_SDMA = F_BB_SDMA / torch.sqrt(Power) * sqrt(K)
        F_BB_RS = F_BB_RS / torch.sqrt(Power) * sqrt(K)


        return Phi,F_RF,F_BB_SDMA,F_BB_RS  #Phi[batch,1,M_ant,M_ant]  F_RF[batch,1,M_ant,K]  F_BB_SDMA[batch,Nc,K,K] F_BB_RS[batch,Nc,K,1] 
#         print(H_RU.shape) 

In [8]:
############################# 定义SDMA预编码网络，RIS数据驱动 SDMA数字预编码模型驱动,这种情况下能量全部被分配给私有数据流
class RIS_SDMA_Precoding(nn.Module): #单层
    def __init__(self, param_list): 
        super(RIS_SDMA_Precoding,self).__init__()
        
#         param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L]

        fc = param_list[0]
        BW  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        K = param_list[14]
        

        self.linear_H  = nn.Linear(M_ant*2, 32)
        
        self.trans1  = TRANS_BLOCK(32*K,32,256,3)
        self.linear_RIS = nn.Linear(Nc*32, M_ant)   #生成RIS相位

        # self.linear_RF = nn.Linear(Nc*32, M_ant*K)   #生成RIS相位
        
        self.trans2  = TRANS_BLOCK(2*K*K,4*K+1,256,1) #4×4等效信道
        

        
        
        
    def forward(self, param_list, H_BR, H_RU):
        #H_RU [batch,K,Nc,1,M_ant]  s [batch,Nc,2*K] H_BR[1,Nc,M_ant,M_ant]
        # param_list = [fc,B,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,SNR]
        fc = param_list[0]
        B  = param_list[1]
        Nc = param_list[2]
        M  = param_list[3]
        N  = param_list[4]
        D_sub = param_list[5]
        D_ant = param_list[6]
        R = param_list[7]
        M_ant = param_list[8]
        h1 = param_list[9]
        h2 = param_list[10]
        L =  param_list[11]
        sigma = param_list[12]
        B = param_list[13]
        K = param_list[14]
        
        batch = H_RU.shape[0]
        
        x = torch.cat((torch.real(H_RU),torch.imag(H_RU)), 4) #[batch,K,Nc,1,2*M_ant]
        x = x.reshape(-1,K,Nc,2*M_ant)
        x = self.linear_H(x)   #[batch,K,Nc,32]
        x = x.permute(0,2,1,3) #[batch,Nc,K,32]
        x = x.reshape(-1,Nc,K*32)#[batch,Nc,K*32]
        
        x = self.trans1(x)   #[batch,Nc,32]
        out_RIS = x.reshape(-1,Nc*32)
        Phi_phase = self.linear_RIS(out_RIS)
        Phi = torch.exp(1j*Phi_phase.reshape(-1,1,M_ant))
        Phi = torch.diag_embed(Phi)  #[batch,1,M_ant,M_ant]

        # out_RF = x[:,:,32:64].reshape(-1,Nc*32)
        # RF_phase = self.linear_RF(out_RF)
        # F_RF = torch.exp(1j*RF_phase.reshape(-1,1,M_ant,K))/sqrt(M_ant) #[batch,1,M_ant,K]


        F_RF = (torch.zeros(M_ant,K) + 0j).cuda()
        for i in range(K):
            F_RF[:,i] = H_BR[Nc//2,(i*N*N+6),:]
        F_RF = F_RF.reshape(1,1,M_ant,K)
        F_RF = torch.real(F_RF) - 1j*torch.imag(F_RF)
        F_RF = (F_RF/torch.abs(F_RF))/sqrt(M_ant)
        
        
        
#         Phi = (torch.zeros(batch,1,K,N*N) + 0j).cuda()
#         for i in range(K):
#             h = H_RU[:,i,Nc//2,0,i*N*N:(i+1)*N*N]
#             Phi[:,0,i,:] = (h/torch.abs(h)).reshape(batch,N*N)
# #         print(Phi[:,0,0,:])
#         Phi = torch.real(Phi) - 1j*torch.imag(Phi)
#         Phi = Phi.reshape(batch,1,M_ant)
#         Phi = torch.diag_embed(Phi)
        
        H_equ = (torch.zeros(batch,Nc,K,K) + 0j).cuda()

        H_RF = H_BR @ F_RF
        H_RF = Phi @ H_RF
        for i in range(K):
            H_equ[:,:,i] = (H_RU[:,i] @ H_RF).reshape(-1,Nc,K)
            
        ############重要
        # for i in range(K):
        #     H_equ[:,:,i] = (H_RU[:,i] @ Phi @ H_BR @ F_RF).reshape(-1,Nc,K)
        n = (torch.randn(H_RU.shape[0],Nc,K) + 1j*torch.randn(H_RU.shape[0],Nc,K)).cuda()*sqrt(sigma/2)
        
        power = torch.sum(torch.abs(H_equ*H_equ),[2,3])
        H_equ = H_equ/torch.sqrt(power.reshape(-1,Nc,1,1))*sqrt(K)
        

        x = H_equ.reshape(-1,Nc,K*K)
        x = torch.cat((torch.real(x),torch.imag(x)), 2) #[batch,Nc,2*K*K]
        x = self.trans2(x)   #[batch,Nc,4*K+1]

        H_hat = H_equ.to(torch.complex128)
        pri1 = (x[:,:,0:K] + 1j*x[:,:,K:2*K]).to(torch.complex128)    #[batch,Nc,K]
        pri2 = (x[:,:,2*K:3*K] + 1j*x[:,:,3*K:4*K]).to(torch.complex128)   #[batch,Nc,K]
        # pub1 = (x[:,:,4*K:5*K] + 1j*x[:,:,5*K:6*K]).to(torch.complex128) #[batch,Nc,K]
        # pub2 = (x[:,:,6*K:7*K] + 1j*x[:,:,7*K:8*K]).to(torch.complex128) #[batch,Nc,K]

        sigma_pri = (x[:,:,4*K]).to(torch.complex128)  #[batch,Nc]
        # sigma_pub = (x[:,:,8*K+1]).to(torch.complex128)#[batch,Nc]


        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,K))
            if(k==0):
                B_pri = pri2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pri = B_pri + (sigma_pri.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))

                # B_pub = pub2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                # B_pub = B_pub + (sigma_pub.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))
            else:
                B_pri = B_pri + pri2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pri = B_pri + (sigma_pri.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))

                # B_pub = B_pub + pub2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                # B_pub = B_pub + (sigma_pub.reshape(batch,Nc,1,1) * torch.eye(K).cuda().to(torch.complex128))

        V_pri = (torch.zeros((batch,Nc,K,K)) + 0j).cuda().to(torch.complex128)
        # V_pub = (torch.zeros((batch,Nc,K,1)) + 0j).cuda().to(torch.complex128)
        # B_inv_pub = torch.inverse(B_pub)
        B_inv_pri = torch.inverse(B_pri)
        # print(np.sum(np.abs(A_inv)**2))
        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,K))
            V_pri[:,:,:,k] = (B_inv_pri @ (pri1[:,:,k].reshape(batch,Nc,1,1) * hk)).reshape(batch,Nc,K)
        # for k in range(K):
        #     hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,K))
        #     if k==0:
        #         A = pub1[:,:,k].reshape(batch,Nc,1,1) * hk
        #     else:
        #         A = A + pub1[:,:,k].reshape(batch,Nc,1,1) * hk
        # V_pub = (B_inv_pub @ A)

        w_pub = (torch.zeros((batch,Nc,K,1)) + 0j).cuda().to(torch.complex128)
        W_pri = V_pri


        F_BB_RS = (w_pub + 0).to(torch.complex64)

        F_BB_SDMA = (W_pri + 0).to(torch.complex64)

        F_SDMA = F_RF @ F_BB_SDMA  #[batch,Nc,M_ant,K]
        F_RS = F_RF @ F_BB_RS  #[batch,Nc,M_ant,N_RS1]

        Power = (torch.sum(torch.abs(F_SDMA)**2,[2,3]) + torch.sum(torch.abs(F_RS)**2,[2,3])).reshape(-1,Nc,1,1)

        F_BB_SDMA = F_BB_SDMA / torch.sqrt(Power) * sqrt(K)
        F_BB_RS = F_BB_RS / torch.sqrt(Power) * sqrt(K)


        return Phi,F_RF,F_BB_SDMA,F_BB_RS  #Phi[batch,1,M_ant,M_ant]  F_RF[batch,1,M_ant,K]  F_BB_SDMA[batch,Nc,K,K] F_BB_RS[batch,Nc,K,1] 
#         print(H_RU.shape) 

### 网络训练函数定义

In [13]:
def train_CAN(param_list,batch,STEPS,model_name_AE): #训练CAN
    
    fc = param_list[0]
#     B  = param_list1[1]
#     Nc = param_list1[2]
    M  = param_list[3]
    N  = param_list[4]
    D_sub = param_list[5]
    D_ant = param_list[6]
    R = param_list[7]
    M_ant = param_list[8]
    h1 = param_list[9]
    h2 = param_list[10]
    L =  param_list[11]
    sigma = param_list[12]
    B = param_list[13]
    Lp = param_list[15]

    net = CAN(param_list).cuda()
        
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=0.001)
    
    Ay = (M-1)*D_sub + (N-1)*(D_ant);#阵列长度
    Az = Ay;

    r_BS  = (torch.tensor([0,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()
    r_RIS = (torch.tensor([R,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()

    
    ##随机生成信道
    H_BR = Channel_BS_RIS(param_list) * 1e6
    
    
    
    num_train = 0
    train_nmse = 0
    best_nmse = 1
    start = datetime.datetime.now()
    numStep = 1
    warmup_steps = 4000
    d_model = 256
    for step in range(STEPS):
        optimizer.param_groups[0]['lr'] = (d_model**(-0.5))*min(numStep**(-0.5),numStep*warmup_steps**(-1.5))
        numStep = numStep+1
        net.train()
        
        alpha = 1e5
        H_RU = Channel_RIS_UE(param_list,batch,alpha)


        H_RU0 = net(param_list, H_BR, H_RU)
        
                
        loss = NMSE(H_RU0,H_RU)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        num_train = num_train + 1
        train_nmse = train_nmse + loss
        if (step+1)%50==0:
            with torch.no_grad():
                net.eval()
                alpha = 1e5
                H_RU = Channel_RIS_UE(param_list,batch,alpha)
        
                H_RU0 = net(param_list, H_BR, H_RU)
        
                
                loss = NMSE(H_RU0,H_RU)
#                 print(H_BR)
#                 print(H_RU)

                
#                 print(R_UE0)
#                 print(R_UE)
#                 print(Theta_UE0)
#                 print(Theta_UE)
                
#                 r_UE0 = net(param_list, H_BR, H_RU)
#                 loss = loss_func(r_UE0, r_UE)
#             print(r_UE0)
#             print(r_UE)
            
            train_nmse = train_nmse/num_train
            time0 =  datetime.datetime.now()-start
            print('step:',step,'time',time0,'train NMSE %.5f' % train_nmse.cpu(),'test NMSE %.5f' % loss.cpu()) 
            if loss < best_nmse:
                best_nmse = loss
                torch.save(net, model_name_AE)
                print('Model saved!')
            num_train = 0
            train_nmse = 0
            start = datetime.datetime.now()

In [14]:
def train_Precoding_SDMA(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0): #训练RIS被动预编码和SDMA模型驱动主动预编码
    
    fc = param_list[0]
    BW  = param_list[1]
    Nc = param_list[2]
    M  = param_list[3]
    N  = param_list[4]
    D_sub = param_list[5]
    D_ant = param_list[6]
    R = param_list[7]
    M_ant = param_list[8]
    h1 = param_list[9]
    h2 = param_list[10]
    L =  param_list[11]
    sigma = param_list[12]
    B = param_list[13]
    K = param_list[14]
    Lp = param_list[15]

    loss_func = SE_RS()
    loss_func = loss_func.cuda()

    net_AE = CAN(param_list).cuda()
    net_AE = torch.load(model_name_AE)
    
    net_RS = RIS_SDMA_Precoding(param_list).cuda()
            
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net_RS.parameters()),lr=0.001)
    
    Ay = (M-1)*D_sub + (N-1)*(D_ant);#阵列长度
    Az = Ay;

    r_BS  = (torch.tensor([0,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()
    r_RIS = (torch.tensor([R,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()

    
    ##随机生成信道
    H_BR = Channel_BS_RIS(param_list) * 1e6
    
    batch_sum = batch*K
    
    num_train = 0
    train_nmse = 0
    train_SE = 0
    best_SE = 0
    start = datetime.datetime.now()
    numStep = 1
    warmup_steps = 4000
    d_model = 256


    for step in range(STEPS):
        optimizer.param_groups[0]['lr'] = (d_model**(-0.5))*min(numStep**(-0.5),numStep*warmup_steps**(-1.5))*LR0
        numStep = numStep+1
        net_AE.eval()
        net_RS.train()
        
        optimizer.zero_grad()
        for i in range(batch_num): #在这个循环内每次进行一次反向传播，进行batch_num次后梯度下降，这是为了在不增加显存的条件下增大batchsize
            alpha = 1e5
            H_RU = Channel_RIS_UE(param_list,batch_sum,alpha)
            with torch.no_grad():
                H_RU0 = net_AE(param_list, H_BR, H_RU).detach_()
            nmse = NMSE(H_RU0,H_RU)
            H_RU = H_RU.reshape(batch,K,Nc,1,M_ant)
            H_RU0 = H_RU0.reshape(batch,K,Nc,1,M_ant)

            # H_RU0 = torch.randn(batch,K,Nc,1,M_ant).cuda() + 0j
            
            Phi,F_RF,F_BB_SDMA,F_BB_RS = net_RS(param_list, H_BR, H_RU0)
            R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi,F_RF,F_BB_SDMA,F_BB_RS)
            loss = -R_sum/batch_num
            
            num_train = num_train + 1
            train_nmse = train_nmse + nmse.detach_()
            train_SE = train_SE + R_sum.detach_()

            # print(train_nmse,nmse.detach_(),num_train)

            loss.backward()

        
        optimizer.step()
        optimizer.zero_grad()
        
        if (step+1)%50==0:
            with torch.no_grad():
                net_AE.eval()
                net_RS.eval()
                test_SE = 0
                test_SDMA = 0
                for i in range(10):
                    alpha = 1e5
                    H_RU = Channel_RIS_UE(param_list,batch_sum,alpha)

                    H_RU0 = net_AE(param_list, H_BR, H_RU)

                    H_RU = H_RU.reshape(batch,K,Nc,1,M_ant)
                    H_RU0 = H_RU0.reshape(batch,K,Nc,1,M_ant)

                    Phi,F_RF,F_BB_SDMA,F_BB_RS = net_RS(param_list, H_BR, H_RU0)
                    R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi,F_RF,F_BB_SDMA,F_BB_RS)
                    test_SE = test_SE + R_sum
                    test_SDMA = test_SDMA + R_SDMA
                test_SE = test_SE/10
                test_SDMA = test_SDMA/10

                

            train_nmse = train_nmse/num_train
            train_SE = train_SE/num_train
            time0 =  datetime.datetime.now()-start
            print('step:',step,'time',time0,' NMSE %.5f' % train_nmse.cpu(),'train_SE %.5f' % train_SE.cpu(),'test_SE %.5f' % test_SE.cpu(),'test_SDMA %.5f' % test_SDMA.cpu()) 
            if test_SE > best_SE:
                best_SE = test_SE
                torch.save(net_RS, model_name_RS)
                print('Model saved!')
            num_train = 0
            train_nmse = 0
            train_SE = 0
            start = datetime.datetime.now()

In [6]:
def train_Precoding_RSMA(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0): #训练RIS被动预编码和RSA模型驱动主动预编码
    
    fc = param_list[0]
    BW  = param_list[1]
    Nc = param_list[2]
    M  = param_list[3]
    N  = param_list[4]
    D_sub = param_list[5]
    D_ant = param_list[6]
    R = param_list[7]
    M_ant = param_list[8]
    h1 = param_list[9]
    h2 = param_list[10]
    L =  param_list[11]
    sigma = param_list[12]
    B = param_list[13]
    K = param_list[14]
    Lp = param_list[15]

    loss_func = SE_RS()
    loss_func = loss_func.cuda()

    net_AE = CAN(param_list).cuda()
    net_AE = torch.load(model_name_AE)
    
    net_RS = RIS_RSMA_Precoding(param_list).cuda()
            
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net_RS.parameters()),lr=0.001)
    
    Ay = (M-1)*D_sub + (N-1)*(D_ant);#阵列长度
    Az = Ay;

    r_BS  = (torch.tensor([0,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()
    r_RIS = (torch.tensor([R,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()

    
    ##随机生成信道
    H_BR = Channel_BS_RIS(param_list) * 1e6
    
    batch_sum = batch*K
    
    num_train = 0
    train_nmse = 0
    train_SE = 0
    best_SE = 0
    start = datetime.datetime.now()
    numStep = 1
    warmup_steps = 4000
    d_model = 256


    for step in range(STEPS):
        optimizer.param_groups[0]['lr'] = (d_model**(-0.5))*min(numStep**(-0.5),numStep*warmup_steps**(-1.5))*LR0
        numStep = numStep+1
        net_AE.eval()
        net_RS.train()
        
        optimizer.zero_grad()
        for i in range(batch_num): #在这个循环内每次进行一次反向传播，进行batch_num次后梯度下降，这是为了在不增加显存的条件下增大batchsize
            alpha = 1e5/sqrt(Lp)
            H_RU = Channel_RIS_UE(param_list,batch_sum,alpha)
            with torch.no_grad():
                H_RU0 = net_AE(param_list, H_BR, H_RU).detach_()
            nmse = NMSE(H_RU0,H_RU)
            H_RU = H_RU.reshape(batch,K,Nc,1,M_ant)
            H_RU0 = H_RU0.reshape(batch,K,Nc,1,M_ant)

            # H_RU0 = torch.randn(batch,K,Nc,1,M_ant).cuda() + 0j
            
            Phi,F_RF,F_BB_SDMA,F_BB_RS = net_RS(param_list, H_BR, H_RU0)
            R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi,F_RF,F_BB_SDMA,F_BB_RS)
            loss = -R_sum/batch_num
            
            num_train = num_train + 1
            train_nmse = train_nmse + nmse.detach_()
            train_SE = train_SE + R_sum.detach_()

            # print(train_nmse,nmse.detach_(),num_train)

            loss.backward()

        
        optimizer.step()
        optimizer.zero_grad()
        
        if (step+1)%50==0:
            with torch.no_grad():
                net_AE.eval()
                net_RS.eval()
                test_SE = 0
                test_SDMA = 0
                for i in range(10):
                    alpha = 1e5/sqrt(Lp)
                    H_RU = Channel_RIS_UE(param_list,batch_sum,alpha)

                    H_RU0 = net_AE(param_list, H_BR, H_RU)

                    H_RU = H_RU.reshape(batch,K,Nc,1,M_ant)
                    H_RU0 = H_RU0.reshape(batch,K,Nc,1,M_ant)

                    Phi,F_RF,F_BB_SDMA,F_BB_RS = net_RS(param_list, H_BR, H_RU0)
                    R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi,F_RF,F_BB_SDMA,F_BB_RS)
                    test_SE = test_SE + R_sum
                    test_SDMA = test_SDMA + R_SDMA
                test_SE = test_SE/10
                test_SDMA = test_SDMA/10

                

            train_nmse = train_nmse/num_train
            train_SE = train_SE/num_train
            time0 =  datetime.datetime.now()-start
            print('step:',step,'time',time0,' NMSE %.5f' % train_nmse.cpu(),'train_SE %.5f' % train_SE.cpu(),'test_SE %.5f' % test_SE.cpu(),'test_SDMA %.5f' % test_SDMA.cpu()) 
            if test_SE > best_SE:
                best_SE = test_SE
                torch.save(net_RS, model_name_RS)
                print('Model saved!')
            num_train = 0
            train_nmse = 0
            train_SE = 0
            start = datetime.datetime.now()

### 基线基于功率分配的RSMA预编码方案定义，其中最佳功率分配系数用穷搜获取，不完美信道用提出的CAN获取

In [5]:
def Convention_RSMA(param_list,batch,STEPS,model_name_AE):
    
    fc = param_list[0]
    BW  = param_list[1]
    Nc = param_list[2]
    M  = param_list[3]
    N  = param_list[4]
    D_sub = param_list[5]
    D_ant = param_list[6]
    R = param_list[7]
    M_ant = param_list[8]
    h1 = param_list[9]
    h2 = param_list[10]
    L =  param_list[11]
    sigma = param_list[12]
    B = param_list[13]
    K = param_list[14]
    Lp = param_list[15]

    loss_func = SE_RS()
    loss_func = loss_func.cuda()

    net_AE = CAN(param_list).cuda()
    net_AE = torch.load(model_name_AE)
    
    Ay = (M-1)*D_sub + (N-1)*(D_ant);#阵列长度
    Az = Ay;

    r_BS  = (torch.tensor([0,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()
    r_RIS = (torch.tensor([R,0,h1]) - torch.tensor([0,Ay/2,Az/2])).cuda()

    
    ##随机生成信道
    H_BR = Channel_BS_RIS(param_list) * 1e6
    
    batch_sum = batch*K
    
    num_train = 0
    train_nmse = 0
    train_SE = 0
    best_SE = 0
    start = datetime.datetime.now()
    numStep = 1
    warmup_steps = 4000
    d_model = 256

    train_SE_RS = 0

    train_SE_SDMA = 0
    t_list = [0.0000001,0.000001,0.00001,0.0001,0.0003,0.001,0.003,0.01,0.03,0.1,0.3,1]
    with torch.no_grad():
        R_sum_perfect_ZF = 0
        R_sum_imperfect_ZF = 0
        R_sum_imperfect_MC = 0
        R_sum_imperfect_RSMA = 0
        for step in range(STEPS):
            numStep = numStep+1
            net_AE.eval()
            alpha = 1e5/sqrt(Lp)
            H_RU = Channel_RIS_UE(param_list,batch_sum,alpha)
            H_RU0 = net_AE(param_list, H_BR, H_RU).detach_()
            nmse = NMSE(H_RU0,H_RU)
            H_RU = H_RU.reshape(batch,K,Nc,1,M_ant)
            H_RU0 = H_RU0.reshape(batch,K,Nc,1,M_ant)

            ####################传统方法

            F_RF = (torch.zeros(M_ant,K) + 0j).cuda()
            for i in range(K):
                F_RF[:,i] = H_BR[Nc//2,(i*N*N+6),:]
            F_RF = F_RF.reshape(M_ant,K)
            F_RF = torch.real(F_RF) - 1j*torch.imag(F_RF)
            F_RF = (F_RF/torch.abs(F_RF))/sqrt(M_ant)

            F_RF = F_RF.reshape(1,1,M_ant,K)
            

            Phi = (torch.zeros(batch,1,K,N*N) + 0j).cuda()
            for i in range(K):
                h = H_RU[:,i,Nc//2,0,i*N*N:(i+1)*N*N]
                Phi[:,0,i,:] = (h/torch.abs(h)).reshape(batch,N*N)
    #         print(Phi[:,0,0,:])
            Phi = torch.real(Phi) - 1j*torch.imag(Phi)
            Phi = Phi.reshape(batch,1,M_ant)
            Phi = torch.diag_embed(Phi)

            Phi0 = (torch.zeros(batch,1,K,N*N) + 0j).cuda()
            for i in range(K):
                h = H_RU0[:,i,Nc//2,0,i*N*N:(i+1)*N*N]
                Phi0[:,0,i,:] = (h/torch.abs(h)).reshape(batch,N*N)
    #         print(Phi[:,0,0,:])
            Phi0 = torch.real(Phi0) - 1j*torch.imag(Phi0)
            Phi0 = Phi0.reshape(batch,1,M_ant)
            Phi0 = torch.diag_embed(Phi0)

            H_equ = (torch.zeros(batch,Nc,K,K) + 0j).cuda()
            for i in range(K):
                H_equ[:,:,i] = (H_RU[:,i] @ Phi @ H_BR @ F_RF).reshape(-1,Nc,K)

            H_equ0 = (torch.zeros(batch,Nc,K,K) + 0j).cuda()
            for i in range(K):
                H_equ0[:,:,i] = (H_RU0[:,i] @ Phi0 @ H_BR @ F_RF).reshape(-1,Nc,K)
            
            # sigma = 1e2
            # param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K]

            
            #完美CSI下的RZF
            t = 1
            N_RS1 = 1
            N_RS2 = K
            RS_cor = [[0,1]]
            F_BB_SDMA = torch.pinverse(Hermitian(H_equ) @ H_equ + torch.eye(K).cuda()*sigma/K) @ Hermitian(H_equ) #[batch,Nc,Nt,K]
            F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_SDMA*F_BB_SDMA),2)).reshape(batch,Nc,1,K)
            F_BB_SDMA = F_BB_SDMA/F_sigma*(sqrt(t))
            # F_RS = (torch.randn(batch,Nc,Nt,N_RS1) + 1j*torch.randn(batch,Nc,Nt,N_RS1)).cuda()
            F_BB_RS = Hermitian(H_equ[:,:,0:1,:]) + Hermitian(H_equ[:,:,1:2,:]) + Hermitian(H_equ[:,:,2:3,:]) + Hermitian(H_equ[:,:,3:4,:])
            F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_RS*F_BB_RS),2)).reshape(batch,Nc,1,N_RS1)
            F_BB_RS = F_BB_RS/F_sigma*(sqrt(K*(1-t)))
            R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi,F_RF,F_BB_SDMA,F_BB_RS)
            R_sum_perfect_ZF = R_sum_perfect_ZF + R_sum
            # print('Perfect CSI ZF 频谱效率：',R_sum)




            #不完美CSI下的RZF
            t = 1
            N_RS1 = 1
            N_RS2 = K
            RS_cor = [[0,1]]
            F_BB_SDMA = torch.pinverse(Hermitian(H_equ0) @ H_equ0 + torch.eye(K).cuda()*sigma/K) @ Hermitian(H_equ0) #[batch,Nc,Nt,K]
            F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_SDMA*F_BB_SDMA),2)).reshape(batch,Nc,1,K)
            F_BB_SDMA = F_BB_SDMA/F_sigma*(sqrt(t))
            # F_RS = (torch.randn(batch,Nc,Nt,N_RS1) + 1j*torch.randn(batch,Nc,Nt,N_RS1)).cuda()
            F_BB_RS = Hermitian(H_equ0[:,:,0:1,:]) + Hermitian(H_equ0[:,:,1:2,:]) + Hermitian(H_equ0[:,:,2:3,:]) + Hermitian(H_equ0[:,:,3:4,:])
            F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_RS*F_BB_RS),2)).reshape(batch,Nc,1,N_RS1)
            F_BB_RS = F_BB_RS/F_sigma*(sqrt(K*(1-t)))
            R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi0,F_RF,F_BB_SDMA,F_BB_RS)
            R_sum_imperfect_ZF = R_sum_imperfect_ZF + R_sum
            # print('Imperfect CSI ZF 频谱效率：',R_sum)

            #不完美CSI下的Multicast
            t = 0
            N_RS1 = 1
            N_RS2 = K
            RS_cor = [[0,1]]
            F_BB_SDMA = torch.pinverse(Hermitian(H_equ0) @ H_equ0 + torch.eye(K).cuda()*sigma/K) @ Hermitian(H_equ0) #[batch,Nc,Nt,K]
            F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_SDMA*F_BB_SDMA),2)).reshape(batch,Nc,1,K)
            F_BB_SDMA = F_BB_SDMA/F_sigma*(sqrt(t))
            # F_RS = (torch.randn(batch,Nc,Nt,N_RS1) + 1j*torch.randn(batch,Nc,Nt,N_RS1)).cuda()
            F_BB_RS = Hermitian(H_equ0[:,:,0:1,:]) + Hermitian(H_equ0[:,:,1:2,:]) + Hermitian(H_equ0[:,:,2:3,:]) + Hermitian(H_equ0[:,:,3:4,:])
            F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_RS*F_BB_RS),2)).reshape(batch,Nc,1,N_RS1)
            F_BB_RS = F_BB_RS/F_sigma*(sqrt(K*(1-t)))
            R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi0,F_RF,F_BB_SDMA,F_BB_RS)
            R_sum_imperfect_MC = R_sum_imperfect_MC + R_sum
            # print('Imperfect CSI Multicast 频谱效率：',R_sum)

            R_sum_max = 0
            for t in t_list:
                N_RS1 = 1
                N_RS2 = K
                RS_cor = [[0,1]]
                F_BB_SDMA = torch.pinverse(Hermitian(H_equ0) @ H_equ0 + torch.eye(K).cuda()*sigma/K) @ Hermitian(H_equ0) #[batch,Nc,Nt,K]
                F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_SDMA*F_BB_SDMA),2)).reshape(batch,Nc,1,K)
                F_BB_SDMA = F_BB_SDMA/F_sigma*(sqrt(t))
                # F_RS = (torch.randn(batch,Nc,Nt,N_RS1) + 1j*torch.randn(batch,Nc,Nt,N_RS1)).cuda()
                F_BB_RS = Hermitian(H_equ0[:,:,0:1,:]) + Hermitian(H_equ0[:,:,1:2,:]) + Hermitian(H_equ0[:,:,2:3,:]) + Hermitian(H_equ0[:,:,3:4,:])
                F_sigma = torch.sqrt(torch.sum(torch.abs(F_BB_RS*F_BB_RS),2)).reshape(batch,Nc,1,N_RS1)
                F_BB_RS = F_BB_RS/F_sigma*(sqrt(K*(1-t)))
                R_sum,R_SA,R_SDMA = loss_func(param_list, H_BR,H_RU,Phi0,F_RF,F_BB_SDMA,F_BB_RS)
                R_sum_max = torch.max(torch.tensor(R_sum_max),R_sum)
            R_sum_imperfect_RSMA = R_sum_imperfect_RSMA + R_sum_max
                # print('Imperfect CSI RSMA 频谱效率：',R_sum)
    print('Perfect CSI RZF 频谱效率：',R_sum_perfect_ZF/STEPS)
    print('Imperfect CSI RZF 频谱效率：',R_sum_imperfect_ZF/STEPS)
    print('Imperfect CSI Multicast 频谱效率：',R_sum_imperfect_MC/STEPS)
    print('Imperfect CSI 功率分配RSMA 频谱效率：',R_sum_imperfect_RSMA/STEPS)

### 参数设置

In [25]:
fc = 1.5e11; #150GHz  

# BW = 1e9;    # 1GHz

BW = 1e10;    # 总带宽10GHz  

Nc = 64;     # 64个子载波  


M = 2; #M×M子阵列行列数量
N = 8; #每个子阵N×N天线
c = 3e8;  #光速
lambda_c = c/fc;  
D_sub = 100*lambda_c;  #子阵中心间隔 0.2m 
D_ant = lambda_c/2;    #天线间隔     2mm  
R = M*D_sub**2/lambda_c;#基站和RIS距离 40m  满足LOS-MIMO距离  

M_ant = M*M*N*N; #基站或RIS的总天线数=256

h1 = 10;  #BS与RIS高度
h2 = 1.5; #用户高度  此项作废，因为后面改成了每个散射体0-4米之间分布

L  = 8  #观测时隙数
Lp  = 1  #多径数
sigma = 1e5 #噪声功率
B = 128  #反馈比特

K = M*M  #用户数=子阵列数=4




### CAN 数据驱动训练

In [None]:
L  = 8  #观测时隙数
sigma = 1e4
B = 32

batch = 128;   #用户数
STEPS = 30000
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
train_CAN(param_list,batch,STEPS,model_name_AE)

### 基线RSMA Precoding

In [27]:
L  = 8  #观测时隙数
sigma = 1e2
B = 32
batch = 64;   
STEPS = 10
for sigma in [1e1,1e2,1e3,1e4,1e5,1e6]:
    model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
    param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
    print('sigma: 1e',str(floor(log10(sigma))),'  结果：')
    Convention_RSMA(param_list,batch,STEPS,model_name_AE)

sigma: 1e 1   结果：




In [53]:
L  = 8  #观测时隙数
sigma = 1e2
B = 32
batch = 64;   
STEPS = 10
for B in [8,16,32,64,128,256]:
    model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
    model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
    param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
    print('B: ',str(B),'  结果：')
    Convention_RSMA(param_list,batch,STEPS,model_name_AE)

B:  8   结果：




Perfect CSI ZF 频谱效率： tensor(17.4732, device='cuda:0')
Imperfect CSI ZF 频谱效率： tensor(0.3898, device='cuda:0')
Imperfect CSI Multicast 频谱效率： tensor(3.2196, device='cuda:0')
Imperfect CSI RSMA 频谱效率： tensor(3.2148, device='cuda:0')
B:  16   结果：
Perfect CSI ZF 频谱效率： tensor(17.4707, device='cuda:0')
Imperfect CSI ZF 频谱效率： tensor(2.2807, device='cuda:0')
Imperfect CSI Multicast 频谱效率： tensor(4.0529, device='cuda:0')
Imperfect CSI RSMA 频谱效率： tensor(4.7419, device='cuda:0')
B:  32   结果：
Perfect CSI ZF 频谱效率： tensor(17.5470, device='cuda:0')
Imperfect CSI ZF 频谱效率： tensor(5.7463, device='cuda:0')
Imperfect CSI Multicast 频谱效率： tensor(4.3592, device='cuda:0')
Imperfect CSI RSMA 频谱效率： tensor(7.6064, device='cuda:0')
B:  64   结果：
Perfect CSI ZF 频谱效率： tensor(17.4853, device='cuda:0')
Imperfect CSI ZF 频谱效率： tensor(5.8832, device='cuda:0')
Imperfect CSI Multicast 频谱效率： tensor(4.3775, device='cuda:0')
Imperfect CSI RSMA 频谱效率： tensor(7.6769, device='cuda:0')
B:  128   结果：
Perfect CSI ZF 频谱效率： tensor(17.4426

### RSMA Precoding 模型驱动训练

In [25]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 32

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding_RSMA(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.966248  NMSE 0.04567 train_SE 0.81872 test_SE 0.80479 test_SDMA 0.65822
Model saved!
step: 99 time 0:00:36.414513  NMSE 0.04521 train_SE 0.90394 test_SE 0.94014 test_SDMA 0.69995
Model saved!
step: 149 time 0:00:36.211751  NMSE 0.04549 train_SE 1.17569 test_SE 1.48656 test_SDMA 0.75509
Model saved!
step: 199 time 0:00:36.465235  NMSE 0.04551 train_SE 1.75707 test_SE 2.14604 test_SDMA 0.59489
Model saved!
step: 249 time 0:00:36.092655  NMSE 0.04561 train_SE 2.26782 test_SE 2.51909 test_SDMA 0.39470
Model saved!
step: 299 time 0:00:35.910152  NMSE 0.04545 train_SE 2.74191 test_SE 3.24435 test_SDMA 0.45854
Model saved!
step: 349 time 0:00:35.906192  NMSE 0.04575 train_SE 3.46973 test_SE 3.82887 test_SDMA 0.44742
Model saved!
step: 399 time 0:00:35.815408  NMSE 0.04587 train_SE 3.87364 test_SE 4.00067 test_SDMA 0.40706
Model saved!
step: 449 time 0:00:35.901717  NMSE 0.04534 train_SE 4.02784 test_SE 4.08262 test_SDMA 0.35873
Model saved!
step: 499 time 0:00:35.897278 

In [26]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e3
B = 32

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.798757  NMSE 0.04586 train_SE 0.86699 test_SE 0.90234 test_SDMA 0.61850
Model saved!
step: 99 time 0:00:36.050330  NMSE 0.04603 train_SE 0.96712 test_SE 1.06138 test_SDMA 0.62200
Model saved!
step: 149 time 0:00:36.066821  NMSE 0.04577 train_SE 1.20356 test_SE 1.40265 test_SDMA 0.53883
Model saved!
step: 199 time 0:00:36.086837  NMSE 0.04575 train_SE 1.61996 test_SE 1.99850 test_SDMA 0.49557
Model saved!
step: 249 time 0:00:36.024973  NMSE 0.04591 train_SE 2.27832 test_SE 2.65715 test_SDMA 0.66049
Model saved!
step: 299 time 0:00:36.056314  NMSE 0.04580 train_SE 2.80118 test_SE 3.00188 test_SDMA 0.77674
Model saved!
step: 349 time 0:00:36.032389  NMSE 0.04586 train_SE 3.11776 test_SE 3.23149 test_SDMA 0.55339
Model saved!
step: 399 time 0:00:36.043459  NMSE 0.04584 train_SE 3.26485 test_SE 3.31944 test_SDMA 0.59629
Model saved!
step: 449 time 0:00:36.049773  NMSE 0.04552 train_SE 3.39059 test_SE 3.53347 test_SDMA 1.00514
Model saved!
step: 499 time 0:00:36.060807 

In [27]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e4
B = 32

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.923667  NMSE 0.05028 train_SE 0.64326 test_SE 0.66028 test_SDMA 0.51445
Model saved!
step: 99 time 0:00:36.068382  NMSE 0.05094 train_SE 0.71422 test_SE 0.80681 test_SDMA 0.58612
Model saved!
step: 149 time 0:00:36.006932  NMSE 0.05090 train_SE 0.99511 test_SE 1.34614 test_SDMA 0.80924
Model saved!
step: 199 time 0:00:35.989921  NMSE 0.05100 train_SE 1.66901 test_SE 2.04973 test_SDMA 0.67705
Model saved!
step: 249 time 0:00:35.965976  NMSE 0.05028 train_SE 2.24808 test_SE 2.45536 test_SDMA 0.78097
Model saved!
step: 299 time 0:00:36.032031  NMSE 0.05097 train_SE 2.56780 test_SE 2.84732 test_SDMA 1.30980
Model saved!
step: 349 time 0:00:36.048786  NMSE 0.05064 train_SE 2.95208 test_SE 3.20060 test_SDMA 1.92478
Model saved!
step: 399 time 0:00:36.029354  NMSE 0.05054 train_SE 3.32122 test_SE 3.56789 test_SDMA 2.44071
Model saved!
step: 449 time 0:00:36.045310  NMSE 0.05001 train_SE 3.53441 test_SE 3.65583 test_SDMA 2.50851
Model saved!
step: 499 time 0:00:36.031763 

In [28]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e5
B = 32

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.966597  NMSE 0.06533 train_SE 0.26505 test_SE 0.29440 test_SDMA 0.22232
Model saved!
step: 99 time 0:00:36.012132  NMSE 0.06496 train_SE 0.34018 test_SE 0.48036 test_SDMA 0.41060
Model saved!
step: 149 time 0:00:36.039890  NMSE 0.06477 train_SE 0.77729 test_SE 1.12674 test_SDMA 1.03380
Model saved!
step: 199 time 0:00:36.055350  NMSE 0.06529 train_SE 1.34462 test_SE 1.64849 test_SDMA 1.48090
Model saved!
step: 249 time 0:00:36.078767  NMSE 0.06463 train_SE 1.84821 test_SE 2.24782 test_SDMA 1.89787
Model saved!
step: 299 time 0:00:36.016187  NMSE 0.06509 train_SE 2.29511 test_SE 2.55052 test_SDMA 2.12020
Model saved!
step: 349 time 0:00:35.923342  NMSE 0.06469 train_SE 2.51353 test_SE 2.65105 test_SDMA 2.21180
Model saved!
step: 399 time 0:00:35.976033  NMSE 0.06514 train_SE 2.63354 test_SE 2.83142 test_SDMA 2.42291
Model saved!
step: 449 time 0:00:35.936486  NMSE 0.06587 train_SE 2.71066 test_SE 2.89761 test_SDMA 2.51127
Model saved!
step: 499 time 0:00:35.873247 

In [29]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 8

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.732709  NMSE 0.69044 train_SE 0.54824 test_SE 0.56149 test_SDMA 0.19906
Model saved!
step: 99 time 0:00:35.891762  NMSE 0.69153 train_SE 0.68014 test_SE 0.84207 test_SDMA 0.20773
Model saved!
step: 149 time 0:00:35.874710  NMSE 0.69106 train_SE 1.00095 test_SE 1.33604 test_SDMA 0.20685
Model saved!
step: 199 time 0:00:35.888311  NMSE 0.68925 train_SE 1.55597 test_SE 2.06431 test_SDMA 0.17339
Model saved!
step: 249 time 0:00:35.925145  NMSE 0.68958 train_SE 2.22586 test_SE 2.63352 test_SDMA 0.09291
Model saved!
step: 299 time 0:00:35.884255  NMSE 0.69091 train_SE 2.64701 test_SE 2.97908 test_SDMA 0.04678
Model saved!
step: 349 time 0:00:35.924239  NMSE 0.68965 train_SE 2.97454 test_SE 3.34742 test_SDMA 0.03710
Model saved!
step: 399 time 0:00:35.823661  NMSE 0.68957 train_SE 3.19896 test_SE 3.38825 test_SDMA 0.06391
Model saved!
step: 449 time 0:00:35.857313  NMSE 0.69111 train_SE 3.28233 test_SE 3.44705 test_SDMA 0.04161
Model saved!
step: 499 time 0:00:35.876233 

In [30]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 16

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.750353  NMSE 0.27694 train_SE 0.76429 test_SE 0.81712 test_SDMA 0.44443
Model saved!
step: 99 time 0:00:35.896861  NMSE 0.28011 train_SE 0.95554 test_SE 1.16949 test_SDMA 0.44641
Model saved!
step: 149 time 0:00:35.879369  NMSE 0.27958 train_SE 1.50161 test_SE 1.96573 test_SDMA 0.43993
Model saved!
step: 199 time 0:00:35.947170  NMSE 0.27850 train_SE 2.25396 test_SE 2.79323 test_SDMA 0.43159
Model saved!
step: 249 time 0:00:35.908560  NMSE 0.28062 train_SE 3.11168 test_SE 3.65778 test_SDMA 0.28034
Model saved!
step: 299 time 0:00:35.890315  NMSE 0.27941 train_SE 3.71373 test_SE 3.93615 test_SDMA 0.11233
Model saved!
step: 349 time 0:00:35.904459  NMSE 0.27937 train_SE 3.91694 test_SE 4.01907 test_SDMA 0.06952
Model saved!
step: 399 time 0:00:35.882800  NMSE 0.27995 train_SE 4.00486 test_SE 4.05292 test_SDMA 0.07200
Model saved!
step: 449 time 0:00:35.898270  NMSE 0.27961 train_SE 4.06031 test_SE 4.07890 test_SDMA 0.10180
Model saved!
step: 499 time 0:00:35.909603 

In [31]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 64

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.698497  NMSE 0.03727 train_SE 0.88977 test_SE 0.89608 test_SDMA 0.71474
Model saved!
step: 99 time 0:00:35.784030  NMSE 0.03739 train_SE 0.95149 test_SE 0.98274 test_SDMA 0.73946
Model saved!
step: 149 time 0:00:35.824435  NMSE 0.03743 train_SE 1.13070 test_SE 1.38745 test_SDMA 0.81052
Model saved!
step: 199 time 0:00:35.815950  NMSE 0.03735 train_SE 1.69857 test_SE 2.15893 test_SDMA 0.61775
Model saved!
step: 249 time 0:00:35.852407  NMSE 0.03736 train_SE 2.34787 test_SE 2.72110 test_SDMA 0.41618
Model saved!
step: 299 time 0:00:35.794242  NMSE 0.03739 train_SE 3.03264 test_SE 3.52724 test_SDMA 0.51167
Model saved!
step: 349 time 0:00:35.829061  NMSE 0.03741 train_SE 3.61423 test_SE 3.88135 test_SDMA 0.57052
Model saved!
step: 399 time 0:00:35.826493  NMSE 0.03744 train_SE 3.89878 test_SE 4.01925 test_SDMA 0.64468
Model saved!
step: 449 time 0:00:35.828531  NMSE 0.03764 train_SE 4.13539 test_SE 4.37860 test_SDMA 1.26370
Model saved!
step: 499 time 0:00:35.792631 

In [52]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 128

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:36.054053  NMSE 0.01257 train_SE 0.95140 test_SE 0.97660 test_SDMA 0.75841
Model saved!
step: 99 time 0:00:36.248176  NMSE 0.01264 train_SE 1.01648 test_SE 1.06532 test_SDMA 0.77817
Model saved!
step: 149 time 0:00:36.363383  NMSE 0.01264 train_SE 1.14646 test_SE 1.30341 test_SDMA 0.83741
Model saved!
step: 199 time 0:00:36.382939  NMSE 0.01265 train_SE 1.47658 test_SE 1.86963 test_SDMA 0.83268
Model saved!
step: 249 time 0:00:36.257058  NMSE 0.01256 train_SE 2.18995 test_SE 2.69580 test_SDMA 0.57727
Model saved!
step: 299 time 0:00:36.243771  NMSE 0.01259 train_SE 2.95390 test_SE 3.44701 test_SDMA 0.52473
Model saved!
step: 349 time 0:00:36.216306  NMSE 0.01254 train_SE 3.62872 test_SE 3.90747 test_SDMA 0.41103
Model saved!
step: 399 time 0:00:36.256929  NMSE 0.01255 train_SE 3.94113 test_SE 4.04002 test_SDMA 0.40836
Model saved!
step: 449 time 0:00:36.249466  NMSE 0.01261 train_SE 4.13106 test_SE 4.29778 test_SDMA 0.98793
Model saved!
step: 499 time 0:00:36.187069 

In [32]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 256

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_RS = './models/RIS_RS_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding(param_list,batch,batch_num,STEPS,model_name_AE,model_name_RS,LR0)



step: 49 time 0:00:35.857759  NMSE 0.00914 train_SE 0.90029 test_SE 0.92599 test_SDMA 0.76239
Model saved!
step: 99 time 0:00:35.945063  NMSE 0.00914 train_SE 0.97477 test_SE 1.01713 test_SDMA 0.78874
Model saved!
step: 149 time 0:00:35.929101  NMSE 0.00916 train_SE 1.12643 test_SE 1.27298 test_SDMA 0.84955
Model saved!
step: 199 time 0:00:35.910097  NMSE 0.00912 train_SE 1.47666 test_SE 1.88050 test_SDMA 0.86728
Model saved!
step: 249 time 0:00:35.963956  NMSE 0.00914 train_SE 2.11033 test_SE 2.48909 test_SDMA 0.72953
Model saved!
step: 299 time 0:00:35.961027  NMSE 0.00913 train_SE 2.74665 test_SE 3.24726 test_SDMA 0.85772
Model saved!
step: 349 time 0:00:35.936067  NMSE 0.00914 train_SE 3.57900 test_SE 4.13461 test_SDMA 1.45959
Model saved!
step: 399 time 0:00:35.934073  NMSE 0.00914 train_SE 4.34624 test_SE 4.91228 test_SDMA 2.26157
Model saved!
step: 449 time 0:00:35.926078  NMSE 0.00918 train_SE 4.95265 test_SE 5.55195 test_SDMA 3.01102
Model saved!
step: 499 time 0:00:35.897219 

### SDMA Precoding 模型驱动训练

In [34]:
LR0 = 1
L  = 8  #观测时隙数
sigma = 1e2
B = 32

batch = 32;   
batch_num = 4
STEPS = 20000
model_name_AE = './models/RIS_AE_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
model_name_SD = './models/RIS_SD_'+str(L)+'pilots_'+str(B)+'bits_'+str(Lp)+'paths_'+str(floor(log10(sigma)))+'sigma'+'.pth'
param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]
train_Precoding_SDMA(param_list,batch,batch_num,STEPS,model_name_AE,model_name_SD,LR0)



step: 49 time 0:00:34.589844  NMSE 0.04555 train_SE 0.72469 test_SE 0.72335 test_SDMA 0.72335
Model saved!
step: 99 time 0:00:34.701252  NMSE 0.04503 train_SE 0.76796 test_SE 0.77807 test_SDMA 0.77807
Model saved!
step: 149 time 0:00:34.681373  NMSE 0.04534 train_SE 0.84091 test_SE 0.90190 test_SDMA 0.90190
Model saved!
step: 199 time 0:00:34.740744  NMSE 0.04547 train_SE 0.99899 test_SE 1.19242 test_SDMA 1.19242
Model saved!
step: 249 time 0:00:34.687888  NMSE 0.04508 train_SE 1.29465 test_SE 1.57184 test_SDMA 1.57184
Model saved!
step: 299 time 0:00:34.701362  NMSE 0.04590 train_SE 1.71746 test_SE 2.25509 test_SDMA 2.25509
Model saved!
step: 349 time 0:00:34.691396  NMSE 0.04539 train_SE 2.18641 test_SE 2.74048 test_SDMA 2.74048
Model saved!
step: 399 time 0:00:34.680390  NMSE 0.04523 train_SE 2.59874 test_SE 3.09260 test_SDMA 3.09260
Model saved!
step: 449 time 0:00:34.696814  NMSE 0.04551 train_SE 2.92248 test_SE 3.37634 test_SDMA 3.37634
Model saved!
step: 499 time 0:00:34.696318 