# Necessary imports and setup

In [None]:
!git clone https://github.com/saeedzou/MetaOffload &> /dev/null
%cd /content/MetaOffload
!pip install -q gym==0.14.0 &> /dev/null
!pip install wandb -qU &> /dev/null
!pip install -q pydotplus &> dev\null
!pip install -q learn2learn
!mkdir ./models
!mkdir ./logs

In [1]:
import os
import json
import logging

import numpy as np
import torch
import learn2learn as l2l

from copy import deepcopy
from tqdm import tqdm
from env.mec_offloaing_envs.offloading_env import Resources
from env.mec_offloaing_envs.offloading_env import OffloadingEnvironment
from models import GraphSeq2Seq, BaselineSeq2Seq
from buffer import SingleRolloutBufferPPO
from train import inner_loop, outer_loop
from utils import log_metrics

# Loading config file and logging

In [None]:
with open('my_config.json') as f:
    args = json.load(f)

class Config:
        def __init__(self, dictionary):
            for key, value in dictionary.items():
                setattr(self, key, value)

c = Config(args)
device = c.device
np.random.seed(c.seed)
torch.manual_seed(c.seed)

In [None]:
log_path = (f"MAML_obs_{c.obs_dim}_"
            f"h_{c.encoder_units}_"
            f"m_{c.adaptation_steps}_"
            f"g_{c.is_graph}_"
            f"att_{c.is_attention}_"
            f"mec_{c.mec_process_capable}_"
            f"mob_{c.mobile_process_capable}_"
            f"ul_{c.bandwidth_up}_"
            f"dl_{c.bandwidth_down}_"
            f"olr_{c.outer_lr}_"
            f"ilr_{c.inner_lr}_"
            f"n_{c.num_iterations}_"
            f"ibs_{c.inner_batch_size}_"
            f"mgn_{c.max_grad_norm}_"
            f"vf_{c.vf_coef}_"
            f"epis_{c.num_task_episodes}")
logger_path = "./logs/"+log_path+'.log'
logger = logging.getLogger(__name__)
logging.basicConfig(filename=logger_path,
                    filemode='w',
                    format='%(message)s',
                    level=logging.DEBUG,
                    force=True)
if c.wandb:
    import wandb
    wandb.login(key=c.wandb_key)
    wandb.init(project=c.wandb_project,
               name=log_path,
               config=c)
else:
    wandb = None

if c.save:
    if not os.path.exists(os.path.join(c.save_path, log_path)):
        os.makedirs(os.path.join(c.save_path, log_path))

# Loading environment

In [3]:
resources = Resources(mec_process_capable=c.mec_process_capable*1024*1024,
                      mobile_process_capable=c.mobile_process_capable*1024*1024,
                      bandwidth_up=c.bandwidth_up,
                      bandwidth_dl=c.bandwidth_down)

env = OffloadingEnvironment(resource_cluster=resources,
                            batch_size=c.graph_number,
                            graph_number=c.graph_number,
                            graph_file_paths=c.graph_file_paths,
                            time_major=False,
                            encoding=c.encoding)

print(f'Average greedy latency: {np.mean(env.greedy_solution()[1]):.4f}')
print(f'Average all local latency: {np.mean(env.get_all_locally_execute_time()):.4f}')
print(f'Average all mec latency: {np.mean(env.get_all_mec_execute_time()):.4f}')

# Loading model and initializing buffer and optimizer

In [None]:
if c.is_graph:
    policy = GraphSeq2Seq(input_dim=c.obs_dim,
                          hidden_dim=c.encoder_units,
                          output_dim=c.action_dim,
                          num_layers=c.num_layers,
                          device=device,
                          is_attention=c.is_attention).to(device)
else:
    policy = BaselineSeq2Seq(input_dim=c.obs_dim,
                             hidden_dim=c.encoder_units,
                             output_dim=c.action_dim,
                             num_layers=c.num_layers,
                             device=device,
                             is_attention=c.is_attention).to(device)
if c.load:
    policy.load_state_dict(torch.load(c.load_path, map_location=device))

maml = l2l.algorithms.MAML(policy, lr=c.inner_lr, first_order=False, allow_unused=True)

train_buffer = SingleRolloutBufferPPO(buffer_size=c.graph_number*c.num_task_episodes, 
                                      discount=c.gamma, 
                                      gae_lambda=c.tau, 
                                      device=device)
val_buffer = SingleRolloutBufferPPO(buffer_size=c.graph_number*c.num_task_episodes,
                                    discount=c.gamma,
                                    gae_lambda=c.tau,
                                    device=device)
outer_optimizer = torch.optim.Adam(maml.parameters(), lr=c.outer_lr)

# Train Loop

In [None]:
with torch.backends.cudnn.flags(enabled=False):
    for iteration in tqdm(range(c.start_iter, c.num_iterations), leave=False, disable=True):
        task_policies = []
        fts_before, fts_after = [], []
        vf_losses, pg_losses = [], []
        all_rewards, all_returns = [], []
        meta_loss = 0.0

        batch_of_tasks = env.sample_tasks(c.meta_batch_size)
        ### Inner loop ###
        for i, task_id in tqdm(enumerate(batch_of_tasks), leave=False, total=c.meta_batch_size, desc=f"Iteration {iteration} inner loop"):
            learner = maml.clone()
            train_buffer.reset()
            train_buffer.collect_episodes(env=env, policy=learner, device=device, task_id=task_id, is_graph=c.is_graph)
            train_buffer.process_task()
            observations, adjs, actions, logits, v_olds, advantages, rewards, returns, fts = train_buffer.sample(batch_size=c.inner_batch_size)
            vf_loss, pg_loss = [], []
            for step in tqdm(range(c.adaptation_steps), desc=f'Adapting task {task_id}', ascii=True, disable=True):
                for observation, adj, action, old_logit, v_old, advantage, return_ in zip(observations, adjs, actions, logits, v_olds, advantages, returns):
                    if c.is_graph:
                        _, new_logit, v_pred = learner(observation, adj, action)
                    else:
                        _, new_logit, v_pred = learner(observation, action)
                    action = action.type(torch.int64)
                    new_logit_a = new_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
                    old_logit_a = old_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
                    ratio = torch.exp(torch.log(new_logit_a) - torch.log(old_logit_a))
                    # compute surrogate loss
                    obj = ratio * advantage
                    obj_clip = ratio.clamp(1.0 - c.clip_eps, 1.0 + c.clip_eps) * advantage
                    policy_loss = -torch.min(obj, obj_clip).mean()
                    # compute value loss
                    if c.vf_is_clipped:
                        v_pred_clipped = v_pred + (v_pred - v_old).clamp(-c.clip_eps, c.clip_eps)
                        v_loss = 0.5 * torch.max((v_pred - return_).pow(2), (v_pred_clipped - return_).pow(2)).mean()
                    else:
                        v_loss = 0.5 * (v_pred - return_).pow(2).mean()
                    vf_loss.append(v_loss.item())
                    pg_loss.append(policy_loss.item())
                    # compute total loss
                    loss = policy_loss + c.vf_coef * v_loss
                    learner.adapt(loss)

            vf_losses.append(np.mean(vf_loss))
            pg_losses.append(np.mean(pg_loss))
            fts_before.append(fts)

            val_buffer.reset()
            val_buffer.collect_episodes(env=env, policy=learner, device=device, task_id=task_id, is_graph=c.is_graph)
            val_buffer.process_task()

            observations, adjs, actions, logits, v_olds, advantages, rewards, returns, fts = val_buffer.sample()
            fts_after.append(fts)
            for observation, adj, action, old_logit, v_old, advantage, return_ in zip(observations, adjs, actions, logits, v_olds, advantages, returns):
                if c.is_graph:
                    _, new_logit, v_pred = learner(observation, adj, action)
                else:
                    _, new_logit, v_pred = learner(observation, action)
                action = action.type(torch.int64)
                new_logit_a = new_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
                old_logit_a = old_logit.gather(-1, action.unsqueeze(-1)).squeeze(-1)
                ratio = torch.exp(torch.log(new_logit_a) - torch.log(old_logit_a))
                # compute surrogate loss
                obj = ratio * advantage
                obj_clip = ratio.clamp(1.0 - c.clip_eps, 1.0 + c.clip_eps) * advantage
                policy_loss = -torch.min(obj, obj_clip).mean()
                # compute value loss
                if c.vf_is_clipped:
                    v_pred_clipped = v_pred + (v_pred - v_old).clamp(-c.clip_eps, c.clip_eps)
                    v_loss = 0.5 * torch.max((v_pred - return_).pow(2), (v_pred_clipped - return_).pow(2)).mean()
                else:
                    v_loss = 0.5 * (v_pred - return_).pow(2).mean()
                # compute total loss
                loss = policy_loss + c.vf_coef * v_loss
                meta_loss += loss

            all_rewards.append(val_buffer.rewards.sum(-1))
            all_returns.append(val_buffer.returns[:, 0].mean().item())
        
        ### Log metrics ###
        avg_vf_losses = np.mean(vf_losses)
        avg_pg_losses = np.mean(pg_losses)
        avg_rewards = np.mean(np.concatenate(all_rewards))
        avg_returns = np.mean(all_returns)
        avg_fts_before = np.mean(np.concatenate(fts_before))
        avg_fts_after = np.mean(np.concatenate(fts_after))

        log_metrics(logger=logger,
                    iteration=iteration,
                    vf_losses=avg_vf_losses,
                    pg_losses=avg_pg_losses,
                    rewards=avg_rewards,
                    returns=avg_returns,
                    finish_times_old=avg_fts_before,
                    finish_times_new=avg_fts_after,
                    wandb=wandb)

        meta_loss /= c.meta_batch_size
        outer_optimizer.zero_grad()
        meta_loss.backward()
        outer_optimizer.step()

    if c.save:
        torch.save(maml.state_dict(), 
                  os.path.join(c.save_path, log_path, f'policy_{c.num_iterations}.pt'))