In [None]:
!git clone https://github.com/saeedzou/MetaOffload &> /dev/null
%cd /content/MetaOffload
!pip install -q gym==0.14.0 &> /dev/null
!pip install wandb -qU &> /dev/null
!pip install -q pydotplus &> dev\null
!mkdir ./models
!mkdir ./logs

In [1]:
import os
import json
import logging

import numpy as np
import torch

from copy import deepcopy
from tqdm import tqdm
from env.mec_offloaing_envs.offloading_env import Resources
from env.mec_offloaing_envs.offloading_env import OffloadingEnvironment
from models import GraphSeq2Seq, BaselineSeq2Seq
from buffer import RolloutBuffer
from train import inner_loop, outer_loop
from utils import log_metrics

In [None]:
with open('my_config.json') as f:
    args = json.load(f)

class Config:
        def __init__(self, dictionary):
            for key, value in dictionary.items():
                setattr(self, key, value)

c = Config(args)

log_path = (f"obs_{c.obs_dim}_"
            f"h_{c.encoder_units}_"
            f"m_{c.adaptation_steps}_"
            f"g_{c.is_graph}_"
            f"att_{c.is_attention}_"
            f"mec_{c.mec_process_capable}_"
            f"mob_{c.mobile_process_capable}_"
            f"ul_{c.bandwidth_up}_"
            f"dl_{c.bandwidth_down}_"
            f"olr_{c.outer_lr}_"
            f"ilr_{c.inner_lr}_"
            f"n_{c.num_iterations}_"
            f"ibs_{c.inner_batch_size}_"
            f"mgn_{c.max_grad_norm}_"
            f"vf_{c.vf_coef}_"
            f"epis_{c.num_task_episodes}")
logger_path = "./logs/"+log_path+'.log'
logger = logging.getLogger(__name__)
logging.basicConfig(filename=logger_path,
                    filemode='w',
                    format='%(message)s',
                    level=logging.DEBUG,
                    force=True)

np.random.seed(c.seed)
torch.manual_seed(c.seed)

In [3]:
resources = Resources(mec_process_capable=c.mec_process_capable*1024*1024,
                      mobile_process_capable=c.mobile_process_capable*1024*1024,
                      bandwidth_up=c.bandwidth_up,
                      bandwidth_dl=c.bandwidth_down)

env = OffloadingEnvironment(resource_cluster=resources,
                            batch_size=c.graph_number,
                            graph_number=c.graph_number,
                            graph_file_paths=c.graph_file_paths,
                            time_major=False,
                            encoding=c.encoding)

In [6]:
if c.wandb:
    import wandb
    wandb.login(key=c.wandb_key)
    wandb.init(project=c.wandb_project,
               name=log_path,
               config=c)

print(f'Average greedy latency: {np.mean(env.greedy_solution()[1]):.4f}')
print(f'Average all local latency: {np.mean(env.get_all_locally_execute_time()):.4f}')
print(f'Average all mec latency: {np.mean(env.get_all_mec_execute_time()):.4f}')

device = c.device

if c.is_graph:
    policy = GraphSeq2Seq(input_dim=c.obs_dim,
                          hidden_dim=c.encoder_units,
                          output_dim=c.action_dim,
                          num_layers=c.num_layers,
                          device=device).to(device)
else:
    policy = BaselineSeq2Seq(input_dim=c.obs_dim,
                             hidden_dim=c.encoder_units,
                             output_dim=c.action_dim,
                             num_layers=c.num_layers,
                             device=device).to(device)
if c.load:
    policy.load_state_dict(torch.load(c.load_path))
if c.save:
    if not os.path.exists(os.path.join(c.save_path, log_path)):
        os.makedirs(os.path.join(c.save_path, log_path))

buffer = RolloutBuffer(meta_batch_size=c.meta_batch_size, 
                       buffer_size=c.graph_number*c.num_task_episodes, 
                       discount=c.gamma, 
                       gae_lambda=c.tau, 
                       device=device)

outer_optimizer = torch.optim.Adam(policy.parameters(), lr=c.outer_lr)

for iteration in tqdm(range(c.start_iter, c.num_iterations), leave=False, disable=True):
    task_policies = []
    fts_before, fts_after = [], []
    vf_losses, pg_losses = [], []
    all_rewards, all_returns = [], []
    
    batch_of_tasks = env.sample_tasks(c.meta_batch_size)
    
    ### Sample trajectories ###
    buffer.reset()
    for i, task_id in tqdm(enumerate(batch_of_tasks), leave=False, total=c.meta_batch_size, desc=f'Sampling trajectories'):
        buffer.collect_episodes(env=env, 
                                policy=policy, 
                                device=device, 
                                meta_batch=i, 
                                task_id=task_id, 
                                is_graph=c.is_graph)
    buffer.process()
    
    ### Inner loop ###
    for i, task_id in enumerate(batch_of_tasks):
        clone = deepcopy(policy).to(device)
        inner_optimizer = torch.optim.Adam(clone.parameters(), lr=c.inner_lr)
        if iteration == 0 and i == 0:
            inner_optimizer_state_dict = inner_optimizer.state_dict()
        inner_optimizer.load_state_dict(inner_optimizer_state_dict)
        vf_loss, pg_loss, fts, clone = \
            inner_loop(policy=clone, 
                       optimizer=inner_optimizer, 
                       buffer=buffer, 
                       meta_batch=i, 
                       task_id=task_id, 
                       hparams=c)
        vf_losses.append(vf_loss)
        pg_losses.append(pg_loss)
        fts_before.append(fts)
        task_policies.append(clone)
        inner_optimizer_state_dict = inner_optimizer.state_dict()
    
    ### Evaluate trajectories ###
    buffer.reset()
    for i, task_id in tqdm(enumerate(batch_of_tasks), leave=False, total=c.meta_batch_size, desc=f'Evaluating trajectories'):
        buffer.collect_episodes(env=env, 
                                policy=task_policies[i], 
                                device=device, 
                                meta_batch=i, task_id=task_id, 
                                is_graph=c.is_graph)
    buffer.process()
    
    ### Log metrics ###
    avg_vf_losses = np.mean(vf_losses)
    avg_pg_losses = np.mean(pg_losses)
    avg_rewards = np.mean([reward.sum(-1) for reward in buffer.rewards])
    avg_returns = np.mean([returns[:, 0].mean().item() for returns in buffer.returns])
    avg_fts_before = np.mean(np.concatenate(fts_before))
    avg_fts_after = np.mean(np.concatenate(buffer.finish_times))

    log_metrics(logger=logger,
                iteration=iteration,
                vf_losses=avg_vf_losses,
                pg_losses=avg_pg_losses,
                rewards=avg_rewards,
                returns=avg_returns,
                finish_times_old=avg_fts_before,
                finish_times_new=avg_fts_after,
                wandb=wandb)

    outer_loop(meta_policy=policy, 
               task_policies=task_policies, 
               outer_optimizer=outer_optimizer,
               hparams=c)
    if c.save and iteration % c.save_every == 0:
        torch.save(policy.state_dict(), 
                   os.path.join(c.save_path, log_path, f'policy_{iteration}.pt'))