# What is this tutorial
This notebook demonstrates how to run a reinforcement learning workflow using ROSE's reinforcement learner. 
It will show how to:
* Define and register the environment and update tasks
* Use a stop criterion to terminate training when a target reward is met
* Run a sequential reinforcement learning loop with policy gradient methods

## Sequential Reinforcement Learner

In this example, we will learn how to use ROSE API to build and submit a `single` Reinforcement Learner that either stops when the performance metric threshold is `met` or the number of iterations the user specified is reached (in this case, 100 iterations). This example uses the REINFORCE policy gradient algorithm for the CartPole environment.


                                                  ┌────────────────────────┐
                                                  │      ENVIRONMENT       │
                                                  │ (Run the policy and    │
                                                  │   gather experience)   │
                                                  └────────────┬───────────┘
                                                               │
                                                               ▼
                                                  ┌────────────────────────┐
                                                  │     POLICY UPDATE      │
                                                  │ (Update the policy     │
                                                  │  using experiences)    │
                                                  └────────────┬───────────┘
                                                               │
                                                               ▼
                                                  ┌────────────────────────┐
                                                  │      POLICY TEST       │
                                                  │ (Test the performance  │
                                                  │   of the new policy)   │
                                                  └────────────┬───────────┘
                                                               │
                                                               ▼
                                                  ┌────────────────────────┐
                                                  │  IMPROVED POLICY LOOP  │
                                                  │(Repeat for N iters     │
                                                  │   or performance goal) │
                                                  └────────────────────────┘

In [None]:
import asyncio
import logging

from typing import List, Tuple
from dataclasses import dataclass

# Task imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym

# ROSE top layer imports
from rose.metrics import GREATER_THAN_THRESHOLD
from rose.rl.reinforcement_learner import SequentialReinforcementLearner

# ROSE Bottom layers imports
from radical.asyncflow import WorkflowEngine
from rhapsody.backends import ConcurrentExecutionBackend
from concurrent.futures import ProcessPoolExecutor
from radical.asyncflow.logging import init_default_logger

logger = logging.getLogger(__name__)

In [2]:
@dataclass
class Config:
    env_id: str = "CartPole-v1"
    seed: int = 42
    hidden_size: int = 128
    gamma: float = 0.99
    lr: float = 3e-3
    episodes: int = 1000
    batch_size: int = 10
    reward_solve_threshold: float = 475.0
    device: str = "cpu"
    model_path: str = "cartpole_policy.pt"
cfg = Config()

In [3]:
class Network(nn.Module):
    def __init__(self, obs_dim: int, hidden: int, n_actions: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )

    def forward(self, x):
        return self.net(x)

    def action_dist(self, obs: np.ndarray) -> torch.distributions.Categorical:
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        logits = self.forward(obs_t)
        return torch.distributions.Categorical(logits=logits)

In [4]:
async def run(rl, **kwargs):

    # ========================================================================
    # 0. HELPER FUNCTIONS
    # ========================================================================
    def discount_reward(rewards: List[float], gamma: float) -> List[float]:
        g = 0.0
        out = []
        for r in reversed(rewards):
            g = r + gamma * g
            out.append(g)
        return list(reversed(out))
    def run_episode(env, policy: Network, seed: int = None) -> Tuple[List[np.ndarray], List[int], List[float]]:
        obs, _ = env.reset(seed=seed)

        obs_list, act_list, rew_list = [], [], []
        done = False
        while not done:
            dist = policy.action_dist(np.expand_dims(obs, axis=0))
            action = dist.sample().item()
            obs_list.append(obs.copy())
            act_list.append(action)

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            rew_list.append(float(reward))
            obs = next_obs
        return obs_list, act_list, rew_list
        
    # ========================================================================
    # 1. ENVIRONMENT TASK
    # ========================================================================
    @rl.environment_task(as_executable=False)
    async def environment(*args, **kwargs) -> dict:
        env = gym.make(cfg.env_id)
        obs_dim = env.observation_space.shape[0]
        n_actions = env.action_space.n
        policy = Network(obs_dim, cfg.hidden_size, n_actions)
        try:
            policy.load_state_dict(torch.load(cfg.model_path, map_location="cpu"))
        except FileNotFoundError:
            torch.save(policy.state_dict(), cfg.model_path)
        episode_buffer = []
        for i in range(cfg.batch_size):
            obs, _ = env.reset(seed=cfg.seed+i)

            observations, actions, rewards = run_episode(env, policy, seed=cfg.seed+i)
            G = discount_reward(rewards, cfg.gamma)
            for o, a, g in zip(observations, actions, G):
                episode_buffer.append((o, a, g))
        obs_batch, act_batch, ret_batch = zip(*episode_buffer)
        return {"observations": obs_batch, "actions": act_batch, "returns": ret_batch}

    # ========================================================================
    # 2. UPDATE TASK
    # ========================================================================
    @rl.update_task(as_executable=False)
    async def update(*args, **kwargs) -> dict:
        data = args[0] if args else kwargs.get("data", {})
        env = gym.make(cfg.env_id)
        obs_dim = env.observation_space.shape[0]
        n_actions = env.action_space.n
        policy = Network(obs_dim, cfg.hidden_size, n_actions)
        policy.load_state_dict(torch.load(cfg.model_path, map_location="cpu"))
        optimizer = optim.Adam(policy.parameters(), lr=cfg.lr)
        obs_batch = torch.tensor(np.array(data["observations"]), dtype=torch.float32, device=cfg.device)
        act_batch = torch.tensor(data["actions"], dtype=torch.int64, device=cfg.device)
        ret_batch = torch.tensor(data["returns"], dtype=torch.float32, device=cfg.device)
        ret_batch = (ret_batch - ret_batch.mean()) / (ret_batch.std() + 1e-8)

        dists = policy.action_dist(obs_batch)
        logp = dists.log_prob(act_batch)
        loss = -(logp * ret_batch).mean()

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(policy.parameters(), max_norm=5.0)
        optimizer.step()
        torch.save(policy.state_dict(), cfg.model_path)
        return {"loss": loss.item()}

    # ========================================================================
    # 3. STOP CRITERION TASK
    # ========================================================================
    @rl.as_stop_criterion(metric_name='MODEL_REWARD', threshold=cfg.reward_solve_threshold, operator=GREATER_THAN_THRESHOLD, as_executable=False)
    async def check_reward(*args, **kwargs):
        env = gym.make(cfg.env_id)
        obs_dim = env.observation_space.shape[0]
        n_actions = env.action_space.n
        policy = Network(obs_dim, cfg.hidden_size, n_actions)
        policy.load_state_dict(torch.load(cfg.model_path, map_location="cpu"))
        policy.eval()
        rewards = []
        for _ in range(cfg.batch_size):
            obs, _ = env.reset()
            done = False
            it_reward = 0.0
    
            while not done:
                with torch.no_grad():
                    dist = policy.action_dist(np.expand_dims(obs, axis=0))
                    action = torch.argmax(dist.probs).item() 
    
                obs, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                it_reward += float(reward)
            rewards.append(it_reward)
        avg_reward = float(np.mean(rewards))
        return avg_reward

    # Run
    logger.info("Starting Reinforcement Learning with ROSE...")
    await rl.learn(**kwargs)
    logger.info("Reinforcement Learning completed!")

try:
    engine = await ConcurrentExecutionBackend(ProcessPoolExecutor())
    asyncflow = await WorkflowEngine.create(engine)
    rl = SequentialReinforcementLearner(asyncflow)

    init_default_logger(logging.INFO)
    await run(rl, max_iter=int(cfg.episodes/cfg.batch_size))
except Exception as e:
    print(f'Learner Failed with: {e}')
finally:
    await rl.shutdown()
    logging.getLogger().handlers.clear()

[90m2025-11-04 13:32:08.173[0m │ [94mINFO[0m │ [38;5;165m[root][0m │ Logger configured successfully - Console: INFO, File: disabled (N/A), Structured: disabled, Style: modern
[90m2025-11-04 13:32:08.176[0m │ [94mINFO[0m │ [38;5;165m[main][0m │ Starting Reinforcement Learning with ROSE...
Starting Sequential RL Learner
Starting Iteration-0
[90m2025-11-04 13:32:08.180[0m │ [94mINFO[0m │ [38;5;165m[workflow_manager][0m │ Submitting ['environment'] for execution
[90m2025-11-04 13:32:08.325[0m │ [94mINFO[0m │ [38;5;165m[workflow_manager][0m │ task.000001 is in DONE state
[90m2025-11-04 13:32:08.336[0m │ [94mINFO[0m │ [38;5;165m[workflow_manager][0m │ Submitting ['update'] for execution
[90m2025-11-04 13:32:09.516[0m │ [94mINFO[0m │ [38;5;165m[workflow_manager][0m │ task.000002 is in DONE state
[90m2025-11-04 13:32:09.528[0m │ [94mINFO[0m │ [38;5;165m[workflow_manager][0m │ Submitting ['check_reward'] for execution
[90m2025-11-04 13:32:09.657[0m │ 