In [7]:
from scipy.io import loadmat
import torch
from torch.utils.data import Dataset,DataLoader
import h5py
import numpy as np

IMG_SIZE = 64
BATCH_SIZE =16
torch.no_grad()

# with h5py.File('HdLink.mat', 'r') as file:
#     hdlink_data = file['HdLink']
#     hdlink_real = np.array(hdlink_data['real'])  
#     hdlink_imag = np.array(hdlink_data['imag']) 
#     Hdlink = torch.complex(torch.FloatTensor(hdlink_real), torch.FloatTensor(hdlink_imag))
                           
# with h5py.File('HrLink.mat', 'r') as file:
#     hrlink_data = file['HrLink']
#     hrlink_real = np.array(hrlink_data['real'])  
#     hrlink_imag = np.array(hrlink_data['imag']) 
#     Hrlink = torch.complex(torch.FloatTensor(hrlink_real), torch.FloatTensor(hrlink_imag))
    
Htlink_data = loadmat('HtLink.mat')
Htlink = torch.from_numpy(Htlink_data['HtLink'])
Hrlink_data = loadmat('HrLink.mat')
Hrlink = torch.from_numpy(Hrlink_data['HrLink'])
Hdlink_data = loadmat('HdLink.mat')
Hdlink = torch.from_numpy(Hdlink_data['HdLink'])

#del hdlink_data,hdlink_real,hdlink_imag,hrlink_data,hrlink_real,hrlink_imag,Htlink_data

Gchannel = torch.zeros(Hdlink.shape[0],Hdlink.shape[1],2,Hrlink.shape[2])
phi = 2*(torch.rand(Hrlink.shape[1])+1j*torch.rand(Hrlink.shape[1]))-1
diag_phi = torch.diag(phi).to(torch.complex128)
for i in range(Hrlink.shape[2]):
    nchannel = Hdlink[:,:,i] + Hrlink[:,:,i] @ diag_phi @ Htlink
    Gchannel[:,:,0,i] = torch.real(nchannel/torch.abs(nchannel))
    Gchannel[:,:,1,i] = torch.imag(nchannel/torch.abs(nchannel))
#Gchannel = Gchannel/torch.abs(Gchannel)
torch.save(Gchannel, 'Gchannel.pt')
dataloader = DataLoader(Gchannel.view(-1,2,IMG_SIZE,IMG_SIZE), batch_size=BATCH_SIZE, shuffle=True, drop_last=True) 

In [7]:
from torch import nn
import math
import torch

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 2
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 2
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
#model

Num params:  62438242


In [12]:
model_size_bytes = sum(p.numel() for p in model.parameters()) * 4  # numel()返回参数的总元素数，乘以4字节得到总字节数
model_size_mb = model_size_bytes / (1024 ** 2)  # 将字节转换为兆字节
print(model_size_mb)
from fvcore.nn import FlopCountAnalysis

# model是你的网络模型，input是一个典型的输入张量
input_tensor = torch.randn(32, 2, 64, 64)
flops = FlopCountAnalysis(model, input_tensor)
print(flops.total())

238.18299102783203


TypeError: forward() missing 1 required positional argument: 'timestep'

In [9]:

############   Loss function & Sampling

def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    plt.imshow(reverse_transforms(image))

import torch
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE ###############
    img = torch.randn((1, 2, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]: ####################
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.show()     

In [10]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F

#### Forward process

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

  warn(f"Failed to load image Python extension: {e}")


In [11]:
from torch.optim import Adam
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 100 # Try more!

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()
      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long() # sample random time slot index from T
      loss = get_loss(model, batch, t)
      loss.backward()
      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        #sample_plot_image()

Epoch 0 | step 000 Loss: 0.8127553462982178 
Epoch 5 | step 000 Loss: 0.3347211480140686 
Epoch 10 | step 000 Loss: 0.37774068117141724 
Epoch 15 | step 000 Loss: 0.2364252805709839 
Epoch 20 | step 000 Loss: 0.2443949431180954 
Epoch 25 | step 000 Loss: 0.16808025538921356 


KeyboardInterrupt: 

In [13]:
torch.save(model.state_dict(), 'model_weights.pth')

In [28]:
from scipy.io import loadmat
import torch
import numpy as np
from tqdm import tqdm

Gchannel = torch.load('Gchannel.pt')
print(Gchannel.shape)
########### pilot signals ############
num_pilots = 100
Nt = 64

# 创建包含给定复数的实部和虚部的张量
constellation_real = torch.tensor([1, 1, -1, -1], dtype=torch.float)
constellation_imag = torch.tensor([1, -1, 1, -1], dtype=torch.float)

# 生成随机索引
p_indices = torch.randint(0, len(constellation_real), (Nt,))
# 使用索引分别从实部和虚部张量中选择元素
random_real_parts = constellation_real[p_indices]
random_imag_parts = constellation_imag[p_indices]

# 将实部和虚部合并成一个复数张量
pilot = torch.complex(random_real_parts, random_imag_parts)
print(pilot.shape)

########### Pick Validation index ##############
val_number = 16
val_index = torch.randint(0, len(Gchannel), (val_number,))
val_channel = torch.zeros([Nt,Nt,2,val_number])
for i in range(len(val_index)):
    val_channel[:,:,:,i] = Gchannel[:,:,:,val_index[i]]
val_channel = torch.complex(torch.FloatTensor(val_channel[:,:,0,:]), torch.FloatTensor(val_channel[:,:,1,:]))
print(val_channel.shape)
HP = torch.zeros_like(val_channel)
for i in range(val_channel.shape[2]):
    HP[:,:,i] = torch.matmul(val_channel[:,:,i], pilot)

########## AWGN noise ##############
noise = torch.rand_like(HP)+1j*torch.rand_like(HP)
snr_range = torch.arange(-10, 30, 5)
noise_power = 10 ** (-snr_range / 10)

torch.Size([64, 64, 2, 96])
torch.Size([64])
torch.Size([64, 64, 16])


In [47]:
T = 300 ### total time slots in forward or backward process
steps = 3 ########## iteration steps at each noise level
random_channel = torch.randn_like(val_channel)
random_channel = torch.view_as_real(random_channel).permute(2, 3, 0, 1)
pilot_real = torch.view_as_real(pilot)
################## For each SNR ###################
for snr_idx, local_noise in tqdm(enumerate(noise_power)):
    val_received = HP + torch.sqrt(noise_power[snr_idx])*noise # received signal through real channel
    channel_t = random_channel.clone()
    testing_index = 0
    val_P_hermitian = torch.conj(torch.transpose(pilot, -1, 0))
    P_real_hermitian = torch.conj(torch.transpose(pilot_real, -1, 0))
    ############### For each time slot in backward process #############
    for time_idx in tqdm(range(T)):
        labels = torch.ones(random_channel.shape[0]) * time_idx
        ######### For each step at that noise level ########
        for step_idx in range(steps):
            with torch.no_grad():
                score = model(channel_t, labels)
                # Compute gradient for measurements in un-normalized space
                HP_t = torch.matmul(channel_t.reshape(val_number,Nt,2*Nt),pilot_real.reshape(2*Nt))
                meas_grad = torch.matmul(HP_t - torch.view_as_real(val_received).reshape(),)
                # Sample noise
                grad_noise = np.sqrt(2 * alpha * beta_noise) * torch.randn_like(channel_t)
                # Apply update
                channel_t = channel_t + alpha * (score - meas_grad /(local_noise/2. + current_sigma ** 2)) + grad_noise
                
                # Store loss
                nmse_log[snr_idx, testing_idx] = \
                (torch.sum(torch.square(torch.abs(current - oracle)), dim=(-1, -2))/\
                torch.sum(torch.square(torch.abs(oracle)), dim=(-1, -2))).cpu().numpy()
                testing_idx = testing_idx + 1

0it [00:00, ?it/s]
  0%|          | 0/300 [00:00<?, ?it/s][A
0it [00:00, ?it/s]


TypeError: reshape() missing 1 required positional arguments: "shape"

In [51]:
score = model(batch,t)
#print(torch.view_as_real(val_received).shape)
print(HP_t.shape)
print(channel_t.shape)
print(torch.view_as_real(val_received).shape)

torch.Size([16, 64])
torch.Size([16, 2, 64, 64])
torch.Size([64, 64, 16, 2])
