In [1]:
import time
from itertools import product

import altair as alt
import numpy as np
import pandas as pd
import torch
from numba import njit, bool_
from prettytable import PrettyTable
from torch_geometric.loader import DataLoader

from loss import create_loss_fn
from model import FlexNet
from utils.algorithms import get_directional_sum_rate, wmmse
from utils.data import flex_graph, gen_rectangular_channel_matrix

In [2]:
# Data generation functions
def generate_data(n, k):
    return gen_rectangular_channel_matrix(k, k, n, seed=899)


# GNN evaluation
def load_model(path):
    model = FlexNet()
    model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    model.eval()
    return model


def eval_model(model, data):
    k = data.shape[-1]
    data = flex_graph(data)
    data = iter(DataLoader(data, batch_size=data.size, shuffle=False)).next()
    outs = model(data)
    p, t = outs
    t = torch.where(t >= 0.5, 1., 0.)
    p = torch.where(p >= 0.5, 1., 0.)
    return p, t, data


# Exhaustive search related functions
@njit
def find_best_sum_rate(h, directions):
    k = np.shape(h)[-1]
    sum_rate = 0.
    # best_p = None
    # best_t = None
    for direction in directions:
        t = np.vstack((direction, 1 - direction)).T.ravel()
        p = wmmse(np.ones(int(k / 2)), get_h(h, t), 1., 1.)
        p_hat = np.zeros(k)
        for j in range(k):
            if t[j] == 1:
                p_hat[j] = p[0]
                p = p[1:]
        new_sum = get_directional_sum_rate(h, p_hat, t, 1., k)
        if new_sum >= sum_rate:
            sum_rate = new_sum
    return sum_rate


@njit
def get_h(h, t):
    return h[:, (1 - t).astype(bool_)][t.astype(bool_), :]


def get_direction_combinations(k):
    return np.array(list(product([0., 1.], repeat=k)), dtype=np.float32)


# Heuristic search related functions
@njit
def find_best_pat(h):
    k = np.shape(h)[-1]
    p = np.ones(k)
    prev_rate = 0.
    for i in range(100):
        _, t = pattern_search2(h, p, k)
        p_wmmse = wmmse(np.ones(int(k / 2)), get_h(h, t), 1., 1.)
        p_hat = np.zeros(k)
        p_hat_mask = np.ones(k, dtype=bool_)
        for j in range(k):
            if t[j] == 1:
                p_hat[j] = p_wmmse[0]
                p_wmmse = p_wmmse[1:]
        p = p_hat.copy()
        sum_rate = get_directional_sum_rate(h, p, t, 1., k)
        if sum_rate - prev_rate <= 1e-6:
            break
        prev_rate = sum_rate
    return sum_rate


@njit
def build_t_vec(t):
    return np.vstack((t, 1 - t)).T.ravel()


@njit
def pattern_search2(h, p, k):
    num_of_pairs = int(k / 2)
    max_rate = np.empty(num_of_pairs + 1)
    max_t = np.empty((num_of_pairs + 1, k))
    t_pattern = np.zeros((num_of_pairs + 1, num_of_pairs))
    for i in range(0, num_of_pairs):
        t_pattern[i + 1][i] = 1
    for j, pat in enumerate(t_pattern):
        max_idx = j
        t_pat = pat.copy()
        while True:
            sum_rates = np.empty(num_of_pairs + 1)
            sum_rates[0] = get_directional_sum_rate(h, p, build_t_vec(t_pat), 1., k)
            for i in range(num_of_pairs):
                t = t_pat.copy()
                t[i] = 1
                sum_rates[i + 1] = get_directional_sum_rate(h, p, build_t_vec(t), 1., k)
            index = np.argmax(sum_rates)
            if index == 0 or index == max_idx:
                break
            else:
                t_pat[index - 1] = 1
                max_idx = index
        max_rate[j] = get_directional_sum_rate(h, p, build_t_vec(t_pat), 1., k)
        max_t[j] = build_t_vec(t_pat)
    max_rate_idx = np.argmax(max_rate)
    return max_rate[max_rate_idx], max_t[max_rate_idx]


# max channel
@njit
def wmmse_with_max_channel(h):
    k = np.shape(h)[-1]
    t = max_channel(h)
    p = wmmse(np.ones(int(k / 2)), get_h(h, t), 1., 1.)
    p_hat = np.zeros(k)
    for j in range(k):
        if t[j] == 1:
            p_hat[j] = p[0]
            p = p[1:]
    return p_hat, t


@njit
def max_channel(h):
    k = np.shape(h)[-1]
    d = np.ones(int(k / 2))
    d_k = np.vstack((d, 1 - d)).T.ravel()
    left_ro_right = np.diag(get_h(h, d_k))
    right_to_left = np.diag(get_h(h, 1 - d_k))
    t = np.zeros((2, int(k / 2)))
    for i in range(int(k / 2)):
        if left_ro_right[i] >= right_to_left[i]:
            t[0, i] = 1.
        else:
            t[1, i] = 1.
    t = t.T.ravel()
    return t


@njit
def find_average_rate(h, algorithm, *args):
    avg_rate = 0.
    for mat in h:
        avg_rate += algorithm(mat, *args)
    return avg_rate / np.shape(h)[0]


@njit
def get_max_directions(h):
    n = np.shape(h)[0]
    k = np.shape(h)[-1]
    max_directions = np.empty((n, k))
    for i in range(n):
        max_directions[i] = max_channel(h[i])
    return max_directions


@njit
def get_max_dir_with_power(h):
    n = np.shape(h)[0]
    k = np.shape(h)[-1]
    max_directions = np.empty((n, k))
    power = np.empty((n, k))
    for i in range(n):
        power[i], max_directions[i] = wmmse_with_max_channel(h[i])
    return power, max_directions


@njit
def calculate_rates_and_average(h, p, t):
    n = np.shape(h)[0]
    k = np.shape(h)[-1]
    sum_rate = 0.
    for i in range(n):
        sum_rate += get_directional_sum_rate(h[i], p[i], t[i], 1., k)
    return sum_rate / n


def time_code(func, *args):
    start = time.perf_counter()
    result = func(*args)
    end = time.perf_counter()
    return end - start, result


def format_float(number):
    return f'{number:.4f}'

In [3]:
def evaluate_algorithms(h):
    n = h.shape[0]
    k = h.shape[-1]
    results = {}
    table = PrettyTable()
    table.field_names = ['Algorithm', 'Time', 'Performance']

    find_average_rate(h[:1, :, :], find_best_pat)
    time_heuristic, rate_heuristic = time_code(find_average_rate, h, find_best_pat)
    results['heuristic'] = (time_heuristic / n, rate_heuristic)
    table.add_row(['Heuristic Search', format_float(time_heuristic / n), format_float(rate_heuristic)])

    dirs = get_direction_combinations(int(h.shape[-1] / 2))
    find_average_rate(h[:1, :, :], find_best_sum_rate, dirs)
    time_exhaustive, rate_exhaustive = time_code(find_average_rate, h, find_best_sum_rate, dirs)
    results['exhaustive'] = (time_exhaustive / n, rate_exhaustive)
    # results['exhaustive'] = (0., 0.)
    table.add_row(['Exhaustive Search', format_float(time_exhaustive / n), format_float(rate_exhaustive)])

    get_max_directions(h[:1, :, :])
    time_max_channel, max_directions = time_code(get_max_directions, h)
    rate_max_direction = calculate_rates_and_average(h, np.ones((n, k)), max_directions)
    results['max_power'] = (time_max_channel / n, rate_max_direction)
    table.add_row(['Max Power', format_float(time_max_channel / n), format_float(rate_max_direction)])

    get_max_dir_with_power(h[:1, :, :])
    time_wmmse, (p_wmmse, d_wmmse) = time_code(get_max_dir_with_power, h)
    rate_wmmse = calculate_rates_and_average(h, p_wmmse, d_wmmse)
    results['wmmse_max'] = (time_wmmse / n, rate_wmmse)
    table.add_row(['WMMSE with max channel', format_float(time_wmmse / n), format_float(rate_wmmse)])

    path_ = f'./experiments/flexible_experiment_{h.shape[-1]}_nodes.pth'
    model = load_model(path_)
    eval_model(model, h[:1, :, :])
    time_gnn, (p, t, data) = time_code(eval_model, model, h)
    rate = create_loss_fn(h.shape[-1], 1.)
    rate_gnn = -rate((p, t), data.y).item()
    results['gnn'] = (time_gnn / n, rate_gnn)
    table.add_row(['Flex-Net', format_float(time_gnn / n), format_float(rate_gnn)])

    print('\n', h.shape[-1], '\033[1m' + '- nodes' + '\033[0m')
    print(table)
    return results


def evaluate_and_save_results():
    nodes = [4, 8, 12, 16, 20, 24, 28, 32]
    df = pd.DataFrame(
        columns=[
            'Nodes',
            'GNN Time',
            'GNN Performance',
            'Heuristic Time',
            'Heuristic Performance',
            'Exhaustive Time',
            'Exhaustive Performance',
            'Max Power Time',
            'Max Power Performance',
            'WMMSE with Max Power Time',
            'WMMSE with Max Power Performance',
        ])
    for i, val in enumerate(nodes):
        channel_data = generate_data(10, val)
        results = evaluate_algorithms(channel_data)
        df.loc[i] = [val] + list(results['gnn']) + list(results['heuristic']) + list(results['exhaustive']) + \
                    list(results['max_power']) + list(results['wmmse_max'])

    df.to_pickle('./experiments/perf_summary.pkl')


def plot_results():
    df = pd.read_pickle('./experiments/perf_summary.pkl')
    pd.options.display.max_columns = None
    pd.options.display.max_rows = None
    print(df.head(10))
    source_perf = df[
        ['Nodes', 'GNN Performance', 'Heuristic Performance', 'Exhaustive Performance', 'Max Power Performance',
         'WMMSE with Max Power Performance']] \
        .melt('Nodes',
              var_name='category',
              value_name='Performance')

    chart_perf = alt.Chart(source_perf).mark_line(point=True, interpolate='monotone').configure_point(
        size=50
    ).transform_joinaggregate(
        max='max(Performance)',
    ).transform_calculate(
        percent="datum.Performance / datum.max"
    ).encode(
        alt.X('Nodes:O', scale=alt.Scale(zero=False)),
        alt.Y('percent:Q', title='Performance', scale=alt.Scale(zero=False), axis=alt.Axis(format='.1%')),
        alt.Color('category:N',
                  legend=alt.Legend(orient='top-left', title=None),
                  scale=alt.Scale(scheme="dark2")),
    ).properties(
        width=400
    )

    source_time = df[['Nodes', 'GNN Time', 'Heuristic Time', 'Exhaustive Time', 'Max Power Time',
                      'WMMSE with Max Power Time']] \
        .melt('Nodes',
              var_name='category',
              value_name='Time')

    chart_time = alt.Chart(source_time).mark_line(point=True, interpolate='monotone').configure_point(
        size=50
    ).encode(
        alt.X('Nodes:O', scale=alt.Scale(zero=False), axis=alt.Axis(tickCount=8)),
        alt.Y('Time:Q', title='Time (s)', scale=alt.Scale(type='log')),
        alt.Color('category:N',
                  legend=alt.Legend(orient='top-left', title=None),
                  scale=alt.Scale(scheme="dark2")),
    ).properties(
        width=400
    )
    return chart_perf, chart_time

In [4]:
# evaluate_and_save_results()

In [5]:
chart_perf, chart_time = plot_results()

   Nodes  GNN Time  GNN Performance  Heuristic Time  Heuristic Performance  \
0    4.0  0.000229         1.791022        0.000113               1.771709   
1    8.0  0.000251         2.495414        0.000550               2.460771   
2   12.0  0.000283         2.920749        0.001467               2.884021   
3   16.0  0.000343         3.265862        0.003362               3.225337   
4   20.0  0.000401         3.519804        0.006781               3.489375   
5   24.0  0.000491         3.745244        0.012289               3.707267   
6   28.0  0.000552         3.931401        0.022320               3.901677   
7   32.0  0.000710         4.091971        0.037172               4.076983   

   Exhaustive Time  Exhaustive Performance  Max Power Time  \
0         0.000030                1.798028        0.000002   
1         0.000282                2.548500        0.000003   
2         0.001931                3.004615        0.000003   
3         0.012367                3.366040       

In [6]:
chart_perf

In [7]:
chart_time

### Plot in SVG format for viewing on Github

<img src="https://user-images.githubusercontent.com/17931435/213001877-d6b2fd74-dd3b-4987-b421-9ae9cdbd3a38.svg" width="200">
<img src="https://user-images.githubusercontent.com/17931435/213001470-8ef72cdc-4b39-4af6-9325-4ed355dbb843.svg" width="200">