# Necessary imports and setup

In [None]:
!git clone https://github.com/saeedzou/MetaOffload &> /dev/null
%cd /content/MetaOffload
!pip install -q gym==0.14.0 &> /dev/null
!pip install wandb -qU &> /dev/null
!pip install -q pydotplus &> dev\null
!mkdir ./models
!mkdir ./logs

In [1]:
import os
import json
import logging

import numpy as np
import torch

from copy import deepcopy
from tqdm import tqdm
from env.mec_offloaing_envs.offloading_env import Resources
from env.mec_offloaing_envs.offloading_env import OffloadingEnvironment
from models import GraphSeq2SeqDual, BaselineSeq2SeqDual
from buffer import SingleRolloutBufferPPG, SingleAuxBuffer
from train import outer_loop
from utils import log_metrics
%load_ext autoreload
%autoreload 2

# Loading config file and logging

In [None]:
with open('my_config_ppg_val.json') as f:
    args = json.load(f)

class Config:
        def __init__(self, dictionary):
            for key, value in dictionary.items():
                setattr(self, key, value)

c = Config(args)
device = c.device
np.random.seed(c.seed)
torch.manual_seed(c.seed)

In [8]:
log_path = (f"ppg_obs_{c.obs_dim}_"
            f"h_{c.encoder_units}_"
            f"nhl_{c.num_layers}_"
            f"mbs_{c.meta_batch_size}_"
            f"Npi_{c.N_pi}_"
            f"E_aux_{c.E_aux}"
            f"g_{c.is_graph}_"
            f"gt_{c.graph_type}_"
            f"n_{c.num_iterations}_"
            f"ibs_{c.inner_batch_size}_"
            f"mgn_{c.max_grad_norm}_"
            f"vfclip_{c.vf_is_clipped}_"
            f"epis_{c.num_task_episodes}"
            f"att_{c.is_attention}_"
            f"seed_{c.seed}_"
            f"olr_{c.outer_lr}_"
            f"ilr_{c.inner_lr}_"
            f"mec_{c.mec_process_capable}_"
            f"mob_{c.mobile_process_capable}_"
            f"ul_{c.bandwidth_up}_"
            f"dl_{c.bandwidth_down}_")
logger_path = "./logs/"+log_path+'.log'
logger = logging.getLogger(__name__)
logging.basicConfig(filename=logger_path,
                    filemode='w',
                    format='%(message)s',
                    level=logging.DEBUG,
                    force=True)
if c.wandb:
    import wandb
    wandb.login(key=c.wandb_key)
    wandb.init(project=c.wandb_project,
               name=log_path,
               config=c)
else:
    wandb = None

if c.save:
    if not os.path.exists(os.path.join(c.save_path, log_path)):
        os.makedirs(os.path.join(c.save_path, log_path))

# Loading environment

In [None]:
resources = Resources(mec_process_capable=c.mec_process_capable*1024*1024,
                      mobile_process_capable=c.mobile_process_capable*1024*1024,
                      bandwidth_up=c.bandwidth_up,
                      bandwidth_dl=c.bandwidth_down)

env = OffloadingEnvironment(resource_cluster=resources,
                            batch_size=c.graph_number,
                            graph_number=c.graph_number,
                            graph_file_paths=c.graph_file_paths,
                            time_major=False,
                            encoding=c.encoding)

print(f'Average greedy latency: {np.mean(env.greedy_solution()[1]):.4f}')
print(f'Average all local latency: {np.mean(env.get_all_locally_execute_time()):.4f}')
print(f'Average all mec latency: {np.mean(env.get_all_mec_execute_time()):.4f}')

# Loading model and initializing buffer and optimizer

# Train Loop

In [10]:
if c.is_graph:
    policy_net = GraphSeq2SeqDual(input_dim=c.obs_dim,
                                  hidden_dim=c.encoder_units,
                                  output_dim=c.action_dim,
                                  num_layers=c.num_layers,
                                  device=device,
                                  is_attention=c.is_attention,
                                  graph=c.graph_type,
                                  arch='policy').to(device)
    value_net = GraphSeq2SeqDual(input_dim=c.obs_dim,
                                 hidden_dim=c.encoder_units,
                                 output_dim=c.action_dim,
                                 num_layers=c.num_layers,
                                 device=device,
                                 is_attention=c.is_attention,
                                 graph=c.graph_type,
                                 arch='value').to(device)
else:
    policy_net = BaselineSeq2SeqDual(input_dim=c.obs_dim,
                                     hidden_dim=c.encoder_units,
                                     output_dim=c.action_dim,
                                     num_layers=c.num_layers,
                                     device=device,
                                     is_attention=c.is_attention,
                                     arch='policy').to(device)
    value_net = BaselineSeq2SeqDual(input_dim=c.obs_dim,
                                    hidden_dim=c.encoder_units,
                                    output_dim=c.action_dim,
                                    num_layers=c.num_layers,
                                    device=device,
                                    is_attention=c.is_attention,
                                    arch='value').to(device)


buffer = SingleRolloutBufferPPG(buffer_size=c.graph_number*c.num_task_episodes, 
                                discount=c.gamma, 
                                gae_lambda=c.tau,
                                device=device)
aux_buffer = SingleAuxBuffer(device=device)

inner_optimizer_pi = torch.optim.Adam(policy_net.parameters(), lr=c.inner_lr)
inner_optimizer_v = torch.optim.Adam(value_net.parameters(), lr=c.inner_lr)

In [None]:
all_fts = []
for iteration in tqdm(range(c.start_iter, c.num_iterations), leave=False, disable=True):
    aux_buffer.reset()
    print(f'Iteration {iteration}'.center(80, '-'))
    print(f'Policy phase'.center(80, ' '))
    for p_iteration in range(c.N_pi): # policy phase
        # perform rollouts under current policy pi
        buffer.reset()
        buffer.collect_episodes(env=env, policy_net=policy_net, value_net=value_net, device=device, task_id=0, is_graph=c.is_graph)
        buffer.process_task()
        vf_losses, pg_losses = 0.0, 0.0
        print(f'pi iteration {p_iteration}', end=' | ')
        print(f'Average reward: {np.mean(buffer.rewards.sum(-1)):.4f}', end=' | ')
        print(f'Average return: {np.mean(buffer.returns[:, 0]):.4f}', end=' | ')
        print(f'Average fts: {np.mean(buffer.finish_times):.4f}')
        all_fts.append(np.mean(buffer.finish_times))
        # save state values to aux buffer
        aux_buffer.store_data(buffer)
        observations, adjs, actions, logits, v_olds, advantages, rewards, returns, fts = buffer.sample(batch_size=c.inner_batch_size)
        print(f'Policy epochs'.center(80, ' '))
        for epoch in range(c.E_pi): # policy epochs
            for observation, adj, action, old_logit, v_old, advantage, return_ in zip(observations, adjs, actions, logits, v_olds, advantages, returns):
                if c.is_graph:
                    _, new_logit, _ = policy_net(observation, adj, action)
                else:
                    _, new_logit, _ = policy_net(observation, action)
                action = action.type(torch.int64)
                new_logit_a = new_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
                old_logit_a = old_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
                ratio = torch.exp(torch.log(new_logit_a) - torch.log(old_logit_a))
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-c.clip_eps, 1+c.clip_eps) * advantage
                pi_loss = -torch.min(surr1, surr2).mean()
                pg_losses += pi_loss.item()
                inner_optimizer_pi.zero_grad()
                pi_loss.backward()
                inner_optimizer_pi.step()
        print(f'Average pg loss: {pg_losses/c.E_pi:.4f}')

        print(f'Value epochs'.center(80, ' '))
        for epoch in range(c.E_v): # value epochs
            for observation, adj, action, old_logit, v_old, advantage, return_ in zip(observations, adjs, actions, logits, v_olds, advantages, returns):
                if c.is_graph:
                    _, _, v_pred = value_net(observation, adj, action)
                else:
                    _, _, v_pred = value_net(observation, action)
                if c.vf_is_clipped:
                    v_pred_clipped = v_pred + (v_pred - v_old).clamp(-c.clip_eps, c.clip_eps)
                    v_loss = 0.5 * torch.max((v_pred - return_).pow(2), (v_pred_clipped - return_).pow(2)).mean()
                else:
                    v_loss = 0.5 * (v_pred - return_).pow(2).mean()

                inner_optimizer_v.zero_grad()
                v_loss.backward()
                vf_losses += v_loss.item()
                inner_optimizer_v.step()
        print(f'Average vf loss: {vf_losses/c.E_v:.4f}')
    # compute and store updated policy logits for all states in aux buffer
    aux_buffer.compute_logits(policy_net, c.is_graph)
    aux_buffer.process_task()
    observations, adjs, actions, logits, Vs, returns = aux_buffer.sample(batch_size=c.inner_batch_size)
    # auxilary phase
    print(f'Auxilary phase'.center(80, ' '))
    for epoch in range(c.E_aux):
        aux_kl_losses, aux_aux_losses, aux_vf_losses = 0.0, 0.0, 0.0
        print(f'aux epoch {epoch}', end=' | ')
        for observation, adj, action, old_logit, v_old, return_ in zip(observations, adjs, actions, logits, Vs, returns):
            if c.is_graph:
                _, new_logit, v_pred = policy_net(observation, adj, action)
            else:
                _, new_logit, v_pred = policy_net(observation, action)
            if c.vf_is_clipped:
                v_pred_clipped = v_pred + (v_pred - v_old).clamp(-c.clip_eps, c.clip_eps)
                aux_loss = 0.5 * torch.max((v_pred - return_).pow(2), (v_pred_clipped - return_).pow(2)).mean()
            else:
                aux_loss = 0.5 * (v_pred - return_).pow(2).mean()
            action = action.type(torch.int64)
            old_logit_a = old_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
            new_logit_a = new_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
            ratio = torch.exp(new_logit_a.log() - old_logit_a.log())
            loss_kl = ((ratio - 1) - (new_logit_a.log() - old_logit_a.log())).mean()
            loss = aux_loss + c.beta_clone * loss_kl
            aux_kl_losses += loss_kl.item()
            aux_aux_losses += aux_loss.item()
            inner_optimizer_pi.zero_grad()
            loss.backward()
            inner_optimizer_pi.step()

            if c.is_graph:
                _, _, v_pred = value_net(observation, adj, action)
            else:
                _, _, v_pred = value_net(observation, action)
            if c.vf_is_clipped:
                v_pred_clipped = v_pred + (v_pred - v_old).clamp(-c.clip_eps, c.clip_eps)
                v_loss = 0.5 * torch.max((v_pred - return_).pow(2), (v_pred_clipped - return_).pow(2)).mean()
            else:
                v_loss = 0.5 * (v_pred - return_).pow(2).mean()
            
            aux_vf_losses += v_loss.item()
            inner_optimizer_v.zero_grad()
            v_loss.backward()
            inner_optimizer_v.step()
        print(f'aux loss: {aux_aux_losses:.4f}', end=' | ')
        print(f'kl loss: {aux_kl_losses:.4f}')