Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ ci/
**/*.egg-info/

# Ignore build related stuff
build/
dist/
**/build/
**/dist/
work/
.mypy_cache/
.ruff_cache/

Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ flame.db*
__pycache__/
*.egg-info/
build/
uv.lock
uv.lock

work/
1 change: 1 addition & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ pub fn default_applications() -> HashMap<String, ApplicationAttributes> {
"-m".to_string(),
"flamepy.rl.runpy".to_string(),
],
working_directory: Some("/opt/flame/work".to_string()),
..ApplicationAttributes::default()
},
),
Expand Down
1 change: 1 addition & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ services:
- ./examples:/opt/examples
- ./e2e:/opt/e2e
- flame-packages:/opt/flame/packages
- ./work:/opt/flame/work

flame-console:
image: xflops/flame-console:latest
Expand Down
166 changes: 166 additions & 0 deletions examples/ps/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Parameter Server Example

This example demonstrates distributed training using the **Parameter Server** pattern with Flame's Python SDK. It trains a simple convolutional neural network on the MNIST dataset using synchronous gradient updates.

## Overview

The parameter server pattern is a classic distributed training architecture where:
- A **Parameter Server** maintains the global model weights and applies gradient updates
- Multiple **Data Workers** compute gradients on different data batches in parallel
- Workers fetch the latest weights, compute gradients, and send them back to the parameter server

This example uses Flame's `flamepy.rl.Runner` to orchestrate the distributed services and handle inter-service communication.

## Architecture

```
┌─────────────────────┐
│ Parameter Server │ - Stores model weights
│ │ - Applies aggregated gradients
└──────────┬──────────┘
┌──────┴──────┐
│ │
┌───▼────┐ ┌───▼────┐
│Worker 1│ │Worker 2│ - Fetch weights
│ │ │ │ - Compute gradients on data batches
└────────┘ └────────┘
```

### Components

1. **ConvNet**: A small convolutional neural network for MNIST digit classification
2. **ParameterServer**: Maintains model state and applies gradient updates using SGD
3. **DataWorker**: Computes gradients on mini-batches from the training dataset
4. **Main Training Loop**: Coordinates synchronous training iterations

## Files

- `main.py`: Entry point that sets up the runner and training loop
- `ps.py`: Implementation of the model, parameter server, and data workers
- `pyproject.toml`: Project dependencies and configuration

## Requirements

- Python >= 3.12
- PyTorch and torchvision
- Flame Python SDK (`flamepy`)
- NumPy
- filelock

## How to Run

1. **Build the Flame cluster** (if not already running):
```bash
docker compose build
docker compose up -d
```

2. **Navigate to the example directory**:
```bash
cd examples/ps
```

3. **Run the example**:
```bash
python main.py
```

The script will automatically download the MNIST dataset on first run and begin training.

## Expected Output

```
100.0%
100.0%
100.0%
100.0%
Running synchronous parameter server training.
Iter 0: accuracy is 16.5
Iter 10: accuracy is 32.9
Final accuracy is 32.9.
```

In this simplified run, the accuracy typically improves from around 10% (random guessing) to roughly 30–35% after 20 training iterations (as shown above). Higher accuracy may require more iterations, a larger model, or full‑dataset evaluation.

## Key Concepts

### 1. Service Creation with Runner

```python
with Runner("ps-example") as rr:
ps_svc = rr.service(ParameterServer(1e-2))
workers_svc = [rr.service(DataWorker) for _ in range(2)]
```

The `Runner` creates and manages distributed services. Services can be instantiated from any Python class.

### 2. Asynchronous Remote Calls

```python
gradients = [worker.compute_gradients(current_weights) for worker in workers_svc]
current_weights = ps_svc.apply_gradients(*gradients).get()
```

Method calls on services return futures. Use `.get()` to block and retrieve the result.

### 3. Synchronous Training

The training loop ensures all workers compute gradients before the parameter server applies updates:

```python
for i in range(20):
# Start all gradient computations in parallel
gradients = [worker.compute_gradients(current_weights) for worker in workers_svc]
# Wait for all gradients and apply update
current_weights = ps_svc.apply_gradients(*gradients).get()
```

This is a **synchronous** parameter server where each iteration waits for all workers.

## Customization

### Adjust Number of Workers

Modify the worker count in `main.py`:

```python
workers_svc = [rr.service(DataWorker) for _ in range(4)] # Use 4 workers
```

### Change Learning Rate

Pass a different learning rate to the ParameterServer:

```python
ps_svc = rr.service(ParameterServer(1e-3)) # Lower learning rate
```

### Increase Training Iterations

Modify the range in the training loop:

```python
for i in range(50): # Train for 50 iterations
```

## Notes

- The example uses **filelock** to safely download MNIST data when multiple workers start simultaneously
- Evaluation is limited to 1024 samples for faster iteration during development
- The model is intentionally small to demonstrate the pattern rather than achieve state-of-the-art accuracy

## Related Examples

- For reinforcement learning examples, see the `examples/rl/` directory
- For more complex distributed patterns, check other examples in `examples/`

## Troubleshooting

If you encounter issues:

1. **Services not starting**: Ensure the Flame cluster is running with `docker compose ps`
2. **Import errors**: Rebuild containers after modifying `sdk/python`: `docker compose build`
3. **Test timeouts**: Check logs with `docker logs flame-executor-manager` and `docker logs flame-session-manager`

For more information, see the [Flame documentation](../../docs/).
Binary file removed examples/ps/dist/ps-example.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/ps/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

with Runner("ps-example") as rr:
ps_svc = rr.service(ParameterServer(1e-2))
workers_svc = [rr.service(DataWorker) for _ in range(4)]
workers_svc = [rr.service(DataWorker) for _ in range(2)]

current_weights = ps_svc.get_weights().get()
for i in range(20):
Expand Down
3 changes: 2 additions & 1 deletion examples/ps/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ dependencies = [
"torch",
"torchvision",
"numpy",
"filelock"
"filelock",
"flamepy",
]

[build-system]
Expand Down
88 changes: 61 additions & 27 deletions executor_manager/src/shims/host_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use async_trait::async_trait;
use hyper_util::rt::TokioIo;
use nix::sys::signal::{killpg, Signal};
use nix::unistd::Pid;
use std::collections::HashMap;
use stdng::{logs::TraceFn, trace_fn};
use tokio::net::UnixStream;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -73,6 +74,63 @@ impl HostShim {
})))
}

fn create_dir(path: &Path, name: &str) -> Result<(), FlameError> {
create_dir_all(path).map_err(|e| {
FlameError::Internal(format!(
"failed to create {} directory {}: {e}",
name,
path.display()
))
})
}

fn setup_working_directory(work_dir: &Path) -> Result<HashMap<String, String>, FlameError> {
trace_fn!("HostShim::setup_working_directory");

let tmp_dir = work_dir.join("tmp");
let uv_cache_dir = work_dir.join(".uv");
let pip_cache_dir = work_dir.join(".pip");

tracing::debug!(
"Working directory of application instance: {}",
work_dir.display()
);
tracing::debug!(
"Temporary directory of application instance: {}",
tmp_dir.display()
);
tracing::debug!(
"UV cache directory of application instance: {}",
uv_cache_dir.display()
);
tracing::debug!(
"PIP cache directory of application instance: {}",
pip_cache_dir.display()
);

// Create the working, temporary, and cache directories if they don't exist
Self::create_dir(&work_dir, "working")?;
Self::create_dir(&tmp_dir, "temporary")?;
Self::create_dir(&uv_cache_dir, "UV cache")?;
Self::create_dir(&pip_cache_dir, "PIP cache")?;

// Build environment variables for the application instance
let mut envs = HashMap::new();
envs.insert("TMPDIR".to_string(), tmp_dir.to_string_lossy().to_string());
envs.insert("TEMP".to_string(), tmp_dir.to_string_lossy().to_string());
envs.insert("TMP".to_string(), tmp_dir.to_string_lossy().to_string());
envs.insert(
"UV_CACHE_DIR".to_string(),
uv_cache_dir.to_string_lossy().to_string(),
);
envs.insert(
"PIP_CACHE_DIR".to_string(),
pip_cache_dir.to_string_lossy().to_string(),
);

Ok(envs)
}

fn launch_instance(
app: &ApplicationContext,
executor: &Executor,
Expand Down Expand Up @@ -109,33 +167,9 @@ impl HostShim {
};

let work_dir = cur_dir.clone();
let tmp_dir = cur_dir.join("tmp");

tracing::debug!(
"Working directory of application instance: {}",
work_dir.display()
);
tracing::debug!(
"Temporary directory of application instance: {}",
tmp_dir.display()
);

// Create the working & temporary directories if they don't exist
create_dir_all(&work_dir).map_err(|e| {
FlameError::Internal(format!(
"failed to create working directory {}: {e}",
work_dir.display()
))
})?;
create_dir_all(&tmp_dir).map_err(|e| {
FlameError::Internal(format!(
"failed to create temporary directory {}: {e}",
tmp_dir.display()
))
})?;

// Set temporary directory for the application instance
envs.insert("TMPDIR".to_string(), tmp_dir.to_string_lossy().to_string());
// Setup working directory and get environment overrides
let wd_envs = Self::setup_working_directory(&work_dir)?;
envs.extend(wd_envs);

let log_out = OpenOptions::new()
.create(true)
Expand Down
18 changes: 2 additions & 16 deletions sdk/python/src/flamepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,6 @@
limitations under the License.
"""

import logging
import os

_log_level_str = os.getenv("FLAME_LOG_LEVEL", "INFO").upper()
_log_level_map = {
"CRITICAL": logging.CRITICAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
}

_log_level = _log_level_map[_log_level_str] if _log_level_str in _log_level_map else logging.INFO

logging.basicConfig(level=_log_level)

# Import submodules for rl and agent (only as submodules)
from . import agent, rl

Expand All @@ -46,6 +30,7 @@
Connection,
Event,
FlameContext,
FlameContextRunner,
FlameError,
FlameErrorCode,
FlamePackage,
Expand Down Expand Up @@ -109,6 +94,7 @@
"Task",
"Application",
"FlamePackage",
"FlameContextRunner",
# Context and utility classes
"TaskInformer",
"FlameContext",
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/src/flamepy/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
FlameService,
SessionContext,
TaskContext,
TaskOutput,
run,
)

Expand All @@ -74,6 +73,7 @@
CommonData,
Event,
FlameContext,
FlameContextRunner,
FlameError,
FlameErrorCode,
FlamePackage,
Expand Down Expand Up @@ -120,6 +120,7 @@
"Task",
"Application",
"FlamePackage",
"FlameContextRunner",
# Context and utility classes
"TaskInformer",
"FlameContext",
Expand Down
Loading
Loading