<a href="https://colab.research.google.com/github/vaas-umputer/edge-cloud-workflow-scheduler/blob/main/review_3_apr11_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torch_geometric stable-baselines3 gym numpy shimmy

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting stable-baselines3
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
Collecting shimmy
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl

In [3]:
#v12
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import Data
import json
import networkx as nx
import numpy as np
from gym import Env, spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from rich.console import Console
import os

class ContrastiveGAT(nn.Module):
    def __init__(self, in_channels=9, hidden_channels=16, embedding_dim=32, heads=4):
        super().__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.2, add_self_loops=False)
        self.gat2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=0.2, add_self_loops=False)
        self.node_to_embed = nn.Linear(hidden_channels * heads, embedding_dim)
        self.edge_proj = nn.Linear(1, hidden_channels)
        self.graph_fc = nn.Linear(hidden_channels * heads, embedding_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        edge_attr = torch.ones((edge_index.shape[1], 1), device=x.device) * 0.1 if edge_attr is None else edge_attr
        edge_attr = torch.nan_to_num(edge_attr, nan=0.1)
        edge_attr = self.edge_proj(edge_attr)
        x = F.relu(self.gat1(x, edge_index, edge_attr=edge_attr))
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(self.gat2(x, edge_index, edge_attr=edge_attr))
        task_embeddings = self.node_to_embed(x)
        graph_embedding = global_mean_pool(x, batch)
        graph_embedding = self.graph_fc(graph_embedding)
        return task_embeddings, graph_embedding

    def contrastive_loss(self, embeddings, edge_index, num_nodes):
        embeddings = torch.nan_to_num(embeddings, nan=0.0)
        if edge_index.shape[1] == 0:
            return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
        pos_pairs = embeddings[edge_index[0]] - embeddings[edge_index[1]]
        pos_loss = (pos_pairs ** 2).sum(dim=1).mean()
        num_neg = min(int(num_nodes * 0.1), 500)
        neg_idx1 = torch.randint(0, num_nodes, (num_neg,), device=embeddings.device)
        neg_idx2 = torch.randint(0, num_nodes, (num_neg,), device=embeddings.device)
        mask = neg_idx1 != neg_idx2
        neg_idx1, neg_idx2 = neg_idx1[mask], neg_idx2[mask]
        if len(neg_idx1) == 0:
            return pos_loss
        neg_pairs = embeddings[neg_idx1] - embeddings[neg_idx2]
        neg_loss = F.relu(1.0 - (neg_pairs ** 2).sum(dim=1)).mean()
        return pos_loss + 0.5 * neg_loss

class TaskQueue:
    def __init__(self, dag, embeddings, task_ids, num_servers=5):
        self.dag = dag
        self.task_ids = task_ids
        self.tasks = {
            node: {
                "embedding": embeddings[i] if i < len(embeddings) else np.zeros(32),
                "dependencies": set(dag.predecessors(node)),
                "exec_time": dag.nodes[node].get("execution_time", 1.0),
                "arrival_time": dag.nodes[node].get("arrival_time", 0.0),
                "monetary_cost": dag.nodes[node].get("monetary_cost", 1.0),
                "energy_cost": dag.nodes[node].get("energy_cost", 1.0),
                "placement": None,
                "server": None,
                "history": []
            } for i, node in enumerate(task_ids)
        }
        self.completed = set()
        self.running = {}
        self.server_loads = {"cloud": np.zeros(num_servers), "edge": np.zeros(num_servers)}
        self.edge_battery = np.full(num_servers, 100.0)
        self.network_latency = 0.5
        self.current_time = 0.0
        self.total_cost = 0.0
        self.total_energy = 0.0
        self.console = Console()
        self.cloud_base_power = np.random.uniform(10, 15, num_servers)
        self.edge_base_power = np.random.uniform(2, 5, num_servers)
        self.load_history = {"cloud": [], "edge": []}
        self.placement_log = {}

    def update_conditions(self):
        self.network_latency = np.random.uniform(0.3, 1.0)
        self.server_loads["cloud"] += np.random.uniform(0, 0.2, len(self.server_loads["cloud"]))
        self.server_loads["edge"] += np.random.uniform(0, 0.1, len(self.server_loads["edge"]))
        self.server_loads["cloud"] = np.clip(self.server_loads["cloud"], 0, 2.0)
        self.server_loads["edge"] = np.clip(self.server_loads["edge"], 0, 1.5)
        self.edge_battery += 0.5
        self.edge_battery = np.clip(self.edge_battery, 0, 100)
        self.load_history["cloud"].append(self.server_loads["cloud"].copy())
        self.load_history["edge"].append(self.server_loads["edge"].copy())
        if len(self.load_history["cloud"]) > 10:
            self.load_history["cloud"].pop(0)
            self.load_history["edge"].pop(0)

    def get_ready_tasks(self):
        return [tid for tid, t in self.tasks.items() if
                tid not in self.running and tid not in self.completed and
                t["arrival_time"] <= self.current_time and
                t["dependencies"].issubset(self.completed)]

    def start_task(self, task_id, placement, server, exec_time):
        energy = 0
        if placement == "cloud":
            base_power = self.cloud_base_power[server]
            load_factor = self.server_loads[placement][server] / 3.0
            energy = (base_power + load_factor * 5) * exec_time
        elif placement == "edge":
            base_power = self.edge_base_power[server]
            load_factor = self.server_loads[placement][server] / 2.0
            energy = (base_power + load_factor * 5) * exec_time
            self.edge_battery[server] -= energy
            self.edge_battery[server] = max(0.0, self.edge_battery[server])
        self.server_loads[placement][server] += self.tasks[task_id]["exec_time"] / (3.0 if placement == "cloud" else 2.0)
        self.running[task_id] = {
            "placement": placement,
            "server": server,
            "expected_end": self.current_time + exec_time,
            "energy": energy,
            "exec_time": exec_time
        }
        self.placement_log[task_id] = {"placement": placement, "server": server}

    def complete_task(self, task_id):
        if task_id not in self.running:
            return
        task_info = self.running[task_id]
        exec_time = task_info["exec_time"]
        placement = task_info["placement"]
        server = task_info["server"]
        self.server_loads[placement][server] = max(0, self.server_loads[placement][server] - exec_time * 0.1)
        base_power = self.cloud_base_power[server] if placement == "cloud" else self.edge_base_power[server]
        load_factor = self.server_loads[placement][server] / (2.0 if placement == "cloud" else 1.5)
        energy = (base_power + load_factor * 10) * exec_time
        cost = exec_time * (0.15 if placement == "cloud" else 0.05)
        sla_deadline = self.tasks[task_id]["exec_time"] * 1.5
        sla_met = 1 if exec_time <= sla_deadline else 0
        self.total_energy += energy
        self.total_cost += cost
        self.tasks[task_id]["history"].append({
            "exec_time": exec_time,
            "energy": energy,
            "cost": cost,
            "sla": sla_met,
            "placement": placement,
            "server": server,
            "timestamp": self.current_time
        })
        self.completed.add(task_id)
        del self.running[task_id]

    def update_running(self):
        completed = [t for t, info in self.running.items() if self.current_time >= info["expected_end"]]
        for task in completed:
            self.complete_task(task)
        self.edge_battery = np.minimum(self.edge_battery + 0.5, 100.0)

    def get_historical_state(self, task_id):
        history = self.tasks[task_id]["history"]
        return np.array([1.0, 0.8, 0.5, 0.0, 1.0]) if not history else np.array([
            history[-1]["exec_time"],
            history[-1]["energy"],
            history[-1]["cost"],
            history[-1]["sla"],
            1 if history[-1]["placement"] == "cloud" else 0
        ])

    def predict_load(self, placement):
        history = self.load_history[placement]
        return np.mean(history, axis=0) if history else self.server_loads[placement]

class Tier1SchedulerEnv(Env):
    def __init__(self, task_queue, embeddings, task_ids):
        super().__init__()
        self.task_queue = task_queue
        self.embeddings = embeddings
        self.task_ids = task_ids
        self.action_space = spaces.MultiBinary(5)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(56 * 5,), dtype=np.float32)
        self.max_exec_time = 10.0
        self.console = Console()
        self.seed_value = None

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.seed_value = seed
            np.random.seed(seed)
        self.task_queue.current_time = 0.0
        self.task_queue.running.clear()
        self.task_queue.completed.clear()
        self.task_queue.server_loads = {"cloud": np.zeros(5), "edge": np.zeros(5)}
        self.task_queue.edge_battery = np.full(5, 100.0)
        self.task_queue.network_latency = 0.5
        self.task_queue.total_cost = 0.0
        self.task_queue.total_energy = 0.0
        self.task_queue.load_history = {"cloud": [], "edge": []}
        self.task_queue.placement_log = {}
        obs = self._get_obs()
        info = {"reset": True}
        return obs, info

    def _get_obs(self):
        ready_tasks = self.task_queue.get_ready_tasks()[:5]
        obs = []
        for i in range(5):
            if i < len(ready_tasks):
                tid = ready_tasks[i]
                idx = self.task_ids.index(tid)
                emb = np.nan_to_num(self.embeddings[idx], nan=0.0)
                system_state = np.concatenate([
                    [self.task_queue.current_time, len(self.task_queue.running), len(ready_tasks), self.task_queue.network_latency],
                    self.task_queue.server_loads["cloud"],
                    self.task_queue.server_loads["edge"],
                    self.task_queue.edge_battery
                ])
                historical_state = self.task_queue.get_historical_state(tid)
                obs.append(np.concatenate([emb, system_state, historical_state]))
            else:
                obs.append(np.zeros(56))
        obs_array = np.concatenate(obs).astype(np.float32)
        return torch.tensor(obs_array, dtype=torch.float32)

    def step(self, action):
        self.task_queue.update_running()
        ready_tasks = self.task_queue.get_ready_tasks()[:5]
        if not ready_tasks:
            self.task_queue.current_time += 0.1
            return self._get_obs(), 0, False, False, {"placements": {}}
        self.task_queue.update_conditions()
        placements = {}
        total_reward = 0
        for i, tid in enumerate(ready_tasks):
            if i >= len(action):
                break
            placement = "cloud" if action[i] == 0 else "edge"
            base_time = self.task_queue.tasks[tid]["exec_time"]
            exec_time = base_time
            if placement == "edge" and base_time > 5.0 and np.mean(self.task_queue.edge_battery) < 20:
                fragments = int(np.ceil(base_time / 2.0))
                exec_time = base_time / fragments
            load_factor = np.mean(self.task_queue.server_loads[placement]) / (2.0 if placement == "cloud" else 1.5)
            exec_time *= (1 + load_factor + self.task_queue.network_latency)
            base_power = np.mean(self.task_queue.cloud_base_power) if placement == "cloud" else np.mean(self.task_queue.edge_base_power)
            energy = (base_power + load_factor * 10) * exec_time
            cost = exec_time * (0.15 if placement == "cloud" else 0.05)
            sla_met = 1 if exec_time <= base_time * 1.5 else 0
            edge_bonus = 1.0 if placement == "edge" else 0.0
            reward = 5.0 * sla_met - 0.1 * (exec_time / self.max_exec_time) - 0.2 * (energy / self.max_exec_time) - 0.15 * (cost / 10.0) + edge_bonus
            total_reward += reward
            placements[tid] = {"placement": placement, "exec_time": exec_time, "energy": energy, "cost": cost}
        terminated = len(self.task_queue.completed) == len(self.task_ids)
        truncated = False
        return self._get_obs(), total_reward, terminated, truncated, {"placements": placements}

class Tier2SchedulerEnv(Env):
    def __init__(self, task_queue, embeddings, task_ids):
        super().__init__()
        self.task_queue = task_queue
        self.embeddings = embeddings
        self.task_ids = task_ids
        self.action_space = spaces.MultiDiscrete([5] * 5)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(56 * 5,), dtype=np.float32)
        self.current_tasks = []
        self.placements = {}
        self.max_exec_time = 10.0
        self.console = Console()
        self.seed_value = None

    def reset(self, seed=None, options=None):
        if seed is not None:
            self.seed_value = seed
            np.random.seed(seed)
        self.current_tasks = []
        self.placements = {}
        obs = self._get_obs()
        info = {"reset": True}
        return obs, info

    def _get_obs(self):
        obs = []
        for i in range(5):
            if i < len(self.current_tasks):
                tid = self.current_tasks[i]
                idx = self.task_ids.index(tid)
                emb = np.nan_to_num(self.embeddings[idx], nan=0.0)
                system_state = np.concatenate([
                    [self.task_queue.current_time, len(self.task_queue.running), len(self.task_queue.get_ready_tasks()), self.task_queue.network_latency],
                    self.task_queue.server_loads["cloud"],
                    self.task_queue.server_loads["edge"],
                    self.task_queue.edge_battery if self.placements.get(tid) == "edge" else np.zeros(5)
                ])
                historical_state = self.task_queue.get_historical_state(tid)
                obs.append(np.concatenate([emb, system_state, historical_state]))
            else:
                obs.append(np.zeros(56))
        obs_array = np.concatenate(obs).astype(np.float32)
        return torch.tensor(obs_array, dtype=torch.float32)

    def step(self, action_dict):
        if not isinstance(action_dict, dict):
            return self._get_obs(), 0, False, False, {}
        total_reward = 0
        info_dict = {}
        for i, (tid, server) in enumerate(action_dict.items()):
            if i >= 5 or tid not in self.task_queue.tasks:
                continue
            placement = self.placements.get(tid, "cloud")
            base_time = self.task_queue.tasks[tid]["exec_time"]
            base_time = min(base_time, self.max_exec_time)
            exec_time = base_time
            if placement == "edge" and base_time > 5.0 and self.task_queue.edge_battery[server] < 30:
                fragments = int(np.ceil(base_time / 2.0))
                exec_time = base_time / fragments
            predicted_load = self.task_queue.predict_load(placement)
            if predicted_load[server] >= (3.0 if placement == "cloud" else 2.0):
                total_reward -= 1
                continue
            load_factor = self.task_queue.server_loads[placement][server] / (3.0 if placement == "cloud" else 2.0)
            exec_time *= (1 + load_factor * 0.5 + self.task_queue.network_latency * 0.5)
            self.task_queue.start_task(tid, placement, server, exec_time)
            sla_met = 1 if exec_time <= base_time * 1.5 else 0
            base_power = self.task_queue.cloud_base_power[server] if placement == "cloud" else self.task_queue.edge_base_power[server]
            energy = (base_power + load_factor * 5) * exec_time
            cost = exec_time * (0.15 if placement == "cloud" else 0.05)
            load_balance = -0.5 * np.std(self.task_queue.server_loads[placement])
            reward = 10.0 * sla_met - 0.05 * (exec_time / self.max_exec_time) - 0.1 * (energy / self.max_exec_time) - 0.1 * (cost / 10.0) + load_balance + (1.0 if placement == "edge" else 0.0)
            total_reward += reward
            info_dict[tid] = {"task": tid, "exec_time": exec_time, "energy": energy, "cost": cost}
        self.task_queue.update_running()
        terminated = len(self.task_queue.completed) == len(self.task_ids)
        truncated = False
        self.current_tasks = []
        self.placements = {}
        return self._get_obs(), total_reward, terminated, truncated, info_dict

class WorkflowScheduler:
    def __init__(self, dag_file, embeddings_file, max_nodes=1000):
        self.dag_file = dag_file
        self.embeddings_file = embeddings_file
        self.max_nodes = max_nodes
        self.console = Console()
        self.dag, self.embeddings, self.task_ids = self.load_dag_and_embeddings()
        self.task_queue = TaskQueue(self.dag, self.embeddings, self.task_ids)
        self.tier1_env = Tier1SchedulerEnv(self.task_queue, self.embeddings, self.task_ids)
        self.tier2_env = Tier2SchedulerEnv(self.task_queue, self.embeddings, self.task_ids)
        self.tier1_model = None
        self.tier2_model = None

    def load_dag_and_embeddings(self):
        if not os.path.exists(self.dag_file):
            self.console.log(f"Error: DAG file {self.dag_file} not found")
            raise FileNotFoundError(f"DAG file {self.dag_file} not found")
        with open(self.dag_file, "r") as f:
            dag_json = json.load(f)
        dag = nx.node_link_graph(dag_json, edges="links")
        if len(dag.nodes) > self.max_nodes:
            nodes = list(dag.nodes)[:self.max_nodes]
            dag = dag.subgraph(nodes).copy()
        if not os.path.exists(self.embeddings_file):
            task_embeddings_dict = {tid: np.zeros(32) for tid in dag.nodes()}
            torch.save({"task_embeddings": task_embeddings_dict}, self.embeddings_file)
        else:
            embeddings_data = torch.load(self.embeddings_file, map_location=torch.device('cpu'))
            task_embeddings_dict = embeddings_data["task_embeddings"]
        task_ids = list(dag.nodes())
        task_embeddings = np.array([task_embeddings_dict.get(tid, np.zeros(32)) for tid in task_ids])
        return dag, task_embeddings, task_ids

    def prepare_data(self, dag):
        if not dag.nodes:
            x = torch.tensor([[0.5] * 9], dtype=torch.float)
            edge_index = torch.tensor([[], []], dtype=torch.long)
            edge_attr = torch.tensor([[0.1]], dtype=torch.float)
            batch = torch.zeros(1, dtype=torch.long)
            return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        node_to_idx = {nid: i for i, nid in enumerate(dag.nodes)}
        feature_keys = ["execution_time", "cpu_usage", "memory_usage", "arrival_time", "power_usage", "monetary_cost", "energy_cost"]
        x = [[float(dag.nodes[n].get(k, 0.5)) for k in feature_keys] +
             ([1.0, 0.0] if dag.nodes[n].get("machine_type", "cloud") == "edge" else [0.0, 1.0])
             for n in dag.nodes]
        x = torch.tensor(x, dtype=torch.float)
        x = torch.nan_to_num(x, nan=0.5, posinf=1.0, neginf=0.0)
        mean = x.mean(dim=0)
        std = torch.where(x.std(dim=0) > 0, x.std(dim=0), torch.ones_like(x.std(dim=0)))
        x = (x - mean) / std
        edge_index = torch.tensor([[node_to_idx[e[0]], node_to_idx[e[1]]] for e in dag.edges], dtype=torch.long).t() if dag.edges else torch.tensor([[], []], dtype=torch.long)
        edge_attr = torch.tensor([[dag.edges[e].get("T_comm", 0.1)] for e in dag.edges], dtype=torch.float) if dag.edges else torch.tensor([[0.1]], dtype=torch.float)
        edge_attr = torch.nan_to_num(edge_attr, nan=0.1)
        batch = torch.zeros(len(dag.nodes), dtype=torch.long)
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)

    def train_gnn(self, epochs=5):
        self.console.log("GNN training...")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = ContrastiveGAT().to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        data = self.prepare_data(self.dag).to(device)
        if data.x.shape[0] == 0 or data.edge_index.shape[1] == 0:
            return
        for epoch in range(1, epochs + 1):
            model.train()
            optimizer.zero_grad()
            task_embeddings, _ = model(data.x, data.edge_index, data.edge_attr, data.batch)
            loss = model.contrastive_loss(task_embeddings, data.edge_index, len(data.x))
            if torch.isnan(loss):
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        model.eval()
        with torch.no_grad():
            task_embeddings, _ = model(data.x, data.edge_index, data.edge_attr, data.batch)
            task_embeddings = torch.nan_to_num(task_embeddings, nan=0.0)
            task_embeddings_dict = {nid: emb.tolist() for nid, emb in zip(self.dag.nodes, task_embeddings)}
            torch.save({"task_embeddings": task_embeddings_dict}, self.embeddings_file)
        self.embeddings = np.array([task_embeddings_dict[tid] for tid in self.task_ids])

    def train_tier1(self, timesteps=1000):
        self.console.log("Tier 1 training...")
        self.tier1_model = PPO("MlpPolicy", DummyVecEnv([lambda: self.tier1_env]), verbose=0, n_steps=128, ent_coef=0.01)
        self.tier1_model.learn(total_timesteps=timesteps)
        self.tier1_model.save("tier1_scheduler")

    def train_tier2(self, timesteps=1000):
        self.console.log("Tier 2 training...")
        self.tier2_model = PPO("MlpPolicy", DummyVecEnv([lambda: self.tier2_env]), verbose=0, n_steps=128, ent_coef=0.01)
        try:
            self.tier1_model = PPO.load("tier1_scheduler")
        except FileNotFoundError:
            self.train_tier1(timesteps=500)
        obs = self.tier1_env.reset()[0]
        total_steps = timesteps // 64
        steps_taken = 0
        while steps_taken < total_steps and len(self.tier1_env.task_queue.completed) < len(self.task_ids):
            ready_tasks = self.tier1_env.task_queue.get_ready_tasks()
            if not ready_tasks and not self.tier1_env.task_queue.running:
                future_arrivals = [t["arrival_time"] for t in self.tier1_env.task_queue.tasks.values()
                                   if t["arrival_time"] > self.tier1_env.task_queue.current_time
                                   and t["dependencies"].issubset(self.tier1_env.task_queue.completed)]
                if future_arrivals:
                    self.tier1_env.task_queue.current_time = min(future_arrivals)
                    self.tier2_env.task_queue.current_time = self.tier1_env.task_queue.current_time
                    continue
                break
            action = np.zeros(5, dtype=int)
            for i in range(min(5, len(ready_tasks))):
                action[i] = self.tier1_model.predict(obs, deterministic=False)[0][i]
            obs, reward, _, _, info = self.tier1_env.step(action)
            if info["placements"]:
                self.tier2_env.current_tasks = list(info["placements"].keys())
                self.tier2_env.placements = {tid: info["placements"][tid]["placement"] for tid in info["placements"]}
                tier2_obs = self.tier2_env._get_obs()
                action_dict = {tid: self.tier2_model.predict(tier2_obs, deterministic=False)[0][i]
                               for i, tid in enumerate(self.tier2_env.current_tasks)}
                tier2_obs, reward2, done, _, tier2_info = self.tier2_env.step(action_dict)
                if tier2_info:
                    self.tier2_model.learn(total_timesteps=64, reset_num_timesteps=False)
                    steps_taken += 1
            self.tier1_env.task_queue.update_running()
            self.tier2_env.task_queue.update_running()
            obs = self.tier1_env._get_obs()
        self.tier2_model.save("tier2_scheduler")

    def validate(self):
        try:
            self.tier1_model = PPO.load("tier1_scheduler")
            self.tier2_model = PPO.load("tier2_scheduler")
        except FileNotFoundError:
            self.console.log("RL models missing")
            return
        self.console.log("Validating...")
        total_exec_time = total_energy = total_cost = sla_compliance = tasks_started = 0
        max_steps = len(self.task_ids) * 5
        step = 0
        tier1_obs = self.tier1_env.reset()[0]
        self.tier2_env.reset()
        while len(self.task_queue.completed) < len(self.task_ids) and step < max_steps:
            self.task_queue.update_running()
            ready_tasks = self.task_queue.get_ready_tasks()
            if not ready_tasks and not self.task_queue.running:
                future_arrivals = [t["arrival_time"] for t in self.task_queue.tasks.values()
                                   if t["arrival_time"] > self.task_queue.current_time and
                                   t["dependencies"].issubset(self.task_queue.completed)]
                if future_arrivals:
                    self.task_queue.current_time = min(future_arrivals)
                    continue
                break
            action = np.zeros(5, dtype=int)
            for i in range(min(5, len(ready_tasks))):
                action[i] = self.tier1_model.predict(tier1_obs, deterministic=False)[0][i]
            tier1_obs, reward1, _, _, tier1_info = self.tier1_env.step(action)
            placements = tier1_info.get("placements", {})
            if not placements:
                self.task_queue.current_time += 0.1
                step += 1
                continue
            self.tier2_env.current_tasks = list(placements.keys())
            self.tier2_env.placements = {tid: placements[tid]["placement"] for tid in placements}
            tier2_obs = self.tier2_env._get_obs()
            action_dict = {tid: self.tier2_model.predict(tier2_obs, deterministic=False)[0][i]
                           for i, tid in enumerate(self.tier2_env.current_tasks)}
            tier2_obs, reward2, done, _, tier2_info = self.tier2_env.step(action_dict)
            for tid in tier2_info:
                tasks_started += 1
                exec_time = tier2_info[tid]["exec_time"]
                total_exec_time += exec_time
                total_energy += tier2_info[tid]["energy"]
                total_cost += tier2_info[tid]["cost"]
                sla_deadline = self.task_queue.tasks[tid]["exec_time"] * 1.5
                sla_compliance += 1 if exec_time <= sla_deadline else 0
            step += 1
        completed_count = len(self.task_queue.completed)
        placement_summary = "\n".join([f"Task {tid}: {info['placement']} server {info['server']}"
                                      for tid, info in self.task_queue.placement_log.items()])
        self.console.log(f"Validation done: {completed_count}/{len(self.task_ids)} tasks completed")
        self.console.log("Placements:\n" + placement_summary)
        avg_exec_time = total_exec_time / tasks_started if tasks_started else 0
        avg_energy = total_energy / tasks_started if tasks_started else 0
        avg_cost = total_cost / tasks_started if tasks_started else 0
        sla_rate = sla_compliance / tasks_started if tasks_started else 0
        self.console.log(f"Makespan: {self.task_queue.current_time:.2f}, Avg Exec: {avg_exec_time:.2f}, "
                         f"Avg Energy: {avg_energy:.2f}, Avg Cost: {avg_cost:.2f}, SLA: {sla_rate:.2%}")

if __name__ == "__main__":
    console = Console()
    console.log("Starting...")
    scheduler = WorkflowScheduler("balanced_cybershake_dag.json", "dag_embeddings.pth")
    scheduler.train_gnn()
    scheduler.train_tier1()
    scheduler.train_tier2()
    scheduler.validate()