In [3]:
import torch
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 使用第二块GPU（从0开始） 
import math 
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from utils import conv, update_registered_buffers, deconv
import datetime
import torch.nn as nn
from compressai.ans import BufferedRansEncoder, RansDecoder
from compressai.layers import conv3x3, subpel_conv3x3
# from compressai.ops import ste_round
from Transformer_model import *
from basic import *
from pytorch_msssim import ms_ssim
import matplotlib.pyplot as plt


In [2]:
import torch.nn.functional as F

def psnr(pred, target, max_val=1.0):
    """
    Compute the PSNR (Peak Signal-to-Noise Ratio) between two images.
    
    Parameters:
        pred (Tensor): Predicted image or batch of images. Shape is (N, C, H, W) where N is the batch size.
        target (Tensor): Target image or batch of images. Should have the same shape as pred.
        max_val (float): Maximum possible pixel value of the images. Default is 1.0 for normalized images.
    
    Returns:
        psnr_value (Tensor): PSNR value(s) for the image(s).
    """
    mse = F.mse_loss(pred, target, reduction='mean')
    psnr_value = 20 * torch.log10(max_val / torch.sqrt(mse).cuda())
    return psnr_value


In [3]:
def kron(a,b):
    #a与b维度为[batch,N],输出应为[batch,N*N]
    batch = a.shape[0]
    a = a.reshape(batch,-1,1)
    b = b.reshape(batch,1,-1)
    c = a @ b
    return c.reshape(batch,-1) #输出的维度a在前，b在后
def DFT_matrix(N):
    n = torch.arange(N).reshape(N,1).cuda()
    k = torch.arange(N).reshape(1,N).cuda()
    W = torch.exp(-1j*2*pi*n*k/N)/sqrt(N)
    return W
def SV_Channel(batch,Nc,N,Nt,sigma_2_alpha):
    #Nc代表载波数 N代表路径数 sigma_2_alph每条路径能量 sigma_angle角度均匀分布的幅度 Nt=[Nx,Ny]是UPA天线数
    d=0.5
    #print(check)
    Tau = Nc  #最大路径时延
    tau = torch.rand(batch,1,N).cuda()*Tau

    N_r = 1
    N_t = Nt[0]*Nt[1]
    fait = torch.rand(batch,1,N).cuda()*2*pi
    fair = torch.rand(batch,1,N).cuda()*2*pi
    theatt = torch.rand(batch,1,N).cuda()*2*pi
    theatr = torch.rand(batch,1,N).cuda()*2*pi
    A_t = torch.zeros(batch,N_t,N).cuda() +0j
    alpha = torch.zeros(batch,1,N).cuda() +0j
    n_t1 = torch.arange(Nt[0]).reshape(1,Nt[0],1).cuda()
    n_t2 = torch.arange(Nt[1]).reshape(1,Nt[1],1).cuda()
    A_r = torch.ones(batch,1,N).cuda() +0j
    H_f = torch.zeros(batch,Nc,N_t).cuda() +0j
    for i in range(N):
        at1 = torch.exp(-2j*pi*0.5*n_t1*torch.cos(fait[:,:,i:(i+1)])*torch.sin(theatt[:,:,i:(i+1)]))
        at2 = torch.exp(-2j*pi*0.5*n_t2*torch.sin(fait[:,:,i:(i+1)]))
        A_t[:,:,i] = kron(at1,at2)
        aa = (torch.randn(batch,1)+1j*torch.randn(batch,1))*sqrt(sigma_2_alpha/2) 
        # if i == 0:
        #     aa = (torch.randn(batch,1)+1j*torch.randn(batch,1))*sqrt(sigma_2_alpha/2) 
        # else:
        #     aa = (torch.randn(batch,1)+1j*torch.randn(batch,1))*sqrt(sigma_2_alpha/2)*0.1
        alpha[:,:,i] = aa[:,:]


    for k in range(Nc):
        P = alpha*torch.exp(-1j*2*pi*tau*k/Nc)
        P = torch.diag_embed(P).reshape(batch,N,N)
        H_f[:,k:(k+1),:] = torch.matmul(torch.matmul(A_r,P),Hermitian(A_t))
        
    return H_f

### GMMV-LAMP

In [5]:
T_num = 10  # LAPM 10层
def eta(R, sigma_t2, K, G, M, theta):  #R [batch,G,K] G为原始稀疏数据长度，K为子载波数也就是MMV的数量，M为观测维度
    theta1 = torch.abs(theta[0])
    theta2 = theta[1]
    pi_t = 1 + sigma_t2/theta1
    psi_t = 0.5*K*torch.log(torch.tensor(1+theta1/sigma_t2).cuda()) + theta2
    
#     aa = (psi_t.reshape(R.shape[0],1,1)-torch.sum(torch.abs(R*R),[2]).reshape(R.shape[0],R.shape[1],1)/((2*sigma_t2*pi_t).reshape(R.shape[0],1,1)))
#     frac = 1/torch.sigmoid(-aa)*
    
    aa = (psi_t.reshape(R.shape[0],1,1)-torch.sum(torch.abs(R*R),[2]).reshape(R.shape[0],R.shape[1],1)/((2*sigma_t2*pi_t).reshape(R.shape[0],1,1)))
    frac_div = torch.sigmoid(-aa)/pi_t.reshape(R.shape[0],1,1)


    H_res = R*frac_div
    b = torch.sum(frac_div,[1,2])/M
    return H_res,b
class LAMP_Net(nn.Module): #单层
    def __init__(self, Theta , Nc, N, L): #H [batch,G,K]为角度域信道, Theta [Nc,L,N]为感知矩阵 N为原始稀疏数据长度，T为迭代次数，Nc为子载波数也就是MMV的数量，L为观测维度
        super(LAMP_Net,self).__init__()
        self.B = torch.nn.Parameter(Hermitian(Theta))    #[Nc,N,L]
        self.theta = torch.nn.Parameter(torch.tensor([1.0,2.0]))
        
        
    def forward(self, Y, Nc, N, L, Theta): #输入Y[batch,Nc,L] 要改成[batch,L,Nc]
        
        # N = ((torch.randn(H.shape[0],M,K).cuda() + 1j*torch.randn(H.shape[0],M,K).cuda())*(sqrt(sigma2/2))).cuda()
        # Y = torch.matmul(self.A,H) + N

        Y = Y.permute(0,2,1).reshape(Y.shape[0],L,Nc)

        H_res = torch.zeros(Y.shape[0],N,Nc).cuda()
        V = Y #[batch,L,Nc]
        
        B = self.B
        theta = self.theta

        BV = B @ V.permute(0,2,1).reshape(Y.shape[0],Nc,L,1) #[batch,Nc,N,1]
        BV = BV.permute(0,2,1,3).reshape(Y.shape[0],N,Nc) #[batch,N,Nc]

        R = H_res + BV
        sigma_t2 = torch.sum(torch.abs(V*V),[1,2])/(L*Nc)
        H_res,b = eta(R, sigma_t2, Nc, N, L, theta)

        HH = Theta @ H_res.permute(0,2,1).reshape(Y.shape[0],Nc,N,1)   #[batch,Nc,L,1]
        HH = HH.permute(0,2,1,3).reshape(Y.shape[0],L,Nc) #[batch,L,Nc]

        V = Y-HH+b.reshape(-1,1,1)*V
        
        return H_res,V

class LAMP_MuNet(nn.Module): #多层
    def __init__(self, Theta , Nc, N, L, net): #H [batch,G,K]为角度域信道, Theta [Nc,L,N]为感知矩阵 N为原始稀疏数据长度，T为迭代次数，Nc为子载波数也就是MMV的数量，L为观测维度
        super(LAMP_MuNet,self).__init__()
        self.net = net
        for p in self.net.parameters():
            (p.requires_grad) = False
        
        self.B = torch.nn.Parameter(Hermitian(Theta))    #[Nc,N,L]
        self.theta = torch.nn.Parameter(torch.tensor([1.0,2.0]))
        
        
    def forward(self, Y, Nc, N, L, Theta): #输入Y[batch,Nc,L] 要改成[batch,L,Nc]

        H_res,V = self.net(Y, Nc, N, L, Theta) 

        Y = Y.permute(0,2,1).reshape(Y.shape[0],L,Nc)

        B = self.B
        theta = self.theta

        BV = B @ V.permute(0,2,1).reshape(Y.shape[0],Nc,L,1) #[batch,Nc,N,1]
        BV = BV.permute(0,2,1,3).reshape(Y.shape[0],N,Nc) #[batch,N,Nc]

        R = H_res + BV
        sigma_t2 = torch.sum(torch.abs(V*V),[1,2])/(L*Nc)
        H_res,b = eta(R, sigma_t2, Nc, N, L, theta)

        HH = Theta @ H_res.permute(0,2,1).reshape(Y.shape[0],Nc,N,1)   #[batch,Nc,L,1]
        HH = HH.permute(0,2,1,3).reshape(Y.shape[0],L,Nc) #[batch,L,Nc]

        V = Y-HH+b.reshape(-1,1,1)*V
        
        return H_res,V

In [None]:
def CE_LAMP_GMMV(parm_set, L, pho): #K1用户数 K2稀疏度
#Phi 为观测矩阵 Theta为感知矩阵

    factor = 100*L/32
    B = 4096
    # param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]

    # H_RU = H_RU.reshape(batch,Nc,M_ant,1)
    Step = 3000

    batch = 64

    Nc  = parm_set[0]
    Nt  = parm_set[1]
    Nr  = parm_set[2]
    snr = parm_set[3]
    K   = parm_set[4]
    sigma = 1/snr

    N_2 = pho*8
    A = torch.zeros(Nt,N_2*N_2).cuda() + 0j
    for i in range(N_2):
        for j in range(N_2):
            a_h = torch.range(0,8-1).reshape(1,8).cuda()/8/pho*i*2*pi
            a_v = torch.range(0,8-1).reshape(1,8).cuda()/8/pho*j*2*pi
            a = kron_add(a_h,a_v).reshape(Nt)
            A[:,i*N_2+j] = torch.exp(1j*a)
    A_t = A

    
    S_digital = torch.randn(L,Nc,K,1).cuda()+1j*torch.randn(L,Nc,K,1).cuda() #L个时隙BS导频,这里假设频率选择
    S_analog = torch.rand(L,1,Nt,4).cuda() * 2*pi
    Phase = torch.rand(L,Nt).cuda() * 2*pi   #L个时隙RIS相位
    Phi = torch.zeros(Nc,L,Nt).cuda() + 0j
    for l in range(L):
        s = (torch.exp(1j*S_analog[l,:,:,:]) @ S_digital[l,:,:,:] ).reshape(Nc,Nt,1)  #导频s [Nc,Nt]    
        
        s_sigma = torch.sqrt(torch.sum(torch.abs(s*s)))
        s = s/s_sigma.reshape(1,1,1)*torch.min(s_sigma,torch.tensor(sqrt(Nc))).reshape(1,1,1)
#             print(s)
        
#             print(n)
        
        
        a =  s.permute(0,2,1) #a[Nc,1,Nt] = Theta * G * s_l
#             print((a @ H_RU.reshape(-1,Nc,M_ant,1)).reshape(-1,Nc))
        Phi[:,l,:] = a[:,0,:]
    Theta = Phi @ A_t
    Theta = Theta/sqrt(factor)




    for T in range(T_num):
        if T==0:
            net = LAMP_Net(Theta, Nc, Nt*pho*pho, L).cuda()
        else:  
            net_1 =  torch.load('./models/AE_models/GMMV_LAMP'+str(T)+'_'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
            net = LAMP_MuNet(Theta, Nc, Nt*pho*pho, L, net_1).cuda()
            S_digital = torch.load('./models/AE_models/S_digital'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
            S_analog  = torch.load('./models/AE_models/S_analog'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
            # net = LAMP_MuNet(Theta , Nc, M_ant*pho*pho, L, net).cuda()
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr=0.001)

        num_train = 0
        train_nmse = 0
        best_nmse = 2
        start = datetime.datetime.now()
        

        for step in range(Step):
            net.train()
            H = SV_Channel(batch,Nc,N,[8,8],1)


            Y = torch.zeros(H.shape[0],Nc,L).cuda() + 1j#总观测【batch,Nc,L】
            Phi = torch.zeros(Nc,L,Nt).cuda() + 0j
            for l in range(L):
                s = (torch.exp(1j*S_analog[l,:,:,:]) @ S_digital[l,:,:,:] ).reshape(Nc,Nt,1)  #导频s [Nc,Nt]    
                
                s_sigma = torch.sqrt(torch.sum(torch.abs(s*s)))
                s = s/s_sigma.reshape(1,1,1)*torch.min(s_sigma,torch.tensor(sqrt(Nc))).reshape(1,1,1)
    #             print(s)
                n = (torch.randn(H.shape[0],Nc) + 1j*torch.randn(H.shape[0],Nc)).cuda()*sqrt(sigma/2)
                
    #             print(n)
                
                
                a =  s.permute(0,2,1) #a[Nc,1,Nt] = Theta * G * s_l
    #             print((a @ H_RU.reshape(-1,Nc,M_ant,1)).reshape(-1,Nc))
                y = (a @ H.reshape(-1,Nc,Nt,1)).reshape(-1,Nc) + n #[batch, Nc]
                Y[:,:,l] = y
                Phi[:,l,:] = a[:,0,:]
            Theta = Phi @ A_t
            Theta = Theta/sqrt(factor)
            Y = Y/sqrt(factor)



            H_t_hat,V = net(Y, Nc, Nt*pho*pho, L, Theta) #[batch,过采样,Nc]
            H_t_hat = H_t_hat.permute(0,2,1).reshape(Y.shape[0],Nc,Nt*pho*pho,1)
            H_hat = A_t @ H_t_hat
            H_hat = H_hat.reshape(-1,Nc,Nt)
            loss = NMSE(H_hat,H)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_train = num_train + 1
            train_nmse = train_nmse + loss

            if (step+1)%100==0:
                net.eval() 
                with torch.no_grad():
                    H = SV_Channel(batch,Nc,N,[8,8],1)


                    Y = torch.zeros(H.shape[0],Nc,L).cuda() + 1j#总观测【batch,Nc,L】
                    Phi = torch.zeros(Nc,L,Nt).cuda() + 0j
                    for l in range(L):
                        s = (torch.exp(1j*S_analog[l,:,:,:]) @ S_digital[l,:,:,:] ).reshape(Nc,Nt,1)  #导频s [Nc,Nt]    
                        
                        s_sigma = torch.sqrt(torch.sum(torch.abs(s*s)))
                        s = s/s_sigma.reshape(1,1,1)*torch.min(s_sigma,torch.tensor(sqrt(Nc))).reshape(1,1,1)
            #             print(s)
                        n = (torch.randn(H.shape[0],Nc) + 1j*torch.randn(H.shape[0],Nc)).cuda()*sqrt(sigma/2)
                        
            #             print(n)
                        
                        
                        a =  s.permute(0,2,1) #a[Nc,1,Nt] = Theta * G * s_l
            #             print((a @ H_RU.reshape(-1,Nc,M_ant,1)).reshape(-1,Nc))
                        y = (a @ H.reshape(-1,Nc,Nt,1)).reshape(-1,Nc) + n #[batch, Nc]
                        Y[:,:,l] = y
                        Phi[:,l,:] = a[:,0,:]
                    Theta = Phi @ A_t
                    Theta = Theta/sqrt(factor)
                    Y = Y/sqrt(factor)



                    H_t_hat,V = net(Y, Nc, Nt*pho*pho, L, Theta) #[batch,过采样,Nc]
                    H_t_hat = H_t_hat.permute(0,2,1).reshape(Y.shape[0],Nc,Nt*pho*pho,1)
                    H_hat = A_t @ H_t_hat
                    H_hat = H_hat.reshape(-1,Nc,Nt)
                    loss = NMSE(H_hat,H)

                    train_nmse = train_nmse/num_train
                    time0 =  datetime.datetime.now()-start
                    print('step:',step,'time',time0,'train NMSE %.5f' % train_nmse.cpu(),'test NMSE %.5f' % loss.cpu()) 

                    if loss < best_nmse:
                        best_nmse = loss
                        torch.save(net, './models/AE_models/GMMV_LAMP'+str(T+1)+'_'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
                        torch.save(S_digital, './models/AE_models/S_digital'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
                        torch.save(S_analog, './models/AE_models/S_analog'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
                        print('Model saved!')
                    num_train = 0
                    train_nmse = 0
                    start = datetime.datetime.now()

In [None]:
def LAMP_GMMV_test(parm_set, L, pho, H): #K1用户数 K2稀疏度
#Phi 为观测矩阵 Theta为感知矩阵

    factor = 100*L/32
    B = 4096
    # param_list = [fc,BW,Nc,M,N,D_sub,D_ant,R,M_ant,h1,h2,L,sigma,B,K,Lp]

    # H_RU = H_RU.reshape(batch,Nc,M_ant,1)
    Step = 3000

    batch = 64

    T = T_num

    Nc  = parm_set[0]
    Nt  = parm_set[1]
    Nr  = parm_set[2]
    snr = parm_set[3]
    K   = parm_set[4]
    sigma = 1/snr

    N_2 = pho*8
    A = torch.zeros(Nt,N_2*N_2).cuda() + 0j
    for i in range(N_2):
        for j in range(N_2):
            a_h = torch.range(0,8-1).reshape(1,8).cuda()/8/pho*i*2*pi
            a_v = torch.range(0,8-1).reshape(1,8).cuda()/8/pho*j*2*pi
            a = kron_add(a_h,a_v).reshape(Nt)
            A[:,i*N_2+j] = torch.exp(1j*a)
    A_t = A

    with torch.no_grad():
    
        S_digital = torch.randn(L,Nc,K,1).cuda()+1j*torch.randn(L,Nc,K,1).cuda() #L个时隙BS导频,这里假设频率选择
        S_analog = torch.rand(L,1,Nt,4).cuda() * 2*pi
        net = torch.load('./models/AE_models/GMMV_LAMP'+str(T)+'_'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
        net.eval() 
        S_digital = torch.load('./models/AE_models/S_digital'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')
        S_analog  = torch.load('./models/AE_models/S_analog'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR'+'.pth')

        Y = torch.zeros(H.shape[0],Nc,L).cuda() + 1j#总观测【batch,Nc,L】
        Phi = torch.zeros(Nc,L,Nt).cuda() + 0j
        for l in range(L):
                s = (torch.exp(1j*S_analog[l,:,:,:]) @ S_digital[l,:,:,:] ).reshape(Nc,Nt,1)  #导频s [Nc,Nt]    
                
                s_sigma = torch.sqrt(torch.sum(torch.abs(s*s)))
                s = s/s_sigma.reshape(1,1,1)*torch.min(s_sigma,torch.tensor(sqrt(Nc))).reshape(1,1,1)
        #             print(s)
                n = (torch.randn(H.shape[0],Nc) + 1j*torch.randn(H.shape[0],Nc)).cuda()*sqrt(sigma/2)
                
        #             print(n)
                
                
                a =  s.permute(0,2,1) #a[Nc,1,Nt] = Theta * G * s_l
        #             print((a @ H_RU.reshape(-1,Nc,M_ant,1)).reshape(-1,Nc))
                y = (a @ H.reshape(-1,Nc,Nt,1)).reshape(-1,Nc) + n #[batch, Nc]
                Y[:,:,l] = y
                Phi[:,l,:] = a[:,0,:]
        Theta = Phi @ A_t
        Theta = Theta/sqrt(factor)
        Y = Y/sqrt(factor)



        H_t_hat,V = net(Y, Nc, Nt*pho*pho, L, Theta) #[batch,过采样,Nc]
        H_t_hat = H_t_hat.permute(0,2,1).reshape(Y.shape[0],Nc,Nt*pho*pho,1)
        H_hat = A_t @ H_t_hat
        H_hat = H_hat.reshape(-1,Nc,Nt)
    return H_hat

In [None]:
Nc = 64
N = 2
Nt = 64
Nr = 1

T_num = 10  # LAMP层数

L = 32
SNR_dB = 10
K = 4
snr =  10**(SNR_dB/10)
pho = 2
parm_set = [Nc,Nt,Nr,snr,K]
CE_LAMP_GMMV(parm_set, L, pho)

### Proposed JSCBF

In [8]:
# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64


def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
    return torch.exp(torch.linspace(math.log(min), math.log(max), levels))

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """ Forward function.
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, inverse=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        self.H = None
        self.W = None

    def forward(self, x, mask_matrix):
        B, L, C = x.shape
        H, W = self.H, self.W
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x


class PatchSplit(nn.Module):
    """ Patch Merging Layer
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(dim, dim * 2, bias=False)
        self.norm = norm_layer(dim)
        self.shuffle = nn.PixelShuffle(2)

    def forward(self, x, H, W):
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = self.norm(x)
        x = self.reduction(x)           # B, L, C
        x = x.permute(0, 2, 1).contiguous().view(B, 2*C, H, W)
        x = self.shuffle(x)             # B, C//2 ,2H, 2W
        x = x.permute(0, 2, 3, 1).contiguous().view(B, 4 * L, -1)
        return x

class BasicLayer(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False,
                 inverse=False):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(

                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                inverse=inverse)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, H, W):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        for blk in self.blocks:
            blk.H, blk.W = H, W
            x = blk(x, attn_mask)
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            if isinstance(self.downsample, PatchMerging):
                Wh, Ww = (H + 1) // 2, (W + 1) // 2
            elif isinstance(self.downsample, PatchSplit):
                Wh, Ww = H * 2, W * 2
            return x_down, Wh, Ww
        else:
            return x, H, W


class PatchEmbed(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

        return x

In [9]:
class SwinEnc(nn.Module):
    def __init__(self,
                 pretrain_img_size=256,
                 patch_size=2,
                 in_chans=3,
                 embed_dim=48,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=4,  
                 num_slices=12, 
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 patch_norm=True,
                 use_checkpoint=False):
        super().__init__()

        self.pretrain_img_size = pretrain_img_size
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_slices = num_slices
        self.max_support_slices = num_slices // 2
        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer), # 输入通道数
                depth=depths[i_layer],       # 层数
                num_heads=num_heads[i_layer], # 多头数量
                window_size=window_size,    #窗口patch维度
                mlp_ratio=mlp_ratio,    #mlp比例
                qkv_bias=qkv_bias,    #qkv加偏置
                qk_scale=qk_scale,    #无scale
                drop=drop_rate,       #无drop
                attn_drop=attn_drop_rate,#无drop
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,# 最后一层没有patch merging，这里的patch merging是放在Transformer之后的
                use_checkpoint=use_checkpoint,
                inverse=False)
            self.layers.append(layer)


        #后面熵编码不需要

        self.h_a = nn.Sequential(
            conv3x3(384, 384),
            nn.GELU(),
            conv3x3(384, 288),
            nn.GELU(),
            conv3x3(288, Nc),
            # conv3x3(336, 192, stride=2),
            # nn.GELU(),
            # conv3x3(288, 240),
            # nn.GELU(),
            # conv3x3(240, 192, stride=2),
        )




    def forward(self, x):   #x[batch,3,256,256]
        x = self.patch_embed(x)   #[batch,48,128,128]
        if self.training:
            x = self.pos_drop(x)  # 验证时可丢弃
        Wh, Ww = x.size(2), x.size(3)
        x = x.flatten(2).transpose(1, 2) #[batch,128*128,48]
        for i in range(self.num_layers): #开始用transformer以及patch merging
            layer = self.layers[i]
            x, Wh, Ww = layer(x, Wh, Ww)
        y = x 
        C = self.embed_dim * 8
        y = y.view(-1, Wh, Ww, C).permute(0, 3, 1, 2).contiguous() #[batch,384,16,16]
        y_shape = y.shape[2:]
        #到这里就可以了 可以直接在384，16，16这个维度用global average pooling然后全连接
        
        z = self.h_a(y)  #[64,16,16]
        #[batch,192,4,4] stride两次降尺寸，卷积通道降维，我就到这里就行了
        #然后我再用global average pooling降维到192后再全连接到比特数
        #或者也可以直接在384，16，16这个维度用global average pooling然后全连接
        return z #[batch,64,16,16]
    
class SwinDec(nn.Module):
    def __init__(self,
                 pretrain_img_size=256,
                 patch_size=2,
                 in_chans=3,
                 embed_dim=48,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=4,
                 num_slices=12,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 patch_norm=True,
                 use_checkpoint=False):
        super().__init__()

        self.pretrain_img_size = pretrain_img_size
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_slices = num_slices
        self.max_support_slices = num_slices // 2



        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        depths = depths[::-1]
        num_heads = num_heads[::-1]
        self.syn_layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** (3-i_layer)),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchSplit if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
                inverse=True)
            self.syn_layers.append(layer)

        self.end_conv = nn.Sequential(nn.Conv2d(embed_dim, embed_dim * patch_size ** 2, kernel_size=5, stride=1, padding=2),   #在最后的de-embedding，从[batch,48,128,128]恢复到[batch,3,256,256]
                                      nn.PixelShuffle(patch_size),
                                      nn.Conv2d(embed_dim, 3, kernel_size=3, stride=1, padding=1),
                                      )

        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
        self.num_features = num_features
        self.g_a = None
        self.g_s = None



        self.h_b = nn.Sequential(
            conv3x3(Nc, 288),
            nn.GELU(),
            conv3x3(288, 384),
            nn.GELU(),
            conv3x3(384, 384),
        )



    def forward(self, z): #[batch,64,16,16]
        y_hat = self.h_b(z) #[batch,384,16,16]
        y_shape = [y_hat.shape[2], y_hat.shape[3]]
        Wh, Ww = y_shape
        C = self.embed_dim * 8
        y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh * Ww, C)  #[batch,384,16,16]
        for i in range(self.num_layers):
            layer = self.syn_layers[i]
            y_hat, Wh, Ww = layer(y_hat, Wh, Ww)
        #[batch,48,128,128]
        if self.training:
            x_hat = self.end_conv(y_hat.view(-1, Wh, Ww, self.embed_dim).permute(0, 3, 1, 2).contiguous())#在最后的de-embedding，从[batch,48,128,128]恢复到[batch,3,256,256]
        else:
            x_hat = self.end_conv(y_hat.view(-1, Wh, Ww, self.embed_dim).permute(0, 3, 1, 2).contiguous()).clamp_(0, 1)  #验证的时候用这个 压缩到01之间
        
        # 
        return x_hat

class SwinAuto(nn.Module):
    def __init__(self,
                 pretrain_img_size=256,
                 patch_size=2,
                 in_chans=3,
                 embed_dim=48,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=4,
                 num_slices=12,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 patch_norm=True,
                 use_checkpoint=False):
        super().__init__()

        self.enc = SwinEnc(pretrain_img_size,patch_size,in_chans,embed_dim,depths,num_heads,window_size,num_slices,mlp_ratio,qkv_bias,qk_scale,drop_rate,attn_drop_rate,drop_path_rate,norm_layer,patch_norm,use_checkpoint)
        self.dec = SwinDec(pretrain_img_size,patch_size,in_chans,embed_dim,depths,num_heads,window_size,num_slices,mlp_ratio,qkv_bias,qk_scale,drop_rate,attn_drop_rate,drop_path_rate,norm_layer,patch_norm,use_checkpoint)
    
        self.embed1 = nn.Linear(256,2*Q)
        self.embed2 = nn.Linear(2*Q,256)



    def forward(self, x): #[batch,3,256,256]
        z = self.enc(x)#[batch,C_dim,16,16]
        z = z.reshape(-1,Nc,256) #[batch,C_dim,256]
        z = self.embed2(self.embed1(z))#[batch,C_dim,16,16]
        z = z.reshape(-1,Nc,16,16)
        x_hat = self.dec(z)
        
        # 
        return x_hat

In [10]:
import torch.nn.functional as F

class Pre_Enc(nn.Module):
    def __init__(self):
        super().__init__()



        self.HEmbed = TRANS_BLOCK(2*K*Nt,64,256,4)
        # self.crossAttention = CrossAttentionFusion(128,2*Q*K,2*Q*K)
        self.Fuse = TRANS_BLOCK(2*K+64,2*K,128,4)  #数据驱动预编码

        
        self.DDDL = TRANS_BLOCK(2*K+64,2*Nt,128,4)  #数据驱动预编码

        self.MDDL  = TRANS_BLOCK(64,4*K+1,256,1) #用Transformer在每个子载波上生成所需的参数

        # self.Imagembed = nn.Linear(2*Q*K,8) #模型驱动预编码以信道语义为主，先将图像语义embed

        # self.linear_H  = nn.Linear(Nt*2, 32)

        self.beta1 = torch.nn.Parameter(torch.zeros(1))
        self.beta2 = torch.nn.Parameter(torch.ones(1))
        # self.beta3 = torch.nn.Parameter(torch.ones(1))
        # self.beta4 = torch.nn.Parameter(torch.zeros(1))
        

  



    def forward(self, z, parm_set, H, H0):   #z[batch*K,C_dim,2*Q] 
        #H [batch,K,Nc,1,Nt] 
        Nc  = parm_set[0]
        Nt  = parm_set[1]
        Nr  = parm_set[2]
        snr = parm_set[3]
        K   = parm_set[4]
        sigma = 1/snr

        batch = H.shape[0]

        
        
        ####################生成噪声与等效信道，同时对两者归一化#######################
        H_equ = (torch.zeros(batch,Nc,K,Nt) + 0j).cuda()
        for i in range(K):
            H_equ[:,:,i] = (H[:,i] ).reshape(-1,Nc,Nt)

        H_equ0 = (torch.zeros(batch,Nc,K,Nt) + 0j).cuda()
        for i in range(K):
            H_equ0[:,:,i] = (H0[:,i] ).reshape(-1,Nc,Nt)

        power = torch.sum(torch.abs(H_equ*H_equ),[2,3])
        H_equ = H_equ/torch.sqrt(power.reshape(-1,Nc,1,1))*sqrt(Nt)
        H_equ0 = H_equ0/torch.sqrt(power.reshape(-1,Nc,1,1))*sqrt(Nt)
        ## 后面噪声也要这么处理，保证snr一致，z不需要这么处理，因为z后面本来就要功率归一化到K。所以我们只管信道和噪声
        noise = (torch.randn(batch,Nc,K,Q).cuda() + 1j*torch.randn(batch,Nc,K,Q).cuda())/torch.sqrt(power.reshape(-1,Nc,1,1))*sqrt(sigma/2)*sqrt(Nt)

         ####################语义融合#######################
        z = z.reshape(-1,Nc,2,Q)#[batch,64,2,128]
        
        z = z.reshape(-1,K,Nc,2,Q)#[4,K,64,2,128]
        z = z.permute(0,4,2,1,3) #[4,Q,64,K,2]
        
        z = z.reshape(-1,Nc,2*K) #[4Q,64,2K]
        

        # z = z.reshape(-1,Nc,K*2*Q) #[4,64,K*2*Q]

        HH = H_equ0.reshape(-1,Nc,K*Nt) #[4,64,K*K]
        HH = torch.cat((torch.real(HH),torch.imag(HH)), 2)#[4,64,2*K*Nt]
        HH = self.HEmbed(HH)#[4,64,64]

        HH_ini = HH + 0

        HH = HH.reshape(-1,1,Nc,64)#[4,1,64,64]
        HH = HH.repeat(1, Q, 1, 1) #[4,Q,64,64]
        HH = HH.reshape(-1,Nc,64)  #[4*Q,64,64]
       
        z = torch.cat((z, HH), 2) #[batch*Q, 64, 2*K+64]

        z_fuse = self.Fuse(z) #[batch*Q, 64, 2*K]



        z_ini = z_fuse + 0 #[batch*Q, 64, 2*K]
        z_fuse = z_fuse.reshape(-1,Q,Nc,K,2)#[4,Q,64,K,2]
        z_fuse = z_fuse.permute(0,2,3,4,1) #[4,64,K,2,Q]
        z_model = z_fuse[:,:,:,0,:] + 1j*z_fuse[:,:,:,1,:] #[4,64,K,Q]


        ###################数据驱动预编码#######################
        z = torch.cat((z_ini, HH), 2) #[batch*Q, 64, 2*K+64]
        z = self.DDDL(z) * self.beta1[0]  #[batch*Q, 64, 2*Nt]
        z = z.reshape(batch,Q,Nc,Nt,2)#[4,Q,64,Nt,2]
        z = z.permute(0,2,3,4,1) #[4,64,Nt,2,Q]

        z = z[:,:,:,0,:] + 1j*z[:,:,:,1,:] #[4,64,Nt,Q]
        Power_1 = (torch.sum(torch.abs(z)**2,[2,3])).reshape(-1,Nc,1,1)


        #####################模型驱动预编码###########################
        # z_embed = self.Imagembed(z_ini) #[batch, 64, 8]
        # z_cat = torch.cat((z_embed, HH), 2) #[batch, 64, 8+128=136]
        x = self.MDDL(HH_ini)   #[batch,Nc,4*K+1]
        H_hat = H_equ0.to(torch.complex128)
        pri1 = (x[:,:,0:K] + 1j*x[:,:,K:2*K]).to(torch.complex128)    #[batch,Nc,K]
        pri2 = (x[:,:,2*K:3*K] + 1j*x[:,:,3*K:4*K]).to(torch.complex128)   #[batch,Nc,K]
        sigma_pri = (x[:,:,4*K]).to(torch.complex128)  #[batch,Nc]
        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,Nt))
            if(k==0):
                B_pri = pri2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pri = B_pri + (sigma_pri.reshape(batch,Nc,1,1) * torch.eye(Nt).cuda().to(torch.complex128))
            else:
                B_pri = B_pri + pri2[:,:,k].reshape(batch,Nc,1,1) * hk @ Hermitian(hk)
                B_pri = B_pri + (sigma_pri.reshape(batch,Nc,1,1) * torch.eye(Nt).cuda().to(torch.complex128))
        V_pri = (torch.zeros((batch,Nc,Nt,K)) + 0j).cuda().to(torch.complex128)
        B_inv_pri = torch.inverse(B_pri)
        for k in range(K):
            hk = Hermitian(H_hat[:,:,k,:].reshape(batch,Nc,1,Nt))
            V_pri[:,:,:,k] = (B_inv_pri @ (pri1[:,:,k].reshape(batch,Nc,1,1) * hk)).reshape(batch,Nc,Nt)
        W_pri = V_pri
        F_BB = (W_pri + 0).to(torch.complex64)
        # Power = (torch.sum(torch.abs(F_BB)**2,[2,3])).reshape(-1,Nc,1,1)
        # F_BB = F_BB / torch.sqrt(Power) * sqrt(K)

        #####################传输###########################
        z_model = F_BB @ z_model * self.beta2[0]
        Power_2 = (torch.sum(torch.abs(z_model)**2,[2,3])).reshape(-1,Nc,1,1)


        noise = noise * torch.sqrt(Power_1+Power_2) / sqrt(Q)  #能量上除等价于噪声上乘

        z = H_equ @ z_model + H_equ @ z #[4,64,K,Q] 复数
        z = z+noise


        # z_model = F_BB @ z_model * self.beta3[0]
        # Power_2 = (torch.sum(torch.abs(z_model)**2,[2,3])).reshape(-1,Nc,1,1)
        # noise = noise * torch.sqrt(Power_1+Power_2) / sqrt(Q)  #能量上除等价于噪声上乘
        # z = H_equ @ z_model + H_equ @ z #[4,64,K,Q] 复数
        # z = z+noise

        return z #[batch,C_dim,K,32]


In [11]:
class Pre_Dec(nn.Module):
    def __init__(self,
                 ):
        super().__init__()
        self.inEmbed = TRANS_BLOCK(2*Q,2*Q,1024,6)




    def forward(self, z): #[batch,C_dim,K,32]
        z = z.permute(0,2,1,3) #[4,K,64,Q] 复数
        z = torch.cat((torch.real(z),torch.imag(z)), 3) #[4,K,64,2*Q]
        z = z.reshape(-1,Nc,2*Q) #[batch,64,2*Q]
        z = self.inEmbed(z) #[batch,64,256]
        z = z.reshape(-1,Nc,2*Q)

        return z

In [12]:
class Pre_Auto(nn.Module):
    def __init__(self,):
        super().__init__()

        self.enc = Pre_Enc()
        self.dec = Pre_Dec()
    




    def forward(self, z, parm_set, H, H0): #[batch,3,256,256]
        z = self.enc(z, parm_set, H, H0)
        z = self.dec(z)
        
        # 
        return z

In [13]:
class E2E_Sem(nn.Module):
    def __init__(self,):
        super().__init__()
        self.swin = SwinAuto()
        self.pre = Pre_Auto()
        





    def forward(self, x, parm_set, H, H0): #[batch*K,3,256,256]
        z = self.swin.enc(x) #[batch*K,C_dim,16,16]
        z = z.reshape(-1,Nc,256) #[batch*K,C_dim,256]
        z = self.swin.embed1(z) #[batch*K,C_dim,2*Q]
        z = self.pre.enc(z, parm_set, H, H0) #[batch,C_dim,K,Q]
        z = self.pre.dec(z)  #[batch*K,C_dim,2*Q]
        z = self.swin.embed2(z)#[batch*K,C_dim,256]
        z = z.reshape(-1,Nc,16,16) #[batch*K,C_dim,16,16]
        x_hat = self.swin.dec(z)#[batch*K,3,256,256]
        
        # 
        return x_hat

In [21]:
import torch
import io
from pytorch_msssim import ms_ssim
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

import torch.utils.data as data
import PIL.Image as Image
from glob import glob




class readImageDataset(data.Dataset):
	def __init__(self, state, input_transform = None):
		self.state = state
		self.root = r".\imageNet"       #该文件夹内包含train、val以及test三个文件夹
		self.inputs = self.getDataPath()
		self.train_input_paths, self.val_input_paths, self.test_input_paths = None, None, None
		self.input_transform = input_transform
    
    # 定义一个读取数据集的子类
	def getDataPath(self):
		self.train_input_paths = glob(self.root + r'\train\data\*')
		self.val_input_paths = glob(self.root + r'\validation\data\*')
		self.test_input_paths = glob(self.root + r'\test\data\*')
		assert self.state == 'train' or self.state == 'val' or self.state == 'test'  #用于选取所读指定的数据集
		if self.state == 'train':
			return self.train_input_paths
		if self.state == 'val':
			return self.val_input_paths
		if self.state == 'test':
			return self.test_input_paths
	
	def __getitem__(self, index):
		input_path = self.inputs[index]
		input = Image.open(input_path).convert('RGB')
		if self.input_transform is not None:
			input = self.input_transform(input)
		return input
	
	def __len__(self):
		return len(self.inputs)
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms


#定义参数函数getArgs
def getArgs():
	parse = argparse.ArgumentParser()
	parse.add_argument("--batch_size", type=int, default=64)
	parse.add_argument("--shuffle", default= True)
	parse.add_argument("--num_workers", type = int, default = 0)
	parse.add_argument('--dataset', default='colorImage', help='colorImage')
	args = parse.parse_args(args=[])
	return args

def getDataset(args):
	train_dataloaders, val_dataloaders, test_dataloaders = None, None, None
	if args.dataset == 'colorImage':
		train_dataset = readImageDataset(r'train', input_transform = input_transform)
		train_dataloaders = DataLoader(train_dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers, drop_last=True)
		val_dataset = readImageDataset(r'val', input_transform = input_transform)
		val_dataloaders = DataLoader(val_dataset, batch_size=args.batch_size, shuffle= True, num_workers = args.num_workers, drop_last=True)
		test_dataset = readImageDataset(r'test', input_transform = input_transform)
		test_dataloaders = DataLoader(test_dataset, batch_size=1, shuffle= True, num_workers = args.num_workers, drop_last=True)
	return train_dataloaders, val_dataloaders, test_dataloaders




input_transform = transforms.Compose([
	# transforms.Grayscale(num_output_channels=3),                   # 1为灰度图， 3为彩色图
	transforms.Resize((256, 256), interpolation=Image.BICUBIC),    # interpolation 改变图像大小，所使用的插值方式
	transforms.ToTensor()
])
# input_transform = None

args = getArgs()      # 获取参数
train_dataloaders, val_dataloaders, test_dataloaders = getDataset(args)  #获取数据集,最终就返回图像，没有其他的标签




In [22]:
import lpips
lpips_fn = lpips.LPIPS(net='vgg').cuda()  # or '.to(tensor1.device)' to use the same device as the tensors
def calculate_lpips(tensor1, tensor2):
    """
    Calculate LPIPS score for two PyTorch tensors, which can either be a single image or a batch of images.
    The tensors should have a shape of [3, 256, 256] for single images or [N, 3, 256, 256] for batches.

    Parameters:
    tensor1 (torch.Tensor): The first tensor or batch of tensors.
    tensor2 (torch.Tensor): The second tensor or batch of tensors.

    Returns:
    torch.Tensor: LPIPS distances for each pair in the batch or single distance if no batch dimension.
    """

    # Make sure both tensors are on the same device, preferably CUDA
    if not tensor1.is_cuda or not tensor2.is_cuda:
        raise ValueError("Both tensors need to be on a CUDA device.")
    
    # Initialize the LPIPS model, using VGG as the backbone
    

    # Add a batch dimension if not present
    if tensor1.ndim == 3:
        tensor1 = tensor1.unsqueeze(0)
    if tensor2.ndim == 3:
        tensor2 = tensor2.unsqueeze(0)

    # Calculate the LPIPS score
    with torch.no_grad():
        lpips_scores = lpips_fn(tensor1, tensor2)

    return lpips_scores.mean()


class RateDistortionLoss(nn.Module):
    """Custom rate distortion loss with a Lagrangian parameter."""

    def __init__(self, lmbda=1e-2):
        super().__init__()
        self.mse = nn.MSELoss()
        self.lmbda = lmbda

    def forward(self, output, target):
        N, _, H, W = target.size()
        out = {}
        num_pixels = N * H * W
        out["mse_loss"] = self.mse(output, target)
        out["loss"] = self.lmbda * 256 ** 2 * out["mse_loss"]

        return out["loss"]

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: d:\Anaconda\envs\torch1_11_0\lib\site-packages\lpips\weights\v0.1\vgg.pth


In [25]:
def train_E2E(LR,EPOCH,model_name_swin,model_name_E2E,parm_set):  #4个用户，一个batch是16，每个用户就是4，所有用户共用一样的网络，那我还是可以一股脑把16一起输入encoder，最后再把输出拆成4×4副图像去预编码，然后我循环4次反向传播来保证batchsize是16
    Nc  = parm_set[0]
    Nt  = parm_set[1]
    Nr  = parm_set[2]
    snr = parm_set[3]
    K   = parm_set[4]
    sigma = 1/snr
    batchsize = 64//K
    batch_num = K
    net=E2E_Sem()
    net.swin.load_state_dict(torch.load(model_name_swin).state_dict())
    net=net.cuda()


    
    optimizer = torch.optim.Adam(net.parameters(),lr=LR)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,  milestones = [25,50,75], gamma = 0.3, last_epoch=-1) 
    criterion = RateDistortionLoss(lmbda=1e-2).cuda()
    best_msssim = 0
    num_train = 0
    train_nmse = 0
    for epoch in range(EPOCH):
        start = datetime.datetime.now()
        for step_train,(b_x_set) in enumerate(train_dataloaders): #b_x_set维度是[64,3,256,256]
            optimizer.zero_grad()
            for i in range(batch_num): #b_x维度是[16,3,256,256]，用户4直接被包含在16的维度里，实际一次的batchsize是4，循环4次所以真正的batchsize是16
                #后续我还需要生成[4,C_dim=Nc,K,K]的信道
                H = SV_Channel(batchsize,Nc,N,[8,8],1)
                H0 = LAMP_GMMV_test(parm_set, L, pho, H)
                nmse = NMSE(H0,H)

                H = H.reshape(batchsize//K,K,Nc,1,Nt)
                H0 = H0.reshape(batchsize//K,K,Nc,1,Nt)

                b_x = b_x_set[i*batchsize:(i+1)*batchsize]
                b_x = b_x.cuda()
                net.train() #训练模式

                out = net(b_x, parm_set, H, H0)
                ssim_train = ms_ssim(b_x, out, data_range=1, size_average=True )
                lpips_train = calculate_lpips(b_x, out)
                mse_train = criterion(out,b_x)

                loss_train = mse_train - ssim_train + lpips_train
                loss_train.backward()

                num_train = num_train + 1
                train_nmse = train_nmse + nmse.detach_()
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()
        with torch.no_grad():
            mse_val = 0
            ssim_val = 0
            num_val = 0
            psnr_val = 0
            lpips_val = 0
            net.eval()  #测试模式
            for step_val,(b_x_set) in enumerate(val_dataloaders):
                out_set = torch.zeros(batchsize*K,3,256,256).cuda()
                for i in range(batch_num): #b_x维度是[16,3,256,256]，用户4直接被包含在16的维度里，实际一次的batchsize是4，循环4次所以真正的batchsize是16
                    H = SV_Channel(batchsize,Nc,N,[8,8],1)
                    H0 = LAMP_GMMV_test(parm_set, L, pho, H)
                    nmse = NMSE(H0,H)

                    H = H.reshape(batchsize//K,K,Nc,1,Nt)
                    H0 = H0.reshape(batchsize//K,K,Nc,1,Nt)

                    b_x = b_x_set[i*batchsize:(i+1)*batchsize]
                    b_x = b_x.cuda()
                    net.train() #训练模式

                    out = net(b_x, parm_set, H, H0)

                    out_set[i*batchsize:(i+1)*batchsize] = out

                    ssim_val = ssim_val + ms_ssim(b_x, out, data_range=1, size_average=True )
                    mse_val = mse_val + criterion(out,b_x)
                    psnr_val = psnr_val + psnr(b_x, out)
                    num_val = num_val + 1
                    lpips_val = lpips_val + calculate_lpips(b_x, out)

            ssim_val = ssim_val/num_val   
            mse_val = mse_val/num_val    
            psnr_val = psnr_val/num_val   
            lpips_val = lpips_val/num_val 
            time0 =  datetime.datetime.now()-start
            train_nmse = train_nmse/num_train
            print('Epoch:',epoch,'NMSE:%.8f' % train_nmse.data,'time',time0,'train mse:%.8f' % mse_train.data,'val mse:%.8f' % mse_val.data,'train ms-ssim:%.8f' % ssim_train.data,'val ms-ssim:%.8f' % ssim_val.data, 'val PSNR:%.4fdB' % psnr_val.data,' val LPIPS:%.8f' % lpips_val.data)
            if best_msssim<ssim_val:
                best_msssim = ssim_val
                torch.save(net, model_name_E2E)
                print('model saved')
            num_train = 0
            train_nmse = 0

            workspace_dir = '.'
            save_dir = os.path.join(workspace_dir, 'logs\\ssim_'+str(N)+'paths_'+str(SNR_dB)+'SNR_'+str(Q)+'Q')
            os.makedirs(save_dir, exist_ok=True)
            filename = os.path.join(save_dir, f'Epoch_{epoch+1:03d}.jpg')
            torchvision.utils.save_image(out_set, filename, nrow=8)
            print(f' | Save some samples to {filename}.')
            # show generated image
            grid_img = torchvision.utils.make_grid(out_set.cpu(), nrow=8)
            plt.figure(figsize=(10,10))
            plt.imshow(grid_img.permute(1, 2, 0))
            plt.show()
        

### Training

##### Pretrain 1 (Swin Encoder and Swin Decoder)

In [None]:
LR = 1e-4
Q = 128
net=SwinAuto()
net=net.cuda()
optimizer = torch.optim.Adam(net.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,  milestones = [25,50], gamma = 0.3, last_epoch=-1)   
criterion = RateDistortionLoss(lmbda=1e-2).cuda()

In [None]:
EPOCH = 50
best_loss = 200
for epoch in range(EPOCH):
    start = datetime.datetime.now()
    for step_train,(b_x) in enumerate(train_dataloaders):
        b_x = b_x.cuda()
        net.train() #训练模式
        out = net.dec(net.enc(b_x))
        loss_train = criterion(out,b_x)
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
    scheduler.step()
    with torch.no_grad():
        loss_val = 0
        num_val = 0
        net.eval()  #测试模式
        for step_val,(b_x) in enumerate(val_dataloaders):
            num_val += 1
            b_x = b_x.cuda()
            out = net.dec(net.enc(b_x))
            loss_val = loss_val + criterion(out,b_x)
        loss_val = loss_val/num_val      
        time0 =  datetime.datetime.now()-start
        print('Epoch:',epoch,'time',time0,'train mse:%.8f' % loss_train.data,'val mse:%.8f' % loss_val.data)
        if loss_val<best_loss:
            best_loss = loss_val
            torch.save(net, './models/net_1user_64C_256dim.pth')
            print('model saved')
    

##### Pretrain 2 (different Q)

In [None]:
LR = 1e-4
Q = 64
net=SwinAuto()
net.enc.load_state_dict(torch.load('./models/net_1user_64C_256dim.pth').enc.state_dict())
net.dec.load_state_dict(torch.load('./models/net_1user_64C_256dim.pth').dec.state_dict())
net=net.cuda()
optimizer = torch.optim.Adam(net.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,  milestones = [25,50,75], gamma = 0.3, last_epoch=-1)   
criterion = RateDistortionLoss(lmbda=1e-2).cuda()

EPOCH = 150
best_loss = 200
for epoch in range(EPOCH):
    start = datetime.datetime.now()
    for step_train,(b_x) in enumerate(train_dataloaders):
        b_x = b_x.cuda()
        net.train() #训练模式
        out = net(b_x)
        loss_train = criterion(out,b_x)
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
    scheduler.step()
    with torch.no_grad():
        loss_val = 0
        num_val = 0
        net.eval()  #测试模式
        for step_val,(b_x) in enumerate(val_dataloaders):
            num_val += 1
            b_x = b_x.cuda()
            out = net(b_x)
            loss_val = loss_val + criterion(out,b_x)
        loss_val = loss_val/num_val      
        time0 =  datetime.datetime.now()-start
        print('Epoch:',epoch,'time',time0,'train mse:%.8f' % loss_train.data,'val mse:%.8f' % loss_val.data)
        if loss_val<best_loss:
            best_loss = loss_val
            torch.save(net, './models/net_1user_64C_'+str(Q)+'Q.pth')
            print('model saved')

##### E2E training

In [None]:
### 损失函数为MSE - MS-SSIM不完美H，embeding，swin里面直接压缩到2*Q，MDDL同时输入信道语义和图像语义，其中图像语义先embed
Q = 128
LR = 1e-4
EPOCH = 5

Nc = 64
N = 2
Nt = 64
Nr = 1
L = 32
SNR_dB = 20
K = 4
snr =  10**(SNR_dB/10)
pho = 2
parm_set = [Nc,Nt,Nr,snr,K]
model_name_swin = './models/net_1user_64C_'+str(Q)+'Q.pth'
model_name_E2E =  './models/01Double_Merge'+str(L)+'pilots_'+str(N)+'paths_'+str(SNR_dB)+'SNR_'+str(Q)+'Q'+'.pth'
train_E2E(LR,EPOCH,model_name_swin,model_name_E2E,parm_set)