A PyTorch plugin that integrates with Jalapeno API to optimize network paths for distributed training using SRv6
This plugin enhances PyTorch's distributed training by:
- Intercepting NCCL communication setup
- Querying Jalapeno API for optimized SRv6 paths
- Programming local SRv6 routes for optimal network paths
- Enabling distributed training with network-aware routing
srv6_plugin.py
: Main plugin that wraps PyTorch's distributed functionalityroute_programmer.py
: Platform-specific route programming (Linux/VPP)controller.py
: Network controller for managing routes and API interactionsdist_setup.py
: Distributed training setup utilitiesdemo/test_dist.py
: Full demo application using containerlab
- Python 3.8+
- PyTorch
- Access to Jalapeno API
- Linux kernel with SRv6 support (for route programming)
- Clone the repository:
git clone https://github.com/segment-routing/srv6-pytorch-plugin.git
cd srv6-pytorch-plugin
- Create and activate a virtual environment (recommended):
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
- Install dependencies:
pip install -r requirements.txt
- Create a
.env
file with your configuration:
JALAPENO_API_ENDPOINT=http://jalapeno-api:8000
TOPOLOGY_COLLECTION=your-collection-name
BACKEND_INTERFACE=eth1
ROUTE_PLATFORM=linux
ROUTE_TABLE_ID=254
HOSTS=host00,host01,host02 # Comma-separated list of hostnames
from srv6_plugin import DemoPlugin
# Initialize with Jalapeno API endpoint
plugin = DemoPlugin(
api_endpoint=os.getenv('JALAPENO_API_ENDPOINT')
)
# Initialize distributed training with network optimization
plugin.init_process_group()
Set backend in dist_setup.py:
# Demo uses:
dist.init_process_group(backend="gloo")
# NCCL:
dist.init_process_group(backend="nccl")
JALAPENO_API_ENDPOINT
: URL of the Jalapeno APITOPOLOGY_COLLECTION
: Name of the topology collection in JalapenoBACKEND_INTERFACE
: Network interface for SRv6 routes (default: eth1)ROUTE_PLATFORM
: Route programming platform (linux/vpp)ROUTE_TABLE_ID
: Routing table ID (default: 254)HOSTS
: Comma-separated list of hostnames for distributed trainingRANK
: Node rank in distributed training (0-based)WORLD_SIZE
: Total number of nodes in distributed trainingMASTER_ADDR
: IP address of the master nodeMASTER_PORT
: Port for distributed training communication
The demo/
directory contains a complete example using containerlab to simulate a network topology with SONiC switches. See demo/readme.md
for detailed instructions.
[PyTorch Distributed Training] ↓ [DemoPlugin] ↓
- Initializes distributed process group
- Collects node information from environment ↓ [Network Controller] ↓
- Queries Jalapeno API for optimized paths
- Gets back SRv6 path information ↓ [Route Programmer] ↓
- Programs local SRv6 routes ↓ [Distributed Training Communication]
- Fork the repository
- Create a feature branch
- Commit your changes
- Push to the branch
- Create a Pull Request