In [None]:
!git clone https://github.com/saeedzou/MetaOffload &> /dev/null
%cd /content/MetaOffload
!pip install -q gym==0.14.0 &> /dev/null
!pip install wandb -qU &> /dev/null
!pip install -q pydotplus &> dev\null

In [1]:
import os
import json
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical
from copy import deepcopy
from tqdm import tqdm
from env.mec_offloaing_envs.offloading_env import Resources
from env.mec_offloaing_envs.offloading_env import OffloadingEnvironment
from models import GraphSeq2Seq, BaselineSeq2Seq
from buffer import RolloutBuffer
from train import inner_loop

In [2]:
with open('baseline_config.json') as f:
    args = json.load(f)

class Config:
        def __init__(self, dictionary):
            for key, value in dictionary.items():
                setattr(self, key, value)

c = Config(args)

np.random.seed(c.seed)
torch.manual_seed(c.seed)

In [3]:
resources = Resources(mec_process_capable=c.mec_process_capable*1024*1024,
                      mobile_process_capable=c.mobile_process_capable*1024*1024,
                      bandwidth_up=c.bandwidth_up,
                      bandwidth_dl=c.bandwidth_down)

env = OffloadingEnvironment(resource_cluster=resources,
                            batch_size=c.graph_number,
                            graph_number=c.graph_number,
                            graph_file_paths=["./env/mec_offloaing_envs/data/meta_offloading_20/offload_random20_12/random.20."],
                            time_major=False,
                            encoding=c.encoding)

In [None]:
if c.wandb:
    import wandb
    wandb.login(key=c.wandb_key)
    wandb.init(project=c.wandb_project,
               name=c.wandb_name,
               config=c)

print(f'Average greedy latency: {np.mean(env.greedy_solution()[1]):.4f}')
print(f'Average all local latency: {np.mean(env.get_all_locally_execute_time()):.4f}')
print(f'Average all mec latency: {np.mean(env.get_all_mec_execute_time()):.4f}')

device = c.device
c.meta_batch_size = 1
latencies = []

if c.is_graph:
    policy = GraphSeq2Seq(input_dim=c.obs_dim,
                          hidden_dim=c.encoder_units,
                          output_dim=c.action_dim,
                          num_layers=c.num_layers,
                          device=device).to(device)
else:
    policy = BaselineSeq2Seq(input_dim=c.obs_dim,
                             hidden_dim=c.encoder_units,
                             output_dim=c.action_dim,
                             num_layers=c.num_layers,
                             device=device).to(device)
if args["load"]:
    policy.load_state_dict(torch.load(args["load_path"]))

buffer = RolloutBuffer(meta_batch_size=c.meta_batch_size, 
                       buffer_size=c.graph_number*c.num_task_episodes, 
                       discount=c.gamma, 
                       gae_lambda=c.tau, 
                       device=device)
optimizer = torch.optim.Adam(policy.parameters(), lr=c.inner_lr)

for iteration in tqdm(range(c.start_iter, c.num_iterations), leave=False, disable=True):
    task_policies = []
    fts_before, fts_after = [], []
    vf_losses, pg_losses = [], []
    all_rewards, all_returns = [], []
    
    batch_of_tasks = env.sample_tasks(c.meta_batch_size)


    ### Sample trajectories ###
    buffer.reset()
    for i, task_id in tqdm(enumerate(batch_of_tasks), leave=False, total=c.meta_batch_size, desc=f'Sampling trajectories'):
        buffer.collect_episodes(env=env, 
                                policy=policy, 
                                device=device, 
                                meta_batch=i, 
                                task_id=task_id, 
                                is_graph=c.is_graph)
    buffer.process()
    
    for i, task_id in enumerate(batch_of_tasks):
        vf_loss, pg_loss, fts, policy = \
            inner_loop(policy=policy, 
                       optimizer=optimizer, 
                       buffer=buffer, 
                       meta_batch=i, 
                       task_id=task_id, 
                       hparams=c)
        vf_losses.append(vf_loss)
        pg_losses.append(pg_loss)
        fts_before.append(fts)
        task_policies.append(policy)

    print('*'*50)
    latencies.append(np.mean(np.concatenate(fts_before)))
    print("Iteration", iteration,
        "| vf_loss: {:.4f}".format(np.mean(vf_losses)),
        "| pg_loss: {:.4f}".format(np.mean(pg_losses)),
        "| average_reward: {:.4f}".format(np.mean([reward.sum(-1) for reward in buffer.rewards])),
        "| average_return: {:.4f}".format(np.mean([returns[:, 0].mean().item() for returns in buffer.returns])),
        "| latency before adaptation: {:.4f}".format(np.mean(np.concatenate(fts_before))))
