In [1]:
import numpy as np
import torch
import helper_functions
from FPLinQ import FP_optimize, FP
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from dgl.data import DGLDataset
import dgl
import dgl.function as fn
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid, BatchNorm1d as BN, ReLU6 as ReLU6

Using backend: pytorch


## Create Dataset

#### Define system parameters

In [2]:
var_noise = 0.1
c = 1/np.sqrt(2)
train_K = 20
test_K = 20
train_layouts = 20000
test_layouts = 1000

#### Generate Channel
We follow [1] to generate channel matrices

In [3]:
train_channel_losses = np.abs(c * np.random.randn(train_layouts, train_K, train_K) + c * 1j * np.random.randn(train_layouts, train_K, train_K))
test_channel_losses = np.abs(c * np.random.randn(test_layouts, test_K, test_K) + c * 1j * np.random.randn(test_layouts, test_K, test_K))

#### Compute the label for training and test dataset via FPLinQ
The code for FPlinQ is copied from [2] https://github.com/willtop/Spatial_Deep_Learning_for_Wireless_Scheduling

In [4]:
# For training dataset, we do not generate labels
## Get direct channel CSI, i.e., diag part of channel matrix
direct_train = helper_functions.get_directLink_channel_losses(train_channel_losses)
## Get interference channel CSI, i.e., off-diag part of channel matrix
cross_train = helper_functions.get_crossLink_channel_losses(train_channel_losses)

In [5]:
## To provide a test baseline, we run FPlinQ with 100 different initializations, and take the highest sum rate
direct_test = helper_functions.get_directLink_channel_losses(test_channel_losses)
cross_test = helper_functions.get_crossLink_channel_losses(test_channel_losses)

rates_all = []
for i in range(100):
    init_x = np.random.rand(test_layouts, test_K,1)
    Y = FP(np.ones([test_layouts, test_K]), test_channel_losses, var_noise, init_x)
    rates = np.expand_dims(helper_functions.compute_rates(var_noise, 
                Y, direct_test, cross_test), axis = 0)
    rates_all.append(rates)
    
rates_all = np.concatenate(rates_all)
sr = np.mean(np.sum(rates[0,:,:],axis=1))
sr_max = np.mean(np.max(np.sum(rates_all, axis = -1), axis = 0))
y_test = Y
print('Sum rate by FPlinQ:', sr)
print('Sum rate by Best FPlinQ:', sr_max)

Sum rate by FPlinQ: 4.280697775775529
Sum rate by Best FPlinQ: 4.634149557657047


#### Create DGL Dataset
Please refer to https://docs.dgl.ai/guide/data.html for a tutorial for the usage of DGL dataset

In [6]:
class PCDataset(DGLDataset):
    def __init__(self, csi, direct, cross):
        self.data = csi
        
        self.direct = torch.tensor(direct, dtype = torch.float)
        self.cross = torch.tensor(cross, dtype = torch.float)
        self.get_cg()
        super().__init__(name='power_control')
    
    def build_graph(self, idx):
        H = self.data[idx,:,:]
        
        graph = dgl.graph(self.adj, num_nodes=train_K)
        
        node_features = torch.tensor(np.expand_dims(np.diag(H),axis=1), dtype = torch.float)
        node_features = torch.cat([node_features, torch.ones_like(node_features)], axis = 1)
        ## Node feature of the k-th node is the direct link channel of k-th pair
        
        edge_features  = []
        for e in self.adj:
            edge_features.append([H[e[0],e[1]],H[e[1],e[0]],1])
        ## Edge feature between node e[0] and e[1] is the interference channel between e[0]-th pair and e[1]-th pair
        
        graph.ndata['feat'] = node_features
        graph.edata['feat'] = torch.tensor(edge_features, dtype = torch.float)
        return graph
    
    def get_cg(self):
        ## The graph is a complete graph
        self.adj = []
        for i in range(0,train_K):
            for j in range(0,train_K):
                if(not(i==j)):
                    self.adj.append([i,j])
            
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.data)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        return self.graph_list[index], self.direct[index], self.cross[index]

    def process(self):
        n = len(self.data)
        self.graph_list = []
        for i in range(n):
            graph = self.build_graph(i)
            self.graph_list.append(graph)

In [7]:
# Please refer to https://docs.dgl.ai/en/0.2.x/tutorials/basics/4_batch.html for details of collate
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, direct, cross = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.stack(direct), torch.stack(cross)

In [8]:
train_data = PCDataset(train_channel_losses, direct_train, cross_train)
test_data = PCDataset(test_channel_losses, direct_test, cross_test)

In [9]:
batch_size = 64
train_loader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=collate)
test_loader = DataLoader(test_data, test_layouts, shuffle=False, collate_fn=collate)

## Build Graph Neural Networks

#### Define loss function
Rewrite compute_rates in helper_functions.py via Pytorch functions and take an negative sign

In [10]:
def rate_loss(allocs, directlink_channel_losses, crosslink_channel_losses):
    SINRs_numerators = allocs * directlink_channel_losses
    SINRs_denominators = torch.squeeze(torch.matmul(crosslink_channel_losses, torch.unsqueeze(allocs, axis=-1))) + var_noise
    SINRs = SINRs_numerators / SINRs_denominators
    rates = torch.log2(1 + SINRs)
    return -torch.mean(torch.sum(rates, axis = 1))

#### Message Passing Modules 
Please refer to https://docs.dgl.ai/guide/message-api.html for the usage of DGL message-passing

In [11]:
def MLP(channels, batch_norm=True):
    return Seq(*[
        Seq(Lin(channels[i - 1], channels[i]), ReLU(), BN(channels[i]))
        for i in range(1, len(channels))
    ])
class EdgeConv(nn.Module):
    def __init__(self, mlp, **kwargs):
        super(EdgeConv, self).__init__()
        self.mlp = mlp
        #self.reset_parameters()

    def concat_message_function(self, edges):
        return {'out': torch.cat([edges.src['hid'], edges.dst['hid'], edges.data['feat']], axis=1)}
    
    def forward(self, g):
        g.apply_edges(self.concat_message_function)
        g.edata['out'] = self.mlp(g.edata['out'])
        g.update_all(fn.copy_edge('out', 'm'),
                     fn.mean('m', 'hid'))

#### GNN Modules 
Please refer to https://docs.dgl.ai/guide/nn-construction.html#guide-nn-construction for usage of DGL GNN

In [12]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = EdgeConv(MLP([7, 16]))
        self.conv2 = EdgeConv(MLP([2*16+3, 32]))
        self.mlp = MLP([32, 16])
        self.mlp = Seq(*[self.mlp, Seq(Lin(16, 1), Sigmoid())])

    def forward(self, g):
        g.ndata['hid'] = g.ndata['feat']
        self.conv1(g)
        self.conv2(g)
        out = self.mlp(g.ndata['hid'])
        return out

In [13]:
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Training and Test
The training is similar to the node regression task, please refer to https://docs.dgl.ai/en/0.6.x/guide/training-node.html for DGL node regression training

In [14]:
def train(epoch):
    """ Train for one epoch. """
    model.train()
    loss_all = 0
    for batch_idx, (g, d_train, c_train) in enumerate(train_loader):
        #data = data.to(device)
        n = len(g.ndata['feat'])
        bs = len(g.ndata['feat'])//train_K
        
        optimizer.zero_grad()
        output = model(g).reshape(bs,-1)
        loss = rate_loss(output, d_train, c_train)
        loss.backward()
        
        loss_all += loss.item() * bs
        optimizer.step()
    return loss_all / len(train_loader.dataset)

In [15]:
def test(loader):
    model.eval()
    correct = 0
    for (g, d_test, c_test) in loader:
        n = len(g.ndata['feat'])
        bs = len(g.ndata['feat'])//train_K
        #data = data.to(device)
        output = model(g).reshape(bs,-1)
        loss = rate_loss(output, d_test, c_test)
        correct += loss.item() * bs
    return correct / len(loader.dataset)

In [16]:
record = []
for epoch in range(0, 50):
    if(epoch % 5 == 0):
        with torch.no_grad():
            train_rate = test(train_loader)
            test_rate = test(test_loader)
        print('Epoch {:03d}, Train Rate: {:.4f}, Test Rate: {:.4f}'.format(
            epoch, train_rate, test_rate))
        record.append([train_rate, test_rate])
    train(epoch)

Epoch 000, Train Rate: -1.4705, Test Rate: -1.4736
Epoch 005, Train Rate: -4.3329, Test Rate: -4.2711
Epoch 010, Train Rate: -4.4226, Test Rate: -4.3682
Epoch 015, Train Rate: -4.4590, Test Rate: -4.4254
Epoch 020, Train Rate: -4.4616, Test Rate: -4.4187
Epoch 025, Train Rate: -4.4980, Test Rate: -4.4837
Epoch 030, Train Rate: -4.4881, Test Rate: -4.4580
Epoch 035, Train Rate: -4.4885, Test Rate: -4.4547
Epoch 040, Train Rate: -4.4120, Test Rate: -4.3869
Epoch 045, Train Rate: -4.4727, Test Rate: -4.3986


## References
[1] H. Sun, X. Chen, Q. Shi, M. Hong, X. Fu, and N. D. Sidiropoulos, “Learning to optimize: Training deep neural networks for interference management,” IEEE Trans. Signal Process., vol. 66, pp. 5438 – 5453, Oct. 2018.
[2]  W. Cui, K. Shen, and W. Yu, “Spatial deep learning for wireless scheduling,” IEEE J. Sel. Areas Commun., vol. 37, Jun. 2019.