In [1]:
import os
import json
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical
from copy import deepcopy
from tqdm import tqdm
from env.mec_offloaing_envs.offloading_env import Resources
from env.mec_offloaing_envs.offloading_env import OffloadingEnvironment
from models import GraphSeq2Seq, BaselineSeq2Seq
from buffer import RolloutBuffer
from train import inner_loop

%load_ext autoreload
%autoreload 2

In [2]:
with open('config.json') as f:
    args = json.load(f)

np.random.seed(args['seed'])
torch.manual_seed(args['seed'])

if args["wandb"]:
    import wandb
    wandb.login(key=args["wandb_key"])

In [None]:
resources = Resources(mec_process_capable=args["mec_process_capable"]*1024*1024,
                        mobile_process_capable=args["mobile_process_capable"]*1024*1024,
                        bandwidth_up=args["bandwidth_up"],
                        bandwidth_dl=args["bandwidth_down"])

env = OffloadingEnvironment(resource_cluster=resources,
                            batch_size=args['graph_number'],
                            graph_number=args['graph_number'],
                            graph_file_paths=["./env/mec_offloaing_envs/data/meta_offloading_20/offload_random20_12/random.20."],
                            time_major=False)

In [None]:
if args["wandb"]:
    wandb.init(project="mec-offloading-phase-1",
            name='experiment',
            config=args)

print(f'Average greedy latency: {np.mean(env.greedy_solution()[1]):.4f}')
print(f'Average all local latency: {np.mean(env.get_all_locally_execute_time()):.4f}')
print(f'Average all mec latency: {np.mean(env.get_all_mec_execute_time()):.4f}')

inner_lr = args["inner_lr"]
outer_lr = args["outer_lr"]
adapt_steps = args["adaptation_steps"]
inner_bs = args["inner_batch_size"]
meta_batch_size = 1
num_iterations = args["num_iterations"]
vf_coef = args["vf_coef"]
vf_is_clipped = args["vf_is_clipped"]
ent_coef = args["ent_coef"]
clip_eps = args["clip_eps"]
is_graph = args["is_graph"]
device = args["device"]
num_task_episodes = args["num_task_episodes"]
gamma = args["gamma"]
tau = args["tau"]
graph_number = args["graph_number"]
latencies = []

if is_graph:
    policy = GraphSeq2Seq(15, 128, 2, 2, device).to(device)
else:
    policy = BaselineSeq2Seq(15, 128, 2, 2, device).to(device)
if args["load"]:
    policy.load_state_dict(torch.load(args["load_path"]))
buffer = RolloutBuffer(meta_batch_size, graph_number*num_task_episodes, discount=gamma, gae_lambda=tau, device=device)
optimizer = torch.optim.Adam(policy.parameters(), lr=inner_lr)

for iteration in tqdm(range(num_iterations), leave=False, disable=True):
    task_policies = []
    fts_before, fts_after = [], []
    vf_losses, pg_losses, ent_losses = [], [], []
    all_rewards, all_returns = [], []
    
    batch_of_tasks = env.sample_tasks(meta_batch_size)

    buffer.reset()
    print('sampling trajectories')
    for i, task_id in enumerate(batch_of_tasks):
        buffer.collect_episodes(env, policy, device, i, task_id, is_graph)
    buffer.process()
    
    for i, task_id in enumerate(batch_of_tasks):
        vf_loss, pg_loss, ent_loss, fts, policy = inner_loop(policy, optimizer, buffer, i, task_id, inner_bs, adapt_steps, clip_eps, vf_coef, ent_coef, vf_is_clipped, is_graph)
        vf_losses.append(vf_loss)
        pg_losses.append(pg_loss)
        ent_losses.append(ent_loss)
        fts_before.append(fts)
        task_policies.append(policy)

    print('*'*50)
    latencies.append(np.mean(np.concatenate(fts_before)))
    print("Iteration", iteration,
        "| vf_loss: {:.4f}".format(np.mean(vf_losses)),
        "| pg_loss: {:.4f}".format(np.mean(pg_losses)),
        "| ent_loss: {:.4f}".format(np.mean(ent_losses)),
        "| average_reward: {:.4f}".format(np.mean([reward.sum(-1) for reward in buffer.rewards])),
        "| average_return: {:.4f}".format(np.mean([returns[:, 0].mean().item() for returns in buffer.returns])),
        "| latency before adaptation: {:.4f}".format(np.mean(np.concatenate(fts_before))))
