From 63dde2a6bb12509309569f70bbbd01a408ee7bae Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 28 Apr 2026 16:36:01 +0800 Subject: [PATCH 1/4] Add RL example. --- examples/rl/.python-version | 1 + examples/rl/README.md | 248 ++++++++++++++++ examples/rl/main.py | 409 ++++++++++++++++++++++++++ examples/rl/pyproject.toml | 17 ++ executor_manager/src/appmgr/python.rs | 1 + flmadm/src/managers/installation.rs | 28 +- 6 files changed, 698 insertions(+), 6 deletions(-) create mode 100644 examples/rl/.python-version create mode 100644 examples/rl/README.md create mode 100644 examples/rl/main.py create mode 100644 examples/rl/pyproject.toml diff --git a/examples/rl/.python-version b/examples/rl/.python-version new file mode 100644 index 00000000..e4fba218 --- /dev/null +++ b/examples/rl/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/examples/rl/README.md b/examples/rl/README.md new file mode 100644 index 00000000..c5acd4f1 --- /dev/null +++ b/examples/rl/README.md @@ -0,0 +1,248 @@ +# Reinforcement Learning Example (Python) + +## Motivation + +Reinforcement Learning (RL) training is computationally intensive, with episode collection (rollouts) being a major bottleneck. Each episode requires running the environment simulation, which can be slow for complex environments. Fortunately, episode collection is **embarrassingly parallel** — each episode is independent and can run on a separate worker. + +By leveraging the `flamepy.Runner` API, we can distribute episode collection across multiple executors in a Flame cluster, dramatically speeding up training while keeping the policy update logic centralized. This pattern is common in distributed RL systems like IMPALA, Ape-X, and SEED. + +This example illustrates: +- How to parallelize RL episode collection using Flame Runner +- The actor-learner pattern: distributed actors collect experience, centralized learner updates policy +- How to serialize and broadcast PyTorch model weights to remote workers +- Clean separation between distributed data collection and local gradient computation + +## Overview + +This example implements the REINFORCE (policy gradient) algorithm on environments from Gymnasium, with distributed episode collection using Flame. + +### Supported Environments + +| Environment | Type | Observation | Action | Episode Time | +|-------------|------|-------------|--------|--------------| +| `cartpole` | Discrete | 4 | 2 | ~1ms | +| `halfcheetah` | Continuous (MuJoCo) | 17 | 6 | ~20ms | +| `hopper` | Continuous (MuJoCo) | 11 | 3 | ~15ms | +| `walker2d` | Continuous (MuJoCo) | 17 | 6 | ~20ms | +| `ant` | Continuous (MuJoCo) | 105 | 8 | ~50ms | + +### How It Works + +1. **Policy Network**: Neural networks that output action distributions: + - `DiscretePolicy`: For CartPole (categorical distribution) + - `ContinuousPolicy`: For MuJoCo environments (Gaussian distribution with learned std) + +2. **Distributed Episode Collection**: Using `flamepy.Runner`, we create a service from the `collect_episode` function. Each call to this service runs on a remote executor that: + - Creates its own Gymnasium environment instance + - Loads the current policy weights (serialized and sent from the learner) + - Runs one complete episode + - Returns the collected experience (states, actions, rewards) + +3. **Centralized Policy Update**: After collecting experiences from all parallel episodes, the learner: + - Computes discounted rewards + - Calculates policy gradients + - Updates the policy network locally + +4. **Iteration**: The process repeats — broadcast new weights, collect more episodes, update again. + +### Files + +- **`main.py`**: REINFORCE training (distributed by default, use `--local` for local mode) +- **`pyproject.toml`**: Package dependencies including `torch`, `gymnasium[mujoco]`, and `flamepy` +- **`README.md`**: This documentation file + +### Key Benefits + +- **Linear Speedup**: Collecting N episodes in parallel takes roughly the same time as collecting 1 episode +- **Minimal Code Changes**: The episode collection function is almost identical to single-threaded code +- **Scalability**: Works with any number of executors — just change `episodes_per_iteration` +- **Flexibility**: Includes local training mode for development and testing without a cluster + +## Usage + +### Prerequisites + +Start the Flame cluster with Docker Compose: + +```shell +$ docker compose up -d +``` + +### Running Distributed Training + +Log into the flame-console and run the example: + +```shell +$ docker compose exec -it flame-console /bin/bash +root@container:/# cd /opt/examples/rl +root@container:/opt/examples/rl# uv run main.py +``` + +### Command Line Options + +```shell +# Distributed training with CartPole (default) +uv run main.py + +# Distributed training with MuJoCo environments +uv run main.py --env ant +uv run main.py --env halfcheetah +uv run main.py --env hopper +uv run main.py --env walker2d + +# Local training (no Flame cluster required) +uv run main.py --local +uv run main.py --env ant --local + +# Custom training configuration +uv run main.py --env ant --iterations 50 --episodes-per-iter 50 + +# Show training plot (requires matplotlib) +uv run main.py --plot +``` + +### Options + +| Flag | Description | Default | +|------|-------------|---------| +| `--env` | Environment: cartpole, halfcheetah, hopper, walker2d, ant | cartpole | +| `--local` | Run local training (no Flame cluster) | Off | +| `--iterations` | Number of training iterations | 100 | +| `--episodes-per-iter` | Parallel episodes per iteration | 100 | +| `--plot` | Show reward plot after training | Off | + +## Example Output + +### Distributed Training (MuJoCo Ant) + +```shell +root@container:/opt/examples/rl# uv run main.py --env ant --iterations 20 +============================================================ +Distributed REINFORCE on Ant-v5 using Flame Runner +============================================================ + +Configuration: + Environment: Ant-v5 + Observation dim: 105 + Action dim: 8 + Continuous actions: True + Training iterations: 20 + Episodes per iteration: 100 + Total episodes: 2000 + +Starting distributed training... +Iteration 0 | Mean Reward: -431.5 | Loss: 0.7285 +Iteration 10 | Mean Reward: -138.8 | Loss: 2.4785 +Iteration 19 | Mean Reward: -122.4 | Loss: -7.4812 + +============================================================ +Training Complete! + Total time: 85.23s + Episodes: 2000 (23.5 episodes/sec) + Final Mean Reward: -122.4 +============================================================ +``` + +### Local Training + +```shell +root@container:/opt/examples/rl# uv run main.py --env ant --iterations 20 --local +============================================================ +Local REINFORCE on Ant-v5 +============================================================ + +Configuration: + Environment: Ant-v5 + Observation dim: 105 + Action dim: 8 + Continuous actions: True + Training iterations: 20 + Episodes per iteration: 100 + Total episodes: 2000 + +Starting local training... +Iteration 0 | Mean Reward: -161.2 | Loss: -7.8887 +Iteration 10 | Mean Reward: -120.5 | Loss: -2.9774 +Iteration 19 | Mean Reward: -91.6 | Loss: 0.6673 + +============================================================ +Training Complete! + Total time: 106.45s + Episodes: 2000 (18.8 episodes/sec) + Final Mean Reward: -91.6 +============================================================ +``` + +## Architecture + +``` +┌─────────────────────────────────────────────────────────┐ +│ Learner (Local) │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ Policy │───▶│ Broadcast │───▶│ Collect │ │ +│ │ Update │ │ Weights │ │ Futures │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +│ ▲ │ │ +│ │ ┌──────────────────────┘ │ +│ │ ▼ │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ Compute │◀───│ Aggregate │ │ +│ │ Gradients │ │ Episodes │ │ +│ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────┘ + │ + │ Flame Runner API + ▼ +┌────────────────────────────────────────────────────────┐ +│ Flame Cluster │ +│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ +│ │ Executor 1│ │ Executor 2│ │ Executor N│ ... │ +│ │ ┌───┐ │ │ ┌───┐ │ │ ┌───┐ │ │ +│ │ │Env│ │ │ │Env│ │ │ │Env│ │ │ +│ │ └───┘ │ │ └───┘ │ │ └───┘ │ │ +│ │ Episode │ │ Episode │ │ Episode │ │ +│ │ Collection│ │ Collection│ │ Collection│ │ +│ └───────────┘ └───────────┘ └───────────┘ │ +└────────────────────────────────────────────────────────┘ +``` + +## Performance + +### When Distribution Helps + +Distribution overhead is ~100ms per task. Speedup depends on episode duration: + +| Environment | Episode Time | Distributed Benefit | +|-------------|--------------|---------------------| +| CartPole | ~1ms | ❌ Overhead dominates | +| Hopper | ~15ms | ⚠️ Marginal with high parallelism | +| HalfCheetah | ~20ms | ⚠️ Marginal with high parallelism | +| Ant | ~50ms | ✅ Benefits with 50+ episodes/iter | +| Complex sims | >100ms | ✅✅ Near-linear speedup | +| Real-world/expensive | >1s | ✅✅✅ Essential | + +### Maximizing Distributed Performance + +1. **Increase `--episodes-per-iter`**: More parallel episodes amortizes the per-iteration overhead (weight upload, session management) +2. **Use heavier environments**: MuJoCo environments benefit more than CartPole +3. **Scale executors**: More executors = more parallel episode collection + +### Scaling Behavior + +With N executors collecting episodes in parallel: + +| Executors | Episodes/Iteration | Theoretical Speedup | Actual Speedup* | +|-----------|-------------------|---------------------|-----------------| +| 1 | 100 | 1x | 1x | +| 5 | 100 | 5x | ~4x | +| 10 | 100 | 10x | ~7-8x | +| 20 | 100 | 20x | ~12-15x | + +*Actual speedup limited by: network latency, executor startup, gradient aggregation time. + +### Best Practices + +1. **Use `--episodes-per-iter 100`** (default) or higher for expensive environments +2. **Use local mode** (`--local`) for fast environments or development/debugging +3. **Profile your environment** to determine if distribution is beneficial +4. **Start with MuJoCo** (ant, halfcheetah) to see distributed benefits diff --git a/examples/rl/main.py b/examples/rl/main.py new file mode 100644 index 00000000..87d1be03 --- /dev/null +++ b/examples/rl/main.py @@ -0,0 +1,409 @@ +""" +Distributed REINFORCE on CartPole-v1 or MuJoCo using Flame Runner. +Use --local flag for local training without a Flame cluster. +Use --env to select environment (cartpole, halfcheetah, hopper, walker2d, ant). +""" + +import time +from dataclasses import dataclass +from typing import Tuple + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.distributions import Categorical, Normal + + +@dataclass +class EnvConfig: + name: str + obs_dim: int + act_dim: int + continuous: bool + max_episode_steps: int + + +ENV_CONFIGS = { + "cartpole": EnvConfig("CartPole-v1", 4, 2, False, 500), + "halfcheetah": EnvConfig("HalfCheetah-v5", 17, 6, True, 1000), + "hopper": EnvConfig("Hopper-v5", 11, 3, True, 1000), + "walker2d": EnvConfig("Walker2d-v5", 17, 6, True, 1000), + "ant": EnvConfig("Ant-v5", 105, 8, True, 1000), +} + + +class DiscretePolicy(nn.Module): + """Policy network for discrete action spaces (CartPole).""" + + def __init__(self, obs_dim: int, act_dim: int): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(obs_dim, 128), + nn.ReLU(), + nn.Linear(128, act_dim), + nn.Softmax(dim=-1), + ) + + def forward(self, x): + return self.fc(x) + + def get_action(self, state: torch.Tensor) -> Tuple[int, torch.Tensor]: + probs = self(state) + m = Categorical(probs) + action = m.sample() + return action.item(), m.log_prob(action) + + def evaluate(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + probs = self(states) + m = Categorical(probs) + return m.log_prob(actions) + + +class ContinuousPolicy(nn.Module): + """Policy network for continuous action spaces (MuJoCo).""" + + def __init__(self, obs_dim: int, act_dim: int): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(obs_dim, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + ) + self.mean = nn.Linear(256, act_dim) + self.log_std = nn.Parameter(torch.zeros(act_dim)) + + def forward(self, x): + x = self.fc(x) + mean = self.mean(x) + std = self.log_std.exp() + return mean, std + + def get_action(self, state: torch.Tensor) -> Tuple[np.ndarray, torch.Tensor]: + mean, std = self(state) + m = Normal(mean, std) + action = m.sample() + log_prob = m.log_prob(action).sum(dim=-1) + return action.squeeze(0).numpy(), log_prob + + def evaluate(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + mean, std = self(states) + m = Normal(mean, std) + return m.log_prob(actions).sum(dim=-1) + + +def create_policy(env_config: EnvConfig) -> nn.Module: + if env_config.continuous: + return ContinuousPolicy(env_config.obs_dim, env_config.act_dim) + return DiscretePolicy(env_config.obs_dim, env_config.act_dim) + + +def collect_episode(weights, env_name: str) -> dict: + """Runs on distributed executors to collect one episode. + + Args: + weights: Model state_dict (auto-resolved from ObjectRef by Runner) + env_name: Name of the environment to use + """ + import gymnasium as gym + import numpy as np + import torch + + from main import ENV_CONFIGS, create_policy + + env_config = ENV_CONFIGS[env_name] + model = create_policy(env_config) + model.load_state_dict(weights) + model.eval() + + env = gym.make(env_config.name) + states, actions, rewards = [], [], [] + state, _ = env.reset() + done = False + + while not done: + states.append(state) + state_tensor = torch.FloatTensor(state).unsqueeze(0) + with torch.no_grad(): + action, _ = model.get_action(state_tensor) + actions.append(action) + + state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + rewards.append(reward) + + env.close() + return { + "states": states, + "actions": actions, + "rewards": rewards, + "total_reward": sum(rewards), + } + + +def compute_discounted_rewards(rewards: list, gamma: float = 0.99) -> torch.Tensor: + discounted = [] + R = 0 + for r in reversed(rewards): + R = r + gamma * R + discounted.insert(0, R) + discounted = torch.tensor(discounted, dtype=torch.float32) + if len(discounted) > 1: + discounted = (discounted - discounted.mean()) / (discounted.std() + 1e-8) + return discounted + + +def train_distributed( + env_name: str, num_iterations: int = 100, episodes_per_iteration: int = 10 +): + from functools import partial + + from flamepy import put_object + from flamepy.runner import Runner + + env_config = ENV_CONFIGS[env_name] + total_episodes = num_iterations * episodes_per_iteration + + print("=" * 60) + print(f"Distributed REINFORCE on {env_config.name} using Flame Runner") + print("=" * 60) + print(f"\nConfiguration:") + print(f" Environment: {env_config.name}") + print(f" Observation dim: {env_config.obs_dim}") + print(f" Action dim: {env_config.act_dim}") + print(f" Continuous actions: {env_config.continuous}") + print(f" Training iterations: {num_iterations}") + print(f" Episodes per iteration: {episodes_per_iteration}") + print(f" Total episodes: {total_episodes}") + print(f"\nStarting distributed training...") + + policy = create_policy(env_config) + lr = 3e-4 if env_config.continuous else 1e-2 + optimizer = optim.Adam(policy.parameters(), lr=lr) + episode_rewards_history = [] + mean_reward = 0.0 + start_time = time.time() + + collect_fn = partial(collect_episode, env_name=env_name) + + with Runner(f"rl-{env_name}") as rr: + collector = rr.service(collect_fn) + + for iteration in range(num_iterations): + weights_ref = put_object(f"rl-weights-{iteration}", policy.state_dict()) + + futures = [collector(weights_ref) for _ in range(episodes_per_iteration)] + episodes = rr.get(futures) + + iteration_rewards = [ep["total_reward"] for ep in episodes] + mean_reward = np.mean(iteration_rewards) + episode_rewards_history.extend(iteration_rewards) + + policy.train() + optimizer.zero_grad() + total_loss = torch.tensor(0.0) + + for episode in episodes: + discounted_rewards = compute_discounted_rewards(episode["rewards"]) + states_tensor = torch.FloatTensor(np.array(episode["states"])) + actions_tensor = torch.tensor( + np.array(episode["actions"]), + dtype=torch.float32 if env_config.continuous else torch.long, + ) + log_probs = policy.evaluate(states_tensor, actions_tensor) + episode_loss = -(log_probs * discounted_rewards).sum() + total_loss = total_loss + episode_loss + + total_loss = total_loss / len(episodes) + total_loss.backward() + optimizer.step() + + if iteration % 10 == 0 or iteration == num_iterations - 1: + print( + f"Iteration {iteration:3d} | " + f"Mean Reward: {mean_reward:8.1f} | " + f"Loss: {total_loss.item():.4f}" + ) + + elapsed = time.time() - start_time + print("\n" + "=" * 60) + print("Training Complete!") + print(f" Total time: {elapsed:.2f}s") + print(f" Episodes: {total_episodes} ({total_episodes/elapsed:.1f} episodes/sec)") + print(f" Final Mean Reward: {mean_reward:.1f}") + print("=" * 60) + + return policy, episode_rewards_history + + +def train_local( + env_name: str, num_iterations: int = 100, episodes_per_iteration: int = 10 +): + env_config = ENV_CONFIGS[env_name] + total_episodes = num_iterations * episodes_per_iteration + + print("=" * 60) + print(f"Local REINFORCE on {env_config.name}") + print("=" * 60) + print(f"\nConfiguration:") + print(f" Environment: {env_config.name}") + print(f" Observation dim: {env_config.obs_dim}") + print(f" Action dim: {env_config.act_dim}") + print(f" Continuous actions: {env_config.continuous}") + print(f" Training iterations: {num_iterations}") + print(f" Episodes per iteration: {episodes_per_iteration}") + print(f" Total episodes: {total_episodes}") + print(f"\nStarting local training...") + + start_time = time.time() + env = gym.make(env_config.name) + policy = create_policy(env_config) + lr = 3e-4 if env_config.continuous else 1e-2 + optimizer = optim.Adam(policy.parameters(), lr=lr) + episode_rewards_history = [] + mean_reward = 0.0 + + for iteration in range(num_iterations): + iteration_episodes = [] + + for _ in range(episodes_per_iteration): + state, _ = env.reset() + log_probs = [] + rewards = [] + states = [] + actions = [] + done = False + + while not done: + states.append(state) + state_tensor = torch.FloatTensor(state).unsqueeze(0) + action, log_prob = policy.get_action(state_tensor) + actions.append(action) + + state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + + log_probs.append(log_prob) + rewards.append(reward) + + iteration_episodes.append({ + "states": states, + "actions": actions, + "rewards": rewards, + "log_probs": log_probs, + "total_reward": sum(rewards), + }) + + iteration_rewards = [ep["total_reward"] for ep in iteration_episodes] + mean_reward = np.mean(iteration_rewards) + episode_rewards_history.extend(iteration_rewards) + + policy.train() + optimizer.zero_grad() + total_loss = torch.tensor(0.0) + + for episode in iteration_episodes: + discounted_rewards = compute_discounted_rewards(episode["rewards"]) + for log_prob, Gt in zip(episode["log_probs"], discounted_rewards): + total_loss = total_loss - log_prob * Gt + + total_loss = total_loss / len(iteration_episodes) + total_loss.backward() + optimizer.step() + + if iteration % 10 == 0 or iteration == num_iterations - 1: + print( + f"Iteration {iteration:3d} | " + f"Mean Reward: {mean_reward:8.1f} | " + f"Loss: {total_loss.item():.4f}" + ) + + env.close() + + elapsed = time.time() - start_time + print("\n" + "=" * 60) + print("Training Complete!") + print(f" Total time: {elapsed:.2f}s") + print(f" Episodes: {total_episodes} ({total_episodes/elapsed:.1f} episodes/sec)") + print(f" Final Mean Reward: {mean_reward:.1f}") + print("=" * 60) + + return policy, episode_rewards_history + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="REINFORCE RL on CartPole or MuJoCo (distributed or local)" + ) + parser.add_argument( + "--env", + type=str, + default="cartpole", + choices=list(ENV_CONFIGS.keys()), + help="Environment to use (default: cartpole)", + ) + parser.add_argument( + "--local", action="store_true", help="Run local training (no Flame cluster)" + ) + parser.add_argument( + "--iterations", type=int, default=100, help="Number of training iterations" + ) + parser.add_argument( + "--episodes-per-iter", type=int, default=100, help="Episodes per iteration" + ) + parser.add_argument( + "--plot", action="store_true", help="Show training reward plot" + ) + + args = parser.parse_args() + + if args.local: + policy, rewards = train_local( + args.env, + num_iterations=args.iterations, + episodes_per_iteration=args.episodes_per_iter, + ) + else: + policy, rewards = train_distributed( + args.env, + num_iterations=args.iterations, + episodes_per_iteration=args.episodes_per_iter, + ) + + if args.plot: + try: + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 5)) + plt.plot(rewards, alpha=0.6, label="Episode Reward") + + window = 50 + if len(rewards) >= window: + moving_avg = np.convolve( + rewards, np.ones(window) / window, mode="valid" + ) + plt.plot( + range(window - 1, len(rewards)), + moving_avg, + color="red", + linewidth=2, + label=f"Moving Avg ({window})", + ) + + plt.title(f"Training Reward Over Episodes ({args.env})") + plt.xlabel("Episode") + plt.ylabel("Total Reward") + plt.legend() + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() + except ImportError: + print("\nNote: matplotlib not installed, skipping plot") + + +if __name__ == "__main__": + main() diff --git a/examples/rl/pyproject.toml b/examples/rl/pyproject.toml new file mode 100644 index 00000000..1d1dd5cb --- /dev/null +++ b/examples/rl/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "rl" +version = "0.1.0" +description = "Distributed Reinforcement Learning Example using Flame Runner" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "torch", + "numpy", + "gymnasium[mujoco]", +] + +[project.optional-dependencies] +plot = ["matplotlib"] + +[dependency-groups] +dev = ["flamepy"] diff --git a/executor_manager/src/appmgr/python.rs b/executor_manager/src/appmgr/python.rs index d2f70acd..6aaee6ea 100644 --- a/executor_manager/src/appmgr/python.rs +++ b/executor_manager/src/appmgr/python.rs @@ -136,6 +136,7 @@ impl Installer for PythonInstaller { let status = tokio::process::Command::new(&uv_cmd) .arg("pip") .arg("install") + .arg("--link-mode=copy") .arg("--target") .arg(&deps_path) .arg(".") diff --git a/flmadm/src/managers/installation.rs b/flmadm/src/managers/installation.rs index dff80ec4..e37c12ed 100644 --- a/flmadm/src/managers/installation.rs +++ b/flmadm/src/managers/installation.rs @@ -276,20 +276,31 @@ if [[ ":$PATH:" != *":{prefix}/bin:"* ]]; then export PATH="{prefix}/bin:$PATH" fi +# UV and pip cache directories (shared across containers) +export UV_CACHE_DIR="$FLAME_HOME/data/cache/uv" +export PIP_CACHE_DIR="$FLAME_HOME/data/cache/pip" +export UV_LINK_MODE=copy + +# Create cache directories if they don't exist +mkdir -p "$UV_CACHE_DIR" 2>/dev/null +mkdir -p "$PIP_CACHE_DIR" 2>/dev/null + # Python environment for flamepy FLAME_SITE_PACKAGES="{site_packages}" +FLAME_LD_DIRS="" if [ -d "$FLAME_SITE_PACKAGES" ]; then if [[ ":$PYTHONPATH:" != *":$FLAME_SITE_PACKAGES:"* ]]; then export PYTHONPATH="$FLAME_SITE_PACKAGES:$PYTHONPATH" fi - # Find pyarrow lib directory for native extensions - PYARROW_DIR=$(find "$FLAME_SITE_PACKAGES" -name "pyarrow" -type d 2>/dev/null | head -1) - if [ -n "$PYARROW_DIR" ] && [ -d "$PYARROW_DIR" ]; then - if [[ ":$LD_LIBRARY_PATH:" != *":$PYARROW_DIR:"* ]]; then - export LD_LIBRARY_PATH="$PYARROW_DIR:$LD_LIBRARY_PATH" + # Find all directories containing shared libraries for native extensions + while IFS= read -r dir; do + abs_dir=$(cd "$dir" 2>/dev/null && pwd) + if [ -n "$abs_dir" ] && [[ ":$LD_LIBRARY_PATH:" != *":$abs_dir:"* ]]; then + export LD_LIBRARY_PATH="$abs_dir:$LD_LIBRARY_PATH" + FLAME_LD_DIRS="$FLAME_LD_DIRS $abs_dir" fi - fi + done < <(find "$FLAME_SITE_PACKAGES" \( -name "*.so" -o -name "*.dylib" \) -type f 2>/dev/null | xargs -n1 dirname | sort -u) fi # Print environment info (only when sourced interactively) @@ -297,9 +308,14 @@ if [[ $- == *i* ]]; then echo "Flame environment loaded:" echo " FLAME_HOME=$FLAME_HOME" echo " PATH includes: {prefix}/bin" + echo " UV_CACHE_DIR=$UV_CACHE_DIR" + echo " PIP_CACHE_DIR=$PIP_CACHE_DIR" if [ -d "$FLAME_SITE_PACKAGES" ]; then echo " PYTHONPATH includes: $FLAME_SITE_PACKAGES" fi + if [ -n "$FLAME_LD_DIRS" ]; then + echo " LD_LIBRARY_PATH includes: $FLAME_LD_DIRS" + fi fi "#, prefix = paths.prefix.display(), From 24a6ba9a3fb3fe7f1f813288ab8fe0ed55bbeb11 Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 28 Apr 2026 16:49:16 +0800 Subject: [PATCH 2/4] refactor(examples/rl): extract shared components to model.py and optimize discounted rewards MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract EnvConfig, ENV_CONFIGS, DiscretePolicy, ContinuousPolicy, create_policy into model.py - Update collect_episode to import from model instead of main (fixes fragile distributed imports) - Optimize compute_discounted_rewards from O(n²) to O(n) using pre-allocated list - Update README.md to document new file structure Addresses PR review feedback from gemini-code-assist --- examples/rl/README.md | 1 + examples/rl/main.py | 99 +++---------------------------------------- examples/rl/model.py | 91 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 93 deletions(-) create mode 100644 examples/rl/model.py diff --git a/examples/rl/README.md b/examples/rl/README.md index c5acd4f1..7b6254e8 100644 --- a/examples/rl/README.md +++ b/examples/rl/README.md @@ -48,6 +48,7 @@ This example implements the REINFORCE (policy gradient) algorithm on environment ### Files - **`main.py`**: REINFORCE training (distributed by default, use `--local` for local mode) +- **`model.py`**: Shared components (policy networks, environment configs) - **`pyproject.toml`**: Package dependencies including `torch`, `gymnasium[mujoco]`, and `flamepy` - **`README.md`**: This documentation file diff --git a/examples/rl/main.py b/examples/rl/main.py index 87d1be03..2dc4885d 100644 --- a/examples/rl/main.py +++ b/examples/rl/main.py @@ -5,99 +5,13 @@ """ import time -from dataclasses import dataclass -from typing import Tuple import gymnasium as gym import numpy as np import torch -import torch.nn as nn import torch.optim as optim -from torch.distributions import Categorical, Normal - -@dataclass -class EnvConfig: - name: str - obs_dim: int - act_dim: int - continuous: bool - max_episode_steps: int - - -ENV_CONFIGS = { - "cartpole": EnvConfig("CartPole-v1", 4, 2, False, 500), - "halfcheetah": EnvConfig("HalfCheetah-v5", 17, 6, True, 1000), - "hopper": EnvConfig("Hopper-v5", 11, 3, True, 1000), - "walker2d": EnvConfig("Walker2d-v5", 17, 6, True, 1000), - "ant": EnvConfig("Ant-v5", 105, 8, True, 1000), -} - - -class DiscretePolicy(nn.Module): - """Policy network for discrete action spaces (CartPole).""" - - def __init__(self, obs_dim: int, act_dim: int): - super().__init__() - self.fc = nn.Sequential( - nn.Linear(obs_dim, 128), - nn.ReLU(), - nn.Linear(128, act_dim), - nn.Softmax(dim=-1), - ) - - def forward(self, x): - return self.fc(x) - - def get_action(self, state: torch.Tensor) -> Tuple[int, torch.Tensor]: - probs = self(state) - m = Categorical(probs) - action = m.sample() - return action.item(), m.log_prob(action) - - def evaluate(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: - probs = self(states) - m = Categorical(probs) - return m.log_prob(actions) - - -class ContinuousPolicy(nn.Module): - """Policy network for continuous action spaces (MuJoCo).""" - - def __init__(self, obs_dim: int, act_dim: int): - super().__init__() - self.fc = nn.Sequential( - nn.Linear(obs_dim, 256), - nn.ReLU(), - nn.Linear(256, 256), - nn.ReLU(), - ) - self.mean = nn.Linear(256, act_dim) - self.log_std = nn.Parameter(torch.zeros(act_dim)) - - def forward(self, x): - x = self.fc(x) - mean = self.mean(x) - std = self.log_std.exp() - return mean, std - - def get_action(self, state: torch.Tensor) -> Tuple[np.ndarray, torch.Tensor]: - mean, std = self(state) - m = Normal(mean, std) - action = m.sample() - log_prob = m.log_prob(action).sum(dim=-1) - return action.squeeze(0).numpy(), log_prob - - def evaluate(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: - mean, std = self(states) - m = Normal(mean, std) - return m.log_prob(actions).sum(dim=-1) - - -def create_policy(env_config: EnvConfig) -> nn.Module: - if env_config.continuous: - return ContinuousPolicy(env_config.obs_dim, env_config.act_dim) - return DiscretePolicy(env_config.obs_dim, env_config.act_dim) +from model import ENV_CONFIGS, create_policy def collect_episode(weights, env_name: str) -> dict: @@ -108,10 +22,9 @@ def collect_episode(weights, env_name: str) -> dict: env_name: Name of the environment to use """ import gymnasium as gym - import numpy as np import torch - from main import ENV_CONFIGS, create_policy + from model import ENV_CONFIGS, create_policy env_config = ENV_CONFIGS[env_name] model = create_policy(env_config) @@ -144,11 +57,11 @@ def collect_episode(weights, env_name: str) -> dict: def compute_discounted_rewards(rewards: list, gamma: float = 0.99) -> torch.Tensor: - discounted = [] + discounted = [0.0] * len(rewards) R = 0 - for r in reversed(rewards): - R = r + gamma * R - discounted.insert(0, R) + for i in range(len(rewards) - 1, -1, -1): + R = rewards[i] + gamma * R + discounted[i] = R discounted = torch.tensor(discounted, dtype=torch.float32) if len(discounted) > 1: discounted = (discounted - discounted.mean()) / (discounted.std() + 1e-8) diff --git a/examples/rl/model.py b/examples/rl/model.py new file mode 100644 index 00000000..e99bed88 --- /dev/null +++ b/examples/rl/model.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +from torch.distributions import Categorical, Normal + + +@dataclass +class EnvConfig: + name: str + obs_dim: int + act_dim: int + continuous: bool + max_episode_steps: int + + +ENV_CONFIGS = { + "cartpole": EnvConfig("CartPole-v1", 4, 2, False, 500), + "halfcheetah": EnvConfig("HalfCheetah-v5", 17, 6, True, 1000), + "hopper": EnvConfig("Hopper-v5", 11, 3, True, 1000), + "walker2d": EnvConfig("Walker2d-v5", 17, 6, True, 1000), + "ant": EnvConfig("Ant-v5", 105, 8, True, 1000), +} + + +class DiscretePolicy(nn.Module): + """Policy network for discrete action spaces (CartPole).""" + + def __init__(self, obs_dim: int, act_dim: int): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(obs_dim, 128), + nn.ReLU(), + nn.Linear(128, act_dim), + nn.Softmax(dim=-1), + ) + + def forward(self, x): + return self.fc(x) + + def get_action(self, state: torch.Tensor) -> Tuple[int, torch.Tensor]: + probs = self(state) + m = Categorical(probs) + action = m.sample() + return action.item(), m.log_prob(action) + + def evaluate(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + probs = self(states) + m = Categorical(probs) + return m.log_prob(actions) + + +class ContinuousPolicy(nn.Module): + """Policy network for continuous action spaces (MuJoCo).""" + + def __init__(self, obs_dim: int, act_dim: int): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(obs_dim, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + ) + self.mean = nn.Linear(256, act_dim) + self.log_std = nn.Parameter(torch.zeros(act_dim)) + + def forward(self, x): + x = self.fc(x) + mean = self.mean(x) + std = self.log_std.exp() + return mean, std + + def get_action(self, state: torch.Tensor) -> Tuple[np.ndarray, torch.Tensor]: + mean, std = self(state) + m = Normal(mean, std) + action = m.sample() + log_prob = m.log_prob(action).sum(dim=-1) + return action.squeeze(0).numpy(), log_prob + + def evaluate(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + mean, std = self(states) + m = Normal(mean, std) + return m.log_prob(actions).sum(dim=-1) + + +def create_policy(env_config: EnvConfig) -> nn.Module: + if env_config.continuous: + return ContinuousPolicy(env_config.obs_dim, env_config.act_dim) + return DiscretePolicy(env_config.obs_dim, env_config.act_dim) From ae241b72dd1b8742e5447f8d8d41d10db6e711f7 Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 28 Apr 2026 16:52:48 +0800 Subject: [PATCH 3/4] fix(examples/rl): explicitly configure py-modules for setuptools Fixes build error: 'Multiple top-level modules discovered in a flat-layout' after adding model.py alongside main.py --- examples/rl/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/rl/pyproject.toml b/examples/rl/pyproject.toml index 1d1dd5cb..1738c421 100644 --- a/examples/rl/pyproject.toml +++ b/examples/rl/pyproject.toml @@ -15,3 +15,6 @@ plot = ["matplotlib"] [dependency-groups] dev = ["flamepy"] + +[tool.setuptools] +py-modules = ["main", "model"] From 5b93995de4dbf72de2075f74576b0b88a8ea9c88 Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Tue, 28 Apr 2026 17:02:16 +0800 Subject: [PATCH 4/4] fix(flmadm): make flmenv.sh robust to empty find results and mkdir failures - Add '|| true' to mkdir commands to prevent failure with 'set -e' - Add '-r' flag to xargs to not run if input is empty - Add empty line check in while loop - Redirect xargs stderr to /dev/null --- flmadm/src/managers/installation.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flmadm/src/managers/installation.rs b/flmadm/src/managers/installation.rs index e37c12ed..4940c252 100644 --- a/flmadm/src/managers/installation.rs +++ b/flmadm/src/managers/installation.rs @@ -282,8 +282,8 @@ export PIP_CACHE_DIR="$FLAME_HOME/data/cache/pip" export UV_LINK_MODE=copy # Create cache directories if they don't exist -mkdir -p "$UV_CACHE_DIR" 2>/dev/null -mkdir -p "$PIP_CACHE_DIR" 2>/dev/null +mkdir -p "$UV_CACHE_DIR" 2>/dev/null || true +mkdir -p "$PIP_CACHE_DIR" 2>/dev/null || true # Python environment for flamepy FLAME_SITE_PACKAGES="{site_packages}" @@ -295,12 +295,13 @@ if [ -d "$FLAME_SITE_PACKAGES" ]; then # Find all directories containing shared libraries for native extensions while IFS= read -r dir; do + [ -z "$dir" ] && continue abs_dir=$(cd "$dir" 2>/dev/null && pwd) if [ -n "$abs_dir" ] && [[ ":$LD_LIBRARY_PATH:" != *":$abs_dir:"* ]]; then export LD_LIBRARY_PATH="$abs_dir:$LD_LIBRARY_PATH" FLAME_LD_DIRS="$FLAME_LD_DIRS $abs_dir" fi - done < <(find "$FLAME_SITE_PACKAGES" \( -name "*.so" -o -name "*.dylib" \) -type f 2>/dev/null | xargs -n1 dirname | sort -u) + done < <(find "$FLAME_SITE_PACKAGES" \( -name "*.so" -o -name "*.dylib" \) -type f 2>/dev/null | xargs -r -n1 dirname 2>/dev/null | sort -u) fi # Print environment info (only when sourced interactively)