In [1]:
import numpy as np
import torch
import helper_functions
from FPLinQ import FP_optimize, FP
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from dgl.data import DGLDataset
import dgl
import dgl.function as fn
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid, BatchNorm1d as BN, ReLU6 as ReLU6

Using backend: pytorch


## Create Dataset

#### Define system parameters

In [2]:
train_K = 20
test_K = 20
train_layouts = 20000
test_layouts = 2000
var_noise = 1
c = 1/np.sqrt(2)

#### Generate Channel
We follow [1] to generate channel matrices

In [23]:
train_channel_losses = np.abs(c * np.random.randn(train_layouts, train_K, train_K) + c * 1j * np.random.randn(train_layouts, train_K, train_K))
test_channel_losses = np.abs(c * np.random.randn(test_layouts, test_K, test_K) + c * 1j * np.random.randn(test_layouts, test_K, test_K))

#### Compute the label for training and test dataset via FPLinQ
The code for FPlinQ is copied from [2] https://github.com/willtop/Spatial_Deep_Learning_for_Wireless_Scheduling

In [4]:
## Get direct channel CSI, i.e., diag part of channel matrix
directLink_channel_losses = helper_functions.get_directLink_channel_losses(train_channel_losses)
## Get interference channel CSI, i.e., off-diag part of channel matrix
crossLink_channel_losses = helper_functions.get_crossLink_channel_losses(train_channel_losses)
Y = FP(np.ones([train_layouts, train_K]), train_channel_losses, var_noise, np.ones([train_layouts, train_K, 1]))
rates = helper_functions.compute_rates(var_noise, 
            Y, directLink_channel_losses, crossLink_channel_losses)
sr = np.mean(np.sum(rates,axis=1))
y_train = Y

In [5]:
directLink_channel_losses = helper_functions.get_directLink_channel_losses(test_channel_losses)
crossLink_channel_losses = helper_functions.get_crossLink_channel_losses(test_channel_losses)
Y = FP(np.ones([test_layouts, test_K]), test_channel_losses, var_noise, np.ones([test_layouts, test_K, 1]))
rates = helper_functions.compute_rates(var_noise, 
            Y, directLink_channel_losses, crossLink_channel_losses)
sr = np.mean(np.sum(rates,axis=1))
y_test = Y
print('Sum rate by FPlinQ:', sr)

Sum rate by FPlinQ: 2.355823374907156


#### Create DGL Dataset
Please refer to https://docs.dgl.ai/guide/data.html for a tutorial for the usage of DGL dataset

In [6]:
class PCDataset(DGLDataset):
    def __init__(self, csi, label):
        self.data = csi
        self.label = np.expand_dims(label, axis = -1)
        self.get_cg()
        super().__init__(name='power_control')
        
        
    def build_graph(self, idx): 
        H = self.data[idx,:,:]
        
        graph = dgl.graph(self.adj, num_nodes=train_K)
        
        node_features = torch.tensor(np.expand_dims(np.diag(H),axis=1), dtype = torch.float)
        ## Node feature of the k-th node is the direct link channel of k-th pair
        node_labels = torch.tensor(self.label[idx,:,:], dtype = torch.float)
        ## Node label is the power obtained by FPlinQ
        
        edge_features  = []
        for e in self.adj:
            edge_features.append([H[e[0],e[1]],H[e[1],e[0]]])
        ## Edge feature between node e[0] and e[1] is the interference channel between e[0]-th pair and e[1]-th pair

        graph.ndata['feat'] = node_features
        graph.ndata['label'] = node_labels
        graph.edata['feat'] = torch.tensor(edge_features, dtype = torch.float)
        return graph
    
    def get_cg(self):
        ## The graph is a complete graph
        self.adj = []
        for i in range(0,train_K):
            for j in range(0,train_K):
                if(not(i==j)):
                    self.adj.append([i,j])
                    
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.data)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        return self.graph_list[index]

    def process(self):
        n = len(self.data)
        self.graph_list = []
        for i in range(n):
            graph = self.build_graph(i)
            self.graph_list.append(graph)

In [7]:
# Please refer to https://docs.dgl.ai/en/0.2.x/tutorials/basics/4_batch.html for details of collate
def collate(samples):
    '''DGL collate function'''
    graphs = samples
    batched_graph = dgl.batch(graphs)
    return batched_graph

In [8]:
train_data = PCDataset(train_channel_losses, y_train)
test_data = PCDataset(test_channel_losses, y_test)

In [9]:
batch_size = 64
train_loader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=collate)
test_loader = DataLoader(test_data, test_layouts, shuffle=False, collate_fn=collate)

## Build Graph Neural Networks

#### Message Passing Modules 
Please refer to https://docs.dgl.ai/guide/message-api.html for the usage of DGL message-passing

In [10]:
def MLP(channels, batch_norm=True):
    return Seq(*[
        Seq(Lin(channels[i - 1], channels[i]), ReLU(), BN(channels[i]))
        for i in range(1, len(channels))
    ])
class EdgeConv(nn.Module):
    def __init__(self, mlp, **kwargs):
        super(EdgeConv, self).__init__()
        self.mlp = mlp
        #self.reset_parameters()

    def concat_message_function(self, edges):
        return {'out': torch.cat([edges.src['hid'], edges.dst['hid'], edges.data['feat']], axis=1)}
    
    def forward(self, g):
        g.apply_edges(self.concat_message_function)
        g.edata['out'] = self.mlp(g.edata['out'])
        g.update_all(fn.copy_edge('out', 'm'),
                     fn.max('m', 'hid'))

#### GNN Modules 
Please refer to https://docs.dgl.ai/guide/nn-construction.html#guide-nn-construction for detailed DGL GNN

In [11]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = EdgeConv(MLP([4, 16, 16]))
        self.conv2 = EdgeConv(MLP([2*16+2, 32]))
        self.mlp = MLP([32, 16])
        self.mlp = Seq(*[self.mlp, Seq(Lin(16, 1), ReLU6())])

    def forward(self, g):
        g.ndata['hid'] = g.ndata['feat']
        self.conv1(g)
        self.conv2(g)
        out = self.mlp(g.ndata['hid'])
        return out/6

In [12]:
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)

## Training and Test
The training is similar to the node regression task, please refer to https://docs.dgl.ai/en/0.6.x/guide/training-node.html for DGL node regression training

In [13]:
def train(epoch):
    """ Train for one epoch. """
    model.train()
    loss_all = 0
    for batch_idx, g in enumerate(train_loader):
        #data = data.to(device)
        optimizer.zero_grad()
        output = model(g)
        loss = F.mse_loss(output, g.ndata['label'])
        loss.backward()
        loss_all += loss.item() * len(g.ndata['feat'])
        optimizer.step()
    return loss_all / len(train_loader.dataset)

In [21]:
def test(loader, test_mode = False):
    model.eval()
    mse = nmse = sr = 0
    for g in loader:
        n = len(g.ndata['feat'])
        bs = len(g.ndata['feat'])//train_K
        output = model(g).reshape(bs,-1)
        y_test = g.ndata['label'].reshape(bs,-1)
        loss = F.mse_loss(output, y_test)
        mse += loss.item() * bs
        if test_mode:
            nmse += (((output - y_test)**2).sum(axis = -1)/(y_test**2).sum(axis = -1)).sum().item()
            # Truncate operation, which is used in 
            output[output > 0.5] = 1 
            output[output < 0.5] = 0
            rates = helper_functions.compute_rates(var_noise, 
                    output.detach().numpy(), directLink_channel_losses, crossLink_channel_losses)
            sr += np.mean(np.sum(rates,axis=1)) * bs
    if test_mode:
        return mse / len(loader.dataset), nmse / len(loader.dataset), sr/len(loader.dataset)
    return mse / len(loader.dataset)

In [22]:
record = []
for epoch in range(0, 20):
    if(epoch % 1 == 0):
        loss = test(train_loader)
        mse, nmse, rate = test(test_loader, True)
        print('Epoch {:03d}, Train Loss: {:.4f}, Val MSE: {:.4f}, Val NMSE: {:.4f}, Val Rate: {:.4f}'.format(
            epoch, loss, mse, nmse, rate))
        record.append([loss,mse,nmse,rate])
    train(epoch)
    #scheduler.step()

Epoch 000, Train Loss: 0.0475, Val MSE: 0.0478, Val NMSE: 0.1982, Val Rate: 2.3112
Epoch 001, Train Loss: 0.0502, Val MSE: 0.0502, Val NMSE: 0.2018, Val Rate: 2.3204
Epoch 002, Train Loss: 0.0486, Val MSE: 0.0486, Val NMSE: 0.1965, Val Rate: 2.3229
Epoch 003, Train Loss: 0.0472, Val MSE: 0.0475, Val NMSE: 0.1949, Val Rate: 2.3166
Epoch 004, Train Loss: 0.0472, Val MSE: 0.0474, Val NMSE: 0.1947, Val Rate: 2.3163
Epoch 005, Train Loss: 0.0473, Val MSE: 0.0478, Val NMSE: 0.2002, Val Rate: 2.3079
Epoch 006, Train Loss: 0.0501, Val MSE: 0.0510, Val NMSE: 0.2179, Val Rate: 2.2860
Epoch 007, Train Loss: 0.0467, Val MSE: 0.0471, Val NMSE: 0.1961, Val Rate: 2.3107
Epoch 008, Train Loss: 0.0474, Val MSE: 0.0479, Val NMSE: 0.2027, Val Rate: 2.3020
Epoch 009, Train Loss: 0.0480, Val MSE: 0.0486, Val NMSE: 0.2048, Val Rate: 2.2997
Epoch 010, Train Loss: 0.0468, Val MSE: 0.0470, Val NMSE: 0.1924, Val Rate: 2.3180
Epoch 011, Train Loss: 0.0465, Val MSE: 0.0469, Val NMSE: 0.1939, Val Rate: 2.3130
Epoc

## References
[1] H. Sun, X. Chen, Q. Shi, M. Hong, X. Fu, and N. D. Sidiropoulos, “Learning to optimize: Training deep neural networks for interference management,” IEEE Trans. Signal Process., vol. 66, pp. 5438 – 5453, Oct. 2018.
[2]  W. Cui, K. Shen, and W. Yu, “Spatial deep learning for wireless scheduling,” IEEE J. Sel. Areas Commun., vol. 37, Jun. 2019.