# VISUALIZE THE PROPORTIONAL, UNIFORM, AND EXTREME SLICING BASELINES
We run several baselines to set feasible and meaningful constraints for the state-augmented-slicing algorithm.

In [1]:
import json
import os
import glob
import torch
import re
import copy
import random
import tqdm
from collections import defaultdict
from matplotlib import pyplot as plt
import seaborn as sns
import time

In [2]:
# Load one pre-trained model and experiment results
save_path = 'visualization_baselines'
# experiment_root_path = "results/n_20_T_slices_{'train': 100, 'test': 500}_num_samples_{'train': 2, 'test': 1}/1700455656.290245 -- TEST VIOLATION RATES"
experiment_root_path = "results/n_20_T_slices_{'train': 100, 'test': 500}_num_samples_{'train': 16, 'test': 4}/1700457654.2167933 -- TEST VIOLATION RATES"

save_path = f'{save_path}/{experiment_root_path}'

experiment_name = re.search('results/(.+?)/', experiment_root_path).group(1)

# Load train chkpt
train_chkpt_root_path = f'{experiment_root_path}/train_chkpts/train_chkpt_'
file_type = r'*.pt'
files = glob.glob(train_chkpt_root_path + file_type)
latest_train_chkpt_path = max(files, key=os.path.getctime)
train_chkpt = torch.load(latest_train_chkpt_path)

# Load model chkpt
model_chkpt_root_path = f'{experiment_root_path}/train_chkpts/model_chkpt_'
file_type = r'*.pt'
files = glob.glob(model_chkpt_root_path + file_type)
latest_model_chkpt_path = max(files, key=os.path.getctime)
model_chkpt = torch.load(latest_model_chkpt_path)

# Load all epoch results chkpt
all_epoch_results_chkpt_root_path = f'{experiment_root_path}/train_chkpts/all_epoch_results_chkpt_'
file_type = r'*.pt'
files = glob.glob(all_epoch_results_chkpt_root_path + file_type)
latest_all_epoch_results_chkpt_path = max(files, key=os.path.getctime)
all_epoch_results_chkpt = torch.load(latest_all_epoch_results_chkpt_path)
# with open(latest_all_epoch_results_chkpt_path, 'r') as outfile:
#    all_epoch_results_chkpt = json.load(outfile)

In [3]:
from core.model import MLP
from core.Slice import Slice
from core.data_gen import create_data
from core.utils import make_test_configs, find_substring_index, seed_everything, make_experiment_name, create_network_configs,\
make_feature_extractor, make_constraint_fncs_and_lambda_samplers, make_eval_fnc, make_logger
from core.StateAugmentation import StateAugmentedSlicingAlgorithm

In [4]:
# How many more iterations do we want to run the pretrained model
n_epochs_more = 1

In [5]:
args = copy.deepcopy(train_chkpt['config'])
seed_everything(args.random_seed)

# create a string indicating the main experiment (hyper)parameters
experiment_name = make_experiment_name(args)
args.save_dir = experiment_root_path

args.channel_data_save_load_path = f'{args.root}/{args.save_dir}/channel_data'
args.traffic_data_save_load_path = f'{args.root}/{args.save_dir}/traffic_data'
    
# Create more folders and save the parsed configuration
os.makedirs(f'{save_path}', exist_ok=True)

# os.makedirs(f'{args.root}/results/{args.save_dir}/plots', exist_ok=True)
# os.makedirs(f'{args.root}/results/{args.save_dir}/train_chkpts', exist_ok=True)
# os.makedirs(f'{args.channel_data_save_load_path}', exist_ok=True)
# os.makedirs(f'{args.traffic_data_save_load_path}', exist_ok=True)
# with open(f'{args.root}/results/{args.save_dir}/config.json', 'w') as f:
#     json.dump(vars(args), f, indent = 6)

# Create network configs to initialize wireless networks
network_configs = create_network_configs(args)


# Change the constraint specifications
r_min_violation_rate = args.r_min_violation_rate # 0.05
l_max_violation_rate = args.l_max_violation_rate # 0.05
r_min = args.r_min
l_max = args.l_max

subpath = ''
if not r_min_violation_rate == args.r_min_violation_rate:
    subpath += f'_rate_violation_rate_{r_min_violation_rate}'
if not l_max_violation_rate == args.l_max_violation_rate:
    subpath += f'_latency_violation_rate_{l_max_violation_rate}'
if not r_min == args.r_min:
    subpath += f'_r_min_{r_min}'
if not l_max == args.l_max:
    subpath += f'_l_max_{l_max}'
if not subpath == '':
    subpath = '/' + subpath
print('Subpath: ', subpath)

args.r_min = r_min
args.l_max = l_max
args.r_min_violation_rate = r_min_violation_rate
args.l_max_violation_rate = l_max_violation_rate


# Create feature extractor, obj and constraint eval functions
feature_extractor, n_features = make_feature_extractor(['slice-weight', 'slice-avg-data-rate'], args)
obj = make_eval_fnc(eval_type = 'obj-mean-rate', eval_slices = [Slice.BE], args=args)

constraints, lambda_samplers = make_constraint_fncs_and_lambda_samplers(args)

args.num_features_list = [n_features + len(constraints)] + args.num_features_list[1:]

# set the computation device and create the model using a GNN parameterization
args.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print('Device: ', args.device)
model = MLP(args.num_features_list, lambda_transform = args.lambda_transform, batch_norm = args.batch_norm).to(args.device)
model.load_state_dict(model_chkpt['model_state_dict'])
print('Model state dicts loaded...')

# Log print statements to a logs.txt file
loggers = [make_logger(f'{save_path}/logs.txt')]

# Load all epoch results
all_epoch_results = train_chkpt['all_epoch_results']
print('All epoch results loaded...')
print('The state-augmented model has been trained for {} epochs.'.format(len(all_epoch_results['train_state'])))

sa_learner = StateAugmentedSlicingAlgorithm(model=model,
                                            config=args,
                                            network_configs=network_configs,
                                            feature_extractor=feature_extractor,
                                            loggers=loggers,
                                            obj=obj,
                                            constraints=constraints,
                                            lambda_samplers = lambda_samplers,
                                            all_epoch_results=all_epoch_results)

Subpath:  
Device:  cuda:0
Model state dicts loaded...
All epoch results loaded...
The state-augmented model has been trained for 20 epochs.


## Run Test Phase

In [13]:
# Make test configs
test_configs = []  

test_config = copy.deepcopy(args)
test_config.name = 'mono-slicing-HT'
test_config.slicing_strategy = 'mono-slicing-HT'
# test_config.constraint_rate_violation_rate = constraint_rate_violation_rate
test_configs.append(test_config)

test_config = copy.deepcopy(args)
test_config.name = 'mono-slicing-LL'
test_config.slicing_strategy = 'mono-slicing-LL'
# test_config.constraint_rate_violation_rate = constraint_rate_violation_rate
test_configs.append(test_config)

test_config = copy.deepcopy(args)
test_config.name = 'proportional-slicing'
test_config.slicing_strategy = 'proportional'
# test_config.constraint_rate_violation_rate = constraint_rate_violation_rate
# test_config.test_on_train_data = True
# test_config.dual_test_init_strategy = 'zeros'
test_configs.append(test_config)

test_config = copy.deepcopy(args)
test_config.name = 'uniform-slicing'
test_config.slicing_strategy = 'uniform'
# test_config.constraint_rate_violation_rate = constraint_rate_violation_rate
test_configs.append(test_config)


test_config = copy.deepcopy(args)
test_config.name = 'state-augmented-slicing'
test_config.slicing_strategy = 'state-augmented'
# test_config.test_on_train_data = True
test_config.dual_test_init_strategy = 'zeros'
# test_config.constraint_rate_violation_rate = constraint_rate_violation_rate
test_configs.append(test_config)

In [14]:
print('Testing sa_learner...')
test_metrics_list = []
test_metrics_over_time_list = []
for test_config in test_configs:
    # if test_config.dual_test_init_strategy == 'mean-regression':
    #     test_config.name = 'mean_init'
    # if test_config.dual_test_init_strategy == 'optimal-regression':
    #     test_config.name = 'optimal_init'
    # if test_config.dual_test_init_strategy == 'zeros':
    #     test_config.name = 'zeros_init'
    print(f'Testing {test_config.name}')
    try:
        test_epoch = sa_learner.all_epoch_results['test_state', test_config.name][-1].epoch + 1
    except:
        test_epoch = len(sa_learner.all_epoch_results['train_state']) + 1
    test_metrics = sa_learner.test(epoch=test_epoch, test_config=test_config) # test metric
    test_metrics_list.append([test_metrics])
    test_metrics_over_time_list.append(test_metrics.metrics_over_slices) # append test metrics over time

Testing sa_learner...
Testing mono-slicing-HT


  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:55<00:00,  8.98it/s]


Testing mono-slicing-LL


100%|██████████| 500/500 [00:57<00:00,  8.64it/s]


Testing proportional-slicing


100%|██████████| 500/500 [00:58<00:00,  8.51it/s]


Testing uniform-slicing


100%|██████████| 500/500 [00:59<00:00,  8.46it/s]


Testing state-augmented-slicing


100%|██████████| 500/500 [00:58<00:00,  8.57it/s]


## Plot test metrics

In [15]:
print('Plotting test evolution over epochs...')
sa_learner.plot_test_evolution_over_epochs(test_metrics_list,
                                        test_config_names=[_.name for _ in test_configs],
                                        save_path=save_path + subpath + '/plots/test_evolution/',
                                        plot_actual_metrics = True)
            
k_networks = sa_learner.config.num_samples['test']
print('Plotting test evolution over timesteps...')
sa_learner.plot_test_evolution_over_slices(test_metrics_list,
                                            test_config_names = [_.name for _ in test_configs],
                                            save_path=save_path + subpath + '/plots/test_evolution_over_slices/',
                                            network_idx = random.sample(range(sa_learner.config.num_samples['test']), k = min(k_networks, sa_learner.config.num_samples['test'])),
                                            plot_actual_metrics = True)

Plotting test evolution over epochs...
Plotting test evolution over timesteps...


<Figure size 640x480 with 0 Axes>