In [1]:
import torch.utils.data as data
import torch
import h5py
import scipy.io as scio
import argparse, os
import math, random
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import models
import torch.utils.model_zoo as model_zoo
from tensorboardX import SummaryWriter
import numpy as np
from torchvision import transforms

os.environ["CUDA_VISIBLE_DEVICES"] = '1'


class DatasetFromHdf5(data.Dataset):
    def __init__(self, file_path):
        super(DatasetFromHdf5, self).__init__()
        hf = h5py.File(file_path)
        self.data = hf.get('data')
        self.target = hf.get('label')

    def __getitem__(self, index):
        return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
        
    def __len__(self):
        return self.data.shape[0]
    

class DatasetFrommat(data.Dataset):
    def __init__(self, data_file_path, gt_file_path):
        super(DatasetFrommat, self).__init__()
        data_file = scio.loadmat(data_file_path)
        gt_file = scio.loadmat(gt_file_path)
        self.data = []
        self.target = []
        for i in range(1,29391): #2671：取决于实际训练RB数+1
            self.data.append(data_file[f'fre_LS_GIGI{str(i)}_123'])
            self.target.append(gt_file[f'H_gt{str(i)}_123'])

    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index]).float()
        label = torch.from_numpy(self.target[index]).float()
        label_magnitude = torch.sqrt(torch.pow(label[0, :, :, :],2) + torch.pow(label[1, :, :, :],2))
        return  data, label, label_magnitude
    
    def __len__(self):
        return len(self.data)
    
    
class Dataset_eval(data.Dataset):
    def __init__(self, data_file_path, length):
        super(Dataset_eval, self).__init__()
        data_file = scio.loadmat(data_file_path)
        self.data = []
        self.target = []
        for i in range(1,201): #179：取决于实际测试RB数+1
            for j in range(0,16):
                self.data.append(data_file[f'fre_LS_GIGI{str(i)}_{str(2*j)}'])
                self.target.append(f'fre_LS_GIGI{str(i)}_{str(2*j)}')

    def __getitem__(self, index):
        return torch.from_numpy(self.data[index]).float(), self.target[index]
        
    def __len__(self):
        return len(self.data)


In [2]:
# Training settings
parser = argparse.ArgumentParser(description="PyTorch SRResNet")
parser.add_argument("--batchSize", type=int, default=32, help="training batch size")
parser.add_argument("--nEpochs", type=int, default=480, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=1e-4")
parser.add_argument("--step", type=int, default=40, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
parser.add_argument("--cuda", action="store_true", help="Use cuda?")
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use, Default: 1")
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
parser.add_argument("--gpus", default="1", type=str, help="gpu ids (default: 0)")

global opt, model, netContent
opt = parser.parse_args(args=[])



def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10"""
    lr = opt.lr * (0.1 ** (epoch // opt.step))
    return lr 

def save_checkpoint(model, epoch, check_location):
    model_out_path = "saved_models/" + check_location + "/model_epoch_{}.pth".format(epoch)
    state = {"epoch": epoch ,"model": model}
    if not os.path.exists("saved_models/" + check_location):
        os.makedirs("saved_models/" + check_location)

    torch.save(state, model_out_path)

    print("Checkpoint saved to {}".format(model_out_path))

    



cuda = True

opt.seed = random.randint(1, 10000)
print("Random Seed: ", opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)

cudnn.benchmark = True

print("===> Loading datasets")
train_set = DatasetFrommat('/home/great80/data/train_fre_LS_GIGI_mixSNR_1_4_12_3D_3.mat', '/home/great80/data/label_train_H_gt_RB_mixSNR_1_4_12_3D_3.mat') #路径修改


test_dataset_location = '/home/great80/data/test_fre_LS_GIGI_mixSNR_1_4_12_3D_3.mat'#路径修改
test_dataset_gt_location = '/home/great80/data/label_test_H_gt_RB_mixSNR_1_4_12_3D_3.mat'#路径修改
labels = scio.loadmat(test_dataset_gt_location)
output = 'out_fre_LS_GIGI_RES_mixSNR_3D_3' #可以不改

evaldataset = Dataset_eval(test_dataset_location, 200)    #178：按照实际测试RB数修改


Random Seed:  4004
===> Loading datasets


FileNotFoundError: [Errno 2] No such file or directory: '/home/great80/data/train_fre_LS_GIGI_mixSNR_1_4_12_3D_3.mat'

In [3]:
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=8, shuffle=True)#, collate_fn=AlignCollate())
eval_data_loader = DataLoader(dataset=evaldataset, batch_size=8, shuffle=True)#, collate_fn=AlignCollate_eval())

print(train_set.__getitem__(0)[0].shape)
print(train_set.__getitem__(0)[1].shape)
print(train_set.__getitem__(0)[2].shape)
print(evaldataset.__getitem__(0)[0].shape)
print(evaldataset.__getitem__(0)[1])


torch.Size([2, 64, 12, 2])
torch.Size([2, 64, 12, 2])
torch.Size([64, 12, 2])
torch.Size([2, 64, 12, 2])
fre_LS_GIGI1_0


In [4]:
def validation(model, eval_data_loader, labels, snr):
    model.eval()
    result_dict = {}
    with torch.no_grad():
        for iteration, batch in enumerate(eval_data_loader, 1):
            im_input = batch[0]
            target = batch[1]
#             im_input = im_input.permute(0,4,1,2,3).reshape(im_input.shape[0]*8,7,2,64,56).permute(0,2,3,4,1)
            im_input = im_input[:,:,:,0:12:4,0:2:2].cuda() #SRCNN:im_input[:,:,:,0:12:1,0:2:1].cuda() #EDSR:im_input[:,:,:,0:12:4,0:2:2].cuda()

            out = model(im_input)[0]

#             out = out.permute(0,4,1,2,3).reshape(im_input.shape[0]//8,4*14,2,64,56).permute(0,2,3,4,1)
            out = out.cpu().numpy()
            for i in range(len(out)):
                result_dict[target[i]] = out[i]    
                
#     mse = np.zeros((64*56,178))
#     for i in range(1,179):
#         out = result_dict[f'fre_LS_GIGI{i}_{snr}']
#         label = labels[f'H_gt{i}_{snr}']
#         out_complex = out[0] + out[1] * 1j
#         label_complex = label[0,:,:,:] + label[1,:,:,:] * 1j
#         for j in range(56):
#             for k in range(64):
#                 mse[j*64+k,i-1] = np.square(np.linalg.norm(out_complex[k,j,:]-label_complex[k,j,:])) /  np.square(np.linalg.norm(label_complex[k,j]))
#     mse2 = np.sum(mse, 0) / 64 / 56
#     mse3 = np.sum(mse2)/178
    mse = np.zeros(201) #201：取决于实际测试RB数+1
    for i in range(1,201): #201：取决于实际测试RB数+1
        out = result_dict[f'fre_LS_GIGI{i}_{snr}']
        label = labels[f'H_gt{i}_{snr}']
        out_complex = out[0] + out[1] * 1j
        label_complex = label[0,:,:,:] + label[1,:,:,:] * 1j
        out_complex = out_complex.flatten()
        label_complex = label_complex.flatten()
        mse[i] = np.square(np.linalg.norm(out_complex-label_complex)) /  np.square(np.linalg.norm(label_complex))
    mse3 = np.sum(mse)/200 #200：取决于实际测试RB数
    return mse3



In [5]:
def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True):
    if not isinstance(kernel_size, int):
        padding = [(i - 1) // 2 for i in kernel_size]
    else:
        padding = (kernel_size - 1) // 2
    return nn.Conv3d(
        in_channels, out_channels, (kernel_size,3,3), stride,
        padding=(padding,1,1), bias=bias)

class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size, stride,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size[i], stride, bias=bias))
            if bn:
                m.append(nn.BatchNorm3d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res

class PixelShuffle3d(nn.Module):
    '''
    This class is a 3d version of pixelshuffle.
    '''
    def __init__(self, scale):
        '''
        :param scale: upsample scale
        '''
        super().__init__()
        self.scale = scale

    def forward(self, input):
        batch_size, channels, in_depth, in_height, in_width = input.size()
        nOut = channels // (self.scale**2 )
#         nOut = channels // self.scale

        out_depth = in_depth
        out_height = in_height * self.scale
        out_width = in_width * self.scale 
#         out_width = in_width * self.scale * 2

        input_view = input.contiguous().view(batch_size, nOut, self.scale, self.scale, in_depth, in_height, in_width)

        output = input_view.permute(0, 1, 4, 5, 2, 6, 3).contiguous()

        return output.view(batch_size, nOut, out_depth, out_height, out_width)


class Net(torch.nn.Module):
    def __init__(self, num_channels=2, base_filter=16, n_resblocks=2, upscale_factor=1):
        super(Net, self).__init__()
        
        kernel_size_res = [3,3] 
        stride = 1
        act = nn.ReLU(True)
        n_resblocks = n_resblocks
        
#         m_head = [default_conv(num_channels, base_filter, 5)]
#         self.head = nn.Sequential(   
#             nn.ConvTranspose3d(2, base_filter//2, (1,2,2), stride=(1,2,2)),
#             nn.BatchNorm3d(base_filter//2),
#             nn.ReLU(inplace=True),
#             nn.ConvTranspose3d(base_filter//2, base_filter, (1,2,2), stride=(1,2,2)),
#             nn.BatchNorm3d(base_filter),
#             nn.ReLU(inplace=True),
#             )
        # define body module
#         m_body = [
#             ResBlock(
#                 default_conv, base_filter, kernel_size_res, stride, act=act, res_scale=1
#             ) for _ in range(n_resblocks)
#         ]
#         self.body = nn.Sequential(*m_body)
        
#         m_tail = []
#         m_tail.append(default_conv(base_filter, 2, 5))
#         self.tail = nn.Sequential(*m_tail)

########### REFINED NET
        self.tail = nn.Sequential(   
            nn.ConvTranspose3d(base_filter, base_filter//2, (1,2,2), stride=(1,2,2)),
            nn.BatchNorm3d(base_filter//2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(base_filter//2, 2, (1,2,1), stride=(1,2,1)),
            
#             nn.ConvTranspose3d(base_filter, 2, (1,2,2), stride=(1,2,2)),
            
#             default_conv(base_filter, base_filter//2, 5), 
#             PixelShuffle3d(2),
#             default_conv(base_filter // 2, base_filter // 2, 5),
#             PixelShuffle3d(2)
            
            )
    
        # define body module
        m_body = [
            ResBlock(
                default_conv, base_filter, kernel_size_res, stride, act=act, res_scale=1
            ) for _ in range(n_resblocks)
        ]
        self.body = nn.Sequential(*m_body)
        
        self.head = nn.Sequential(default_conv(2, base_filter, 5),
#             nn.BatchNorm3d(base_filter),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.head(x)
        x = self.body(x) + x
        x = self.tail(x)
        x_magnitude = torch.sqrt(torch.pow(x[:, 0, :, :],2) + torch.pow(x[:, 1, :, :],2))

        return x, x_magnitude

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()


In [1]:
import torch 
import torch.nn as nn


class SRCNN(torch.nn.Module):
    def __init__(self, num_channels=2, base_filter=16, upscale_factor=1):
        super(SRCNN, self).__init__()

        self.layers = torch.nn.Sequential(
            nn.Conv3d(in_channels=num_channels, out_channels=base_filter, kernel_size=5, stride=1, padding = 2, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels=base_filter, out_channels=base_filter, kernel_size=5, stride=1, padding= 2, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels=base_filter, out_channels=num_channels, kernel_size=5, stride=1, padding= 2, bias=True),

#             nn.Conv3d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=5, stride=1, padding=2, bias=True),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(in_channels=base_filter // 2, out_channels=num_channels * (upscale_factor ** 2), kernel_size=5, stride=1, padding=2, bias=True),
            #             nn.PixelShuffle(upscale_factor)
        )

    def forward(self, x):
        x = self.layers(x)
        x_magnitude = torch.sqrt(torch.pow(x[:, 0, :, :],2) + torch.pow(x[:, 1, :, :],2))

        return x, x_magnitude 

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [5]:
model = SRCNN(num_channels=2, base_filter=16, upscale_factor=1) #SRCNN
# print(model)
print(model(torch.ones(4,2,64,12,2))[1].shape)
print(model(torch.ones(4,2,64,12,2))[0].shape)


torch.Size([4, 64, 12, 2])
torch.Size([4, 2, 64, 12, 2])


In [6]:
model = Net(num_channels=2, base_filter=16, n_resblocks=4) #EDSR 
# print(model)
print(model(torch.ones(4,2,64,3,1))[1].shape)
print(model(torch.ones(4,2,64,3,1))[0].shape)

torch.Size([4, 64, 12, 2])
torch.Size([4, 2, 64, 12, 2])


In [None]:
expeiment_name = 'SNS/20230809_eCNN_RN' #随你喜欢改，最好每次训练都改个名字

# Training settings
parser = argparse.ArgumentParser(description="PyTorch SRResNet")
parser.add_argument("--batchSize", type=int, default=32, help="training batch size")
parser.add_argument("--nEpochs", type=int, default=480, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate. Default=1e-4")
parser.add_argument("--step", type=int, default=40, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
parser.add_argument("--cuda", action="store_true", help="Use cuda?")
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use, Default: 1")
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
parser.add_argument("--gpus", default="4", type=str, help="gpu ids (default: 0)")

global opt, model, netContent
opt = parser.parse_args(args=[])
print(opt)

"""tensorboard setting"""
writer = SummaryWriter(f'./saved_models/{expeiment_name}')

print("===> Building model")

model = Net(num_channels=2, base_filter=8, n_resblocks=4) #eCNN-RN
# model = SRCNN(num_channels=2, base_filter=8, upscale_factor=1) #SRCNN

criterion = nn.MSELoss(reduction='mean')

print("===> Setting GPU")
model = model.cuda()
criterion = criterion.cuda()

# optionally resume from a checkpoint
if opt.resume:
    if os.path.isfile(opt.resume):
        print("=> loading checkpoint '{}'".format(opt.resume))
        checkpoint = torch.load(opt.resume)
        opt.start_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["model"].state_dict())
    else:
        print("=> no checkpoint found at '{}'".format(opt.resume))

# optionally copy weights from a checkpoint
if opt.pretrained:
    if os.path.isfile(opt.pretrained):
        print("=> loading model '{}'".format(opt.pretrained))
        weights = torch.load(opt.pretrained)
        model.load_state_dict(weights['model'].state_dict())
    else:
        print("=> no model found at '{}'".format(opt.pretrained))

print("===> Setting Optimizer")
optimizer = optim.Adam(model.parameters(), lr=opt.lr)

print("===> Training")
i = 0
num_print = 4
val_print = 334
mse_best = 100
running_loss = 0.0
running_loss_mse = 0.0
running_loss_magnitude = 0.0
for epoch in range(opt.start_epoch, opt.nEpochs + 1):

#     lr = adjust_learning_rate(optimizer, epoch-1)

#     for param_group in optimizer.param_groups:
#         param_group["lr"] = lr

#     print("Epoch={}, lr={}".format(epoch, optimizer.param_groups[0]["lr"]))
    model.train()


    for iteration, batch in enumerate(training_data_loader, 1):

        input, target, target_m = batch[0], batch[1], batch[2]
#         print(input.shape)
#         input = input.permute(0,4,1,2,3).reshape(input.shape[0]*8,7,2,64,56).permute(0,2,3,4,1)

        input = input[:,:,:,0:12:4,0:2:2].cuda() #SRCNN: input[:,:,:,0:12:1,0:2:1].cuda() #EDSR: input[:,:,:,0:12:4,0:2:2].cuda()
        target = target.cuda()
        target_m = target_m.cuda()

        output, output_m = model(input)
        
#         output_m = output_m.permute(0,3,1,2).reshape(input.shape[0]//8,4*14,64,56).permute(0,2,3,1)
#         output = output.permute(0,4,1,2,3).reshape(input.shape[0]//8,4*14,2,64,56).permute(0,2,3,4,1)
        loss_mse = criterion(output, target)
        loss_magnitude = criterion(output_m, target_m)
        loss = loss_mse + 0.1*loss_magnitude
        
        running_loss += loss_mse.item()
        running_loss += loss_magnitude.item()
        running_loss_mse += loss_mse.item()
        running_loss_magnitude += loss_magnitude.item()
        
        optimizer.zero_grad()

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

        i += 1

        if i % num_print == 0:
            writer.add_scalar('train/Loss',  running_loss / num_print, i)
            writer.add_scalar('train/Loss_mse',  running_loss_mse / num_print, i)
            writer.add_scalar('train/Loss_magnitude',  running_loss_magnitude / num_print, i)
            print("Epoch[{}]({}/{}): Loss: {:.5} Loss_x: {:.5} Loss_m: {:.5}".format(epoch, iteration, len(training_data_loader), running_loss/num_print, running_loss_mse/num_print, running_loss_magnitude/num_print))
            running_loss = 0.0
            running_loss_mse = 0.0
            running_loss_magnitude = 0.0
        if i % val_print == 0:
            mse = validation(model, eval_data_loader, labels,30)
            print("Epoch[{}]: val mse: {:.5}".format(epoch, mse))
            writer.add_scalar('train/valmse',  mse, i)
            if mse < mse_best:
                save_checkpoint(model, i, expeiment_name)
                mse_best = mse

In [8]:
model.load_state_dict(torch.load('saved_models/SNS/20230809_eCNN_RN/model_epoch_1578150.pth')['model'].state_dict())
model.cuda()

val_snr_dict = {}
for i in range(0, 32, 2):
    mse = validation(model, eval_data_loader, labels, i)
    val_snr_dict[i] = mse

In [None]:
for i in range(0, 16):
    print(f'{2*i} \t {val_snr_dict[2*i]:.4f}')