In [1]:
import scipy.io
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import math

## Load and Normalize Data
To compute the spectral efficiency, both transmitter and receiver should be optimized with NN. So the dataset contains F_opt and W_opt.

In [2]:
train_layouts = 10000
test_layouts = 2000

In [3]:
data = scipy.io.loadmat('hb_train_144_36.mat')
Fopt_train = data['Fopt'].transpose(2,0,1)
Wopt_train = data['Wopt'].transpose(2,0,1)

In [4]:
test_data = scipy.io.loadmat('hb_test_144_36.mat')
Fopt_test = test_data['Fopt'].transpose(2,0,1)
Wopt_test = test_data['Wopt'].transpose(2,0,1)

In [5]:
def normalize_data(train_data,test_data):
    n1, n2 = train_data.shape[0], test_data.shape[0]
    norm_train = train_data.reshape(n1,-1)
    norm_test = test_data.reshape(n2,-1)
    
    norm_train = np.concatenate((norm_train.real, norm_train.imag),axis = 1)
    norm_test = np.concatenate((norm_test.real, norm_test.imag),axis = 1)
    return norm_train, norm_test
norm_train_F, norm_test_F = normalize_data(Fopt_train, Fopt_test)
norm_train_W, norm_test_W = normalize_data(Wopt_train, Wopt_test)

## Create Dataset

In [6]:
class PCDataset(torch.utils.data.Dataset):
    def __init__(self, data_F, F_opt, data_W, W_opt):
        'Initialization'
        self.Fdata = torch.tensor(data_F, dtype = torch.float)
        self.F_opt = torch.tensor(F_opt, dtype = torch.cfloat)
        self.Wdata = torch.tensor(data_W, dtype = torch.float)
        self.W_opt = torch.tensor(W_opt, dtype = torch.cfloat)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.Fdata)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        F = self.Fdata[index]
        F_opt = self.F_opt[index]
        W = self.Wdata[index]
        W_opt = self.W_opt[index]
        return F, F_opt, W, W_opt

In [7]:
train_data = PCDataset(norm_train_F, Fopt_train, norm_train_W, Wopt_train)
test_data = PCDataset(norm_test_F, Fopt_test, norm_test_W, Wopt_test)

In [8]:
batch_size = 64
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data, test_layouts, shuffle=False)

## Build Model

#### Loss function
As the phase shifter matrix is block diagonal, we first define a block diagonal mask

In [9]:
N_r, N_t, N_RF, N_s = Wopt_train.shape[1], Fopt_train.shape[1], 18, Fopt_train.shape[2]
Fmask = np.zeros((1,N_t, N_RF) )
Wmask = np.zeros((1,N_r, N_RF) )
for i in range(N_RF):
    Fmask[0,i*N_t//N_RF: (i+1)*N_t//N_RF,i] = np.ones((N_t//N_RF) )
    Wmask[0,i*N_r//N_RF: (i+1)*N_r//N_RF,i] = np.ones((N_r//N_RF) )
Fmask = torch.tensor(Fmask, dtype = torch.cfloat)
Wmask = torch.tensor(Wmask, dtype = torch.cfloat)

The neural network module only needs to output F_BB and W_BB. F_RF and W_RF can be obtained by using (33) in [1].

Note: New version pytorch supports complex valued auto differentiation. Please refer to https://pytorch.org/docs/stable/complex_numbers.html for details. 

In [10]:
def FMF_loss(F_BB, F_opt):
    # Compute F_RF from F_BB
    F_BB = F_BB/torch.norm(F_BB, p = 'fro', dim = [1,2], keepdim = True) * math.sqrt(N_RF * N_s)
    F_RF = F_opt @ F_BB.conj().transpose(1,2)
    F_RF = F_RF/torch.abs(F_RF)
    F_RF = Fmask * F_RF / math.sqrt(N_t)
    # Matrix factorization loss
    return torch.mean(torch.norm(F_opt - F_RF @ F_BB, dim = [1,2])**2)

In [11]:
def WMF_loss(W_BB, W_opt):
    # Compute W_RF from W_BB
    W_BB = W_BB/torch.norm(W_BB, p = 'fro', dim = [1,2], keepdim = True) * math.sqrt(N_RF * N_s)
    W_RF = W_opt @ W_BB.conj().transpose(1,2)
    W_RF = W_RF/torch.abs(W_RF)
    W_RF = Wmask * W_RF / math.sqrt(N_r)
    # Matrix factorization loss
    return torch.mean(torch.norm(W_opt - W_RF @ W_BB, dim = [1,2])**2)

#### Standard MLP modules

In [12]:
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid, BatchNorm1d as BN, ReLU6 as ReLU6
def MLP(channels, batch_norm=True):
    return Seq(*[
        Seq(Lin(channels[i - 1], channels[i]), ReLU(), BN(channels[i]))
        for i in range(1, len(channels))
    ])

In [13]:
# FNet maps F_opt to F_BB
class FNet(torch.nn.Module):
    def __init__(self):
        super(FNet, self).__init__()
        self.out_dim = N_RF*N_s
        self.mlp = MLP([N_t * N_s * 2, 100, 100])
        self.mlp = Seq(*[self.mlp,Seq(Lin(100, 2*N_RF*N_s))])

    def forward(self, data):
        bs = data.shape[0]
        out = self.mlp(data)
        out_real = torch.unsqueeze(out[:, :self.out_dim], axis = -1)
        out_imag = torch.unsqueeze(out[:, self.out_dim:self.out_dim*2], axis = -1)
        out = torch.cat((out_real, out_imag), axis = -1)
        out = torch.view_as_complex(out)
        return out.reshape((bs,N_RF,N_s))

In [14]:
# WNet maps W_opt to W_BB
class WNet(torch.nn.Module):
    def __init__(self):
        super(WNet, self).__init__()
        self.out_dim = N_RF*N_s
        self.mlp = MLP([N_r * N_s * 2, 100, 100])
        self.mlp = Seq(*[self.mlp,Seq(Lin(100, 2*N_RF*N_s))])

    def forward(self, data):
        bs = data.shape[0]
        out = self.mlp(data)
        out_real = torch.unsqueeze(out[:, :self.out_dim], axis = -1)
        out_imag = torch.unsqueeze(out[:, self.out_dim:self.out_dim*2], axis = -1)
        out = torch.cat((out_real, out_imag), axis = -1)
        out = torch.view_as_complex(out)
        return out.reshape((bs,N_RF,N_s))

## Train and Test

In [15]:
def train(epoch):
    """ Train for one epoch. """
    Fmodel.train()
    Wmodel.train()
    loss_all = 0
    for batch_idx, (F_train, F_opt_train, W_train, W_opt_train) in enumerate(train_loader):
        #data = data.to(device)
        optimizer.zero_grad()
        Foutput = Fmodel(F_train)
        Woutput = Wmodel(W_train)
        loss = FMF_loss(Foutput, F_opt_train) + WMF_loss(Woutput, W_opt_train)
        loss.backward()
        loss_all += loss.item() * len(F_train)
        optimizer.step()
    return loss_all / len(train_loader.dataset)

In [16]:
def test(loader):
    Fmodel.eval()
    Wmodel.eval()
    correct = 0
    for (F_test, F_opt_test, W_test, W_opt_test) in loader:
        #data = data.to(device)
        Foutput = Fmodel(F_test)
        Woutput = Wmodel(W_test)
        loss = FMF_loss(Foutput, F_opt_test) + WMF_loss(Woutput, W_opt_test)
        correct += loss.item() * len(F_test)
    return correct / len(loader.dataset)

In [17]:
Fmodel = FNet()
Wmodel = WNet()
optimizer = torch.optim.Adam(list(Fmodel.parameters()) + list(Wmodel.parameters()), lr=1e-3)

In [18]:
record = []
for epoch in range(0, 200):
    if(epoch % 10 == 0):
        train_rate = test(train_loader)
        test_rate = test(test_loader)
        print('Epoch {:03d}, Train Rate: {:.4f}, Test Rate: {:.4f}'.format(
            epoch, train_rate, test_rate))
        record.append([train_rate, test_rate])
    train(epoch)

Epoch 000, Train Rate: 3.0085, Test Rate: 3.0117
Epoch 010, Train Rate: 1.6671, Test Rate: 1.8445
Epoch 020, Train Rate: 1.6114, Test Rate: 1.8241
Epoch 030, Train Rate: 1.5894, Test Rate: 1.8211
Epoch 040, Train Rate: 1.5748, Test Rate: 1.8191
Epoch 050, Train Rate: 1.5662, Test Rate: 1.8200
Epoch 060, Train Rate: 1.5588, Test Rate: 1.8196
Epoch 070, Train Rate: 1.5531, Test Rate: 1.8195
Epoch 080, Train Rate: 1.5493, Test Rate: 1.8210
Epoch 090, Train Rate: 1.5451, Test Rate: 1.8217
Epoch 100, Train Rate: 1.5420, Test Rate: 1.8210
Epoch 110, Train Rate: 1.5410, Test Rate: 1.8230
Epoch 120, Train Rate: 1.5372, Test Rate: 1.8216
Epoch 130, Train Rate: 1.5358, Test Rate: 1.8244
Epoch 140, Train Rate: 1.5334, Test Rate: 1.8242
Epoch 150, Train Rate: 1.5318, Test Rate: 1.8217
Epoch 160, Train Rate: 1.5315, Test Rate: 1.8243
Epoch 170, Train Rate: 1.5284, Test Rate: 1.8253
Epoch 180, Train Rate: 1.5271, Test Rate: 1.8251
Epoch 190, Train Rate: 1.5263, Test Rate: 1.8275


## Compute spectral efficiency
The rate computation function is from [1] https://github.com/yuxianghao/Alternating-minimization-algorithms-for-hybrid-precoding-in-millimeter-wave-MIMO-systems/blob/Initial/Narrowband/SDR-AltMin/main_SNR.m

In [19]:
def FBB2FRF(F_BB, F_opt):
    F_BB = F_BB/torch.norm(F_BB, p = 'fro', dim = [1,2], keepdim = True) * math.sqrt(N_RF * N_s)
    F_RF = F_opt @ F_BB.conj().transpose(1,2)
    F_RF = F_RF/torch.abs(F_RF)
    F_RF = Fmask * F_RF / math.sqrt(N_t)
    return F_BB, F_RF

def WBB2WRF(W_BB, W_opt):
    W_BB = W_BB/torch.norm(W_BB, p = 'fro', dim = [1,2], keepdim = True) * math.sqrt(N_RF * N_s)
    W_RF = W_opt @ W_BB.conj().transpose(1,2)
    W_RF = W_RF/torch.abs(W_RF)
    W_RF = Wmask * W_RF / math.sqrt(N_r)
    return W_BB, W_RF

def compute_rate(FBB, FRF, WBB, WRF, H, SNR):
    '''Matlab code: log2(det(eye(Ns) + SNR(s)/Ns * pinv(WRF * WBB) * H(:,:,reali) * FRF * FBB * FBB' * FRF' * H(:,:,reali)' * WRF * WBB))
    '''
    rate = torch.log2(torch.det(torch.eye(N_s) + SNR/N_s * torch.linalg.pinv(WRF @ WBB) @ H @ FRF @ FBB @ FBB.conj().transpose(1,2)
                                         @ FRF.conj().transpose(1,2) @ H.conj().transpose(1,2) @ WRF @ WBB))
    return float(torch.mean(rate).detach().numpy().real)

def rate_test(loader, H):
    Fmodel.eval()
    Wmodel.eval()
    correct = 0
    with torch.no_grad():
        for (F_test, F_opt, W_test, W_opt) in loader:
            #data = data.to(device)
            FBB = Fmodel(F_test)
            FBB, FRF = FBB2FRF(FBB, F_opt)
            WBB = Wmodel(W_test)
            WBB, WRF = WBB2WRF(WBB, W_opt)
            
            print('MF loss:', WMF_loss(WBB, W_opt) + FMF_loss(FBB, F_opt))

        SNR_dBs = np.arange(-15, 15, 5)
        res_mlp = []
        res_opt = []
        res_ran = []
        for SNR_dB in SNR_dBs:
            SNR = 10**(SNR_dB/10)
            res_mlp.append(compute_rate(FBB, FRF, WBB, WRF, H, SNR))
    return res_mlp

In [20]:
test_data = scipy.io.loadmat('hb_test_144_36.mat')
H = torch.tensor(test_data['H'].transpose(2,0,1), dtype = torch.cfloat)
rate_test(test_loader, H)

MF loss: tensor(1.8258)


[5.748510360717773,
 8.554444313049316,
 11.645604133605957,
 14.876102447509766,
 18.164989471435547,
 21.475706100463867]

## References
[1] X. Yu, J.-C. Shen, J. Zhang, and K. B. Letaief, “Alternating minimization algorithms for hybrid precoding in millimeter wave mimo systems,” IEEE J. Sel. Topics Signal Process., vol. 10, no. 3, pp. 485–500, 2016