diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..a6668e4 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,32 @@ +name: Ruff + +on: + push: + branches: ["master"] + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +jobs: + lint: + if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + + - name: Install Ruff + run: pip install ruff + + - name: Format code + run: ruff --config pyproject.toml format . + + - name: Fix lint issues + run: ruff --config pyproject.toml check --fix --exit-zero . + + - name: Ensure no remaining lint issues + run: ruff --config pyproject.toml check . diff --git a/netpulse-client/examples/sdk_test.py b/netpulse-client/examples/sdk_test.py index 83bd96d..e23ca5a 100644 --- a/netpulse-client/examples/sdk_test.py +++ b/netpulse-client/examples/sdk_test.py @@ -1,5 +1,4 @@ -import asyncio -from netpulse_client import NetPulseClient, Device +from netpulse_client import Device, NetPulseClient # 配置信息 ENDPOINT = "http://localhost:9000" @@ -19,7 +18,7 @@ def basic_operations(): print("=== 基础操作示例 ===") - + with NetPulseClient(ENDPOINT, API_KEY) as np_client: # 1. 执行命令 print("\n1. 执行命令") @@ -35,6 +34,5 @@ def basic_operations(): print(result.request_id) - if __name__ == "__main__": - basic_operations() \ No newline at end of file + basic_operations() diff --git a/netpulse-client/netpulse_client/__init__.py b/netpulse-client/netpulse_client/__init__.py index 5cc33be..c682b9e 100644 --- a/netpulse-client/netpulse_client/__init__.py +++ b/netpulse-client/netpulse_client/__init__.py @@ -1,11 +1,11 @@ """ NetPulse Client - 网络设备自动化客户端 -提供同步和异步的网络设备操作接口,支持命令执行和配置推送。 +提供同步和异步的网络设备操作接口, 支持命令执行和配置推送。 -核心方法(与API端点对应): +核心方法 (与API端点对应) : - exec_command(): 同步执行命令 -> /device/execute -- exec_config(): 同步推送配置 -> /device/execute +- exec_config(): 同步推送配置 -> /device/execute - bulk_command(): 同步批量执行命令 -> /device/bulk - bulk_config(): 同步批量推送配置 -> /device/bulk - aexec_command(): 异步执行命令 -> /device/execute @@ -23,34 +23,34 @@ """ # 核心客户端类 +from .async_client import AsyncJobHandle, AsyncNetPulseClient from .client import NetPulseClient -from .async_client import AsyncNetPulseClient, AsyncJobHandle - -# 数据模型 -from .models import ( - ConnectionArgs, # 主要的连接参数模型 - Device, # 向后兼容的Device别名 - CommandResult, - ConfigResult, - BatchResult, - JobInfo, - WorkerInfo, - HealthCheckResult, - ConnectionTestResult, -) - -# 工具函数 -from .models import create_device_request, create_batch_device_request # 异常类 from .exceptions import ( - NetPulseError, AuthenticationError, ConnectionError, JobError, + NetPulseError, + SDKValidationError, TimeoutError, ValidationError, - SDKValidationError, +) + +# 数据模型 +# 工具函数 +from .models import ( + BatchResult, + CommandResult, + ConfigResult, + ConnectionArgs, # 主要的连接参数模型 + ConnectionTestResult, + Device, # 向后兼容的Device别名 + HealthCheckResult, + JobInfo, + WorkerInfo, + create_batch_device_request, + create_device_request, ) __version__ = "0.1.0" @@ -60,7 +60,6 @@ "NetPulseClient", "AsyncNetPulseClient", "AsyncJobHandle", - # 数据模型 "ConnectionArgs", # 主要的连接参数模型 "Device", # 向后兼容的Device别名 @@ -71,11 +70,9 @@ "WorkerInfo", "HealthCheckResult", "ConnectionTestResult", - # 工具函数 "create_device_request", "create_batch_device_request", - # 异常类 "NetPulseError", "AuthenticationError", @@ -84,4 +81,4 @@ "TimeoutError", "ValidationError", "SDKValidationError", -] \ No newline at end of file +] diff --git a/netpulse-client/netpulse_client/async_client.py b/netpulse-client/netpulse_client/async_client.py index 0e12b13..d6d6dac 100644 --- a/netpulse-client/netpulse_client/async_client.py +++ b/netpulse-client/netpulse_client/async_client.py @@ -6,13 +6,12 @@ """ import asyncio -import time import logging -from typing import Any, Dict, List, Optional, Union, Callable +import time +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urljoin import httpx -from pydantic import ValidationError from .exceptions import ( AuthenticationError, @@ -20,32 +19,32 @@ JobError, NetPulseError, TimeoutError, +) +from .exceptions import ( ValidationError as SDKValidationError, ) from .models import ( - ConnectionConfig, - ConnectionArgs, + AsyncJobHandle, CommandResult, ConfigResult, - BatchResult, - JobInfo, - WorkerInfo, - HealthCheckResult, + ConnectionArgs, + ConnectionConfig, ConnectionTestResult, + HealthCheckResult, + JobInfo, OperationStatus, - JobStatus, - WorkerState, - AsyncJobHandle, - create_device_request, + WorkerInfo, create_batch_device_request, + create_device_request, ) # 配置日志 logger = logging.getLogger(__name__) + class AsyncNetPulseClient: """NetPulse 异步客户端""" - + def __init__( self, endpoint: str, @@ -57,7 +56,7 @@ def __init__( ): """ 初始化异步客户端 - + Args: endpoint: NetPulse API端点 api_key: API密钥 @@ -67,7 +66,7 @@ def __init__( verify_ssl: 是否验证SSL证书 """ self.config = ConnectionConfig( - endpoint=endpoint.rstrip('/'), + endpoint=endpoint.rstrip("/"), api_key=api_key, timeout=timeout, max_retries=3, @@ -80,27 +79,27 @@ def __init__( "Content-Type": "application/json", "User-Agent": "NetPulse-Client/0.1.0", } - + # 异步客户端配置 self.limits = httpx.Limits( max_connections=max_concurrent, max_keepalive_connections=20, ) - + # 任务管理 self._active_jobs: Dict[str, AsyncJobHandle] = {} self._job_callbacks: Dict[str, List[Callable]] = {} self._client: Optional[httpx.AsyncClient] = None - + async def __aenter__(self): """异步上下文管理器入口""" await self._ensure_client() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """异步上下文管理器出口""" await self.close() - + async def _ensure_client(self): """确保客户端已初始化""" if self._client is None: @@ -111,42 +110,37 @@ async def _ensure_client(self): verify=self.config.verify_ssl, headers=self.headers, ) - + async def close(self): """关闭客户端""" if self._client: await self._client.aclose() self._client = None - + def _build_url(self, path: str) -> str: """构建完整URL""" return urljoin(self.config.endpoint, path) - + def _handle_response(self, response: httpx.Response) -> Dict[str, Any]: """处理HTTP响应""" try: data = response.json() except Exception as e: raise SDKValidationError(f"Invalid JSON response: {e}") - + if response.status_code == 401: raise AuthenticationError("Invalid API key") elif response.status_code == 403: raise AuthenticationError("Insufficient permissions") elif response.status_code >= 500: raise ConnectionError(f"Server error: {response.status_code}") - + return data - - async def _make_request( - self, - method: str, - url: str, - **kwargs - ) -> Dict[str, Any]: + + async def _make_request(self, method: str, url: str, **kwargs) -> Dict[str, Any]: """发送异步HTTP请求""" await self._ensure_client() - + try: response = await self._client.request(method, url, **kwargs) response.raise_for_status() @@ -162,18 +156,18 @@ async def _make_request( raise ConnectionError(f"HTTP error: {e.response.status_code}") except httpx.ConnectError as e: raise ConnectionError(f"Request failed: {e}") - + async def _submit_job(self, url: str, payload: Dict[str, Any]) -> str: """提交任务并返回job_id""" data = await self._make_request("POST", url, json=payload) - + if data.get("code") != 0: raise JobError(f"API error: {data.get('message')}") - + return data["data"]["id"] - + # ========== 核心方法 ========== - + async def execute( self, device: ConnectionArgs, @@ -187,19 +181,19 @@ async def execute( callback: Optional[Callable] = None, ) -> Union[CommandResult, ConfigResult, AsyncJobHandle]: """ - 在设备上执行操作(命令或配置)- 对应 /device/execute 端点 - + 在设备上执行操作 (命令或配置) - 对应 /device/execute 端点 + Args: device: 目标设备 - command: 要执行的命令(Pull操作) - config: 要推送的配置(Push操作) + command: 要执行的命令 (Pull操作) + config: 要推送的配置 (Push操作) parse_with: 解析器类型 save: 是否保存配置 dry_run: 是否为干运行模式 wait: 是否等待完成 timeout: 超时时间 callback: 完成回调函数 - + Returns: CommandResult/ConfigResult: 如果wait=True AsyncJobHandle: 如果wait=False @@ -208,7 +202,7 @@ async def execute( raise ValueError("必须指定 command 或 config 参数") if command and config: raise ValueError("command 和 config 参数不能同时指定") - + # 构建设备连接参数 connection_args = { "host": device.host, @@ -220,26 +214,26 @@ async def execute( connection_args["port"] = device.port if device.timeout: connection_args["timeout"] = device.timeout - + # 构建驱动参数 driver_args = {} if save: driver_args["save"] = save if dry_run: driver_args["dry_run"] = dry_run - + # 构建请求选项 options = { "queue_strategy": "pinned", "ttl": timeout, } - + if parse_with: options["parsing"] = { "name": parse_with, "template": f"file:///templates/{(command or config).replace(' ', '_') if isinstance((command or config), str) else (command or config)[0].replace(' ', '_')}.{parse_with}", } - + # 创建请求 if command: request_data = create_device_request( @@ -259,13 +253,12 @@ async def execute( options=options, ) operation_type = "config" - + # 提交任务 job_id = await self._submit_job( - self._build_url("/device/execute"), - request_data.model_dump(exclude_none=True) + self._build_url("/device/execute"), request_data.model_dump(exclude_none=True) ) - + if not wait: # 创建任务句柄 handle = AsyncJobHandle( @@ -273,22 +266,24 @@ async def execute( task_type=f"single_{operation_type}", device_hosts=[device.host], submitted_at=time.time(), - timeout=timeout + timeout=timeout, ) - + # 注册任务 self._active_jobs[job_id] = handle if callback: self._job_callbacks[job_id] = [callback] - + return handle - + # 等待任务完成 result = await self._wait_for_job(job_id, timeout) - + if command: return CommandResult( - status=OperationStatus.SUCCESS if result["status"] == "completed" else OperationStatus.FAILED, + status=OperationStatus.SUCCESS + if result["status"] == "completed" + else OperationStatus.FAILED, data=result.get("result", {}).get("retval"), job_id=job_id, device_host=device.host, @@ -297,13 +292,15 @@ async def execute( ) else: return ConfigResult( - status=OperationStatus.SUCCESS if result["status"] == "completed" else OperationStatus.FAILED, + status=OperationStatus.SUCCESS + if result["status"] == "completed" + else OperationStatus.FAILED, job_id=job_id, device_host=device.host, error=result.get("error"), execution_time=result.get("execution_time"), ) - + async def bulk( self, devices: List[ConnectionArgs], @@ -317,17 +314,17 @@ async def bulk( ) -> AsyncJobHandle: """ 批量设备操作 - 对应 /device/bulk 端点 - + Args: devices: 设备列表 - command: 要执行的命令(Pull操作) - config: 要推送的配置(Push操作) + command: 要执行的命令 (Pull操作) + config: 要推送的配置 (Push操作) parse_with: 解析器类型 save: 是否保存配置 dry_run: 是否为干运行模式 timeout: 超时时间 callback: 完成回调函数 - + Returns: AsyncJobHandle: 异步任务句柄 """ @@ -335,10 +332,10 @@ async def bulk( raise ValueError("必须指定 command 或 config 参数") if command and config: raise ValueError("command 和 config 参数不能同时指定") - + if not devices: raise ValueError("设备列表不能为空") - + # 构建设备列表 device_list = [] device_hosts = [] @@ -354,33 +351,33 @@ async def bulk( device_dict["timeout"] = device.timeout device_list.append(device_dict) device_hosts.append(device.host) - + # 构建连接参数 connection_args = { "device_type": devices[0].device_type, "timeout": 30, "keepalive": 120, } - + # 构建驱动参数 driver_args = {} if save: driver_args["save"] = save if dry_run: driver_args["dry_run"] = dry_run - + # 构建请求选项 options = { "queue_strategy": "pinned", "ttl": timeout, } - + if parse_with: options["parsing"] = { "name": parse_with, "template": f"file:///templates/{(command or config).replace(' ', '_') if isinstance((command or config), str) else (command or config)[0].replace(' ', '_')}.{parse_with}", } - + # 创建请求 if command: request_data = create_batch_device_request( @@ -402,47 +399,42 @@ async def bulk( options=options, ) operation_type = "batch_config" - + # 提交任务 job_id = await self._submit_job( - self._build_url("/device/bulk"), - request_data.model_dump(exclude_none=True) + self._build_url("/device/bulk"), request_data.model_dump(exclude_none=True) ) - + # 创建任务句柄 handle = AsyncJobHandle( job_id=job_id, task_type=operation_type, device_hosts=device_hosts, submitted_at=time.time(), - timeout=timeout + timeout=timeout, ) - + # 注册任务 self._active_jobs[job_id] = handle if callback: self._job_callbacks[job_id] = [callback] - + return handle - + # ========== 任务管理方法 ========== - + async def get_job_status(self, job_id: str) -> Optional[JobInfo]: """获取任务状态""" try: - data = await self._make_request( - "GET", - self._build_url("/job"), - params={"id": job_id} - ) - + data = await self._make_request("GET", self._build_url("/job"), params={"id": job_id}) + if data.get("code") != 0: raise JobError(f"API error: {data.get('message')}") - + job_list = data.get("data", []) if not job_list: return None - + job_data = job_list[0] return JobInfo( job_id=job_data.get("id", ""), @@ -454,47 +446,49 @@ async def get_job_status(self, job_id: str) -> Optional[JobInfo]: finished_at=job_data.get("ended_at"), result=job_data.get("result"), ) - + except Exception as e: raise JobError(f"Failed to get job status: {e}") - + async def wait_for_job( - self, - handle: AsyncJobHandle, - poll_interval: float = 0.4 + self, handle: AsyncJobHandle, poll_interval: float = 0.4 ) -> Union[CommandResult, ConfigResult]: """等待任务完成""" start_time = time.time() - + while time.time() - start_time < handle.timeout: try: job_info = await self.get_job_status(handle.job_id) - + if not job_info: raise JobError(f"Job {handle.job_id} not found") - + if job_info.status in ["finished", "failed"]: - # 任务完成,执行回调 + # 任务完成, 执行回调 if handle.job_id in self._job_callbacks: for callback in self._job_callbacks[handle.job_id]: try: - await callback(job_info) if asyncio.iscoroutinefunction(callback) else callback(job_info) + await callback(job_info) if asyncio.iscoroutinefunction( + callback + ) else callback(job_info) except Exception as e: print(f"Callback error: {e}") - + # 从活动任务中移除 if handle.job_id in self._active_jobs: del self._active_jobs[handle.job_id] if handle.job_id in self._job_callbacks: del self._job_callbacks[handle.job_id] - + # 根据任务类型返回结果 if handle.task_type == "command": return CommandResult( status=job_info.status, data=job_info.result.get("retval") if job_info.result else None, job_id=handle.job_id, - device_host=handle.device_hosts[0] if handle.device_hosts else "unknown", + device_host=handle.device_hosts[0] + if handle.device_hosts + else "unknown", error=job_info.result.get("error") if job_info.result else None, execution_time=time.time() - start_time, ) @@ -502,40 +496,40 @@ async def wait_for_job( return ConfigResult( status=job_info.status, job_id=handle.job_id, - device_host=handle.device_hosts[0] if handle.device_hosts else "unknown", + device_host=handle.device_hosts[0] + if handle.device_hosts + else "unknown", error=job_info.result.get("error") if job_info.result else None, execution_time=time.time() - start_time, ) - + await asyncio.sleep(poll_interval) - + except Exception as e: raise JobError(f"Failed to wait for job: {e}") - + raise TimeoutError(f"Job {handle.job_id} timed out after {handle.timeout} seconds") - + async def cancel_job(self, job_id: str) -> bool: """取消任务""" try: data = await self._make_request( - "DELETE", - self._build_url("/job"), - params={"id": job_id} + "DELETE", self._build_url("/job"), params={"id": job_id} ) - + # 从活动任务中移除 if job_id in self._active_jobs: del self._active_jobs[job_id] if job_id in self._job_callbacks: del self._job_callbacks[job_id] - + return data.get("code") == 0 - + except Exception as e: raise JobError(f"Failed to cancel job: {e}") - + # ========== 工作节点管理方法 ========== - + async def get_workers( self, queue: Optional[str] = None, @@ -547,52 +541,50 @@ async def get_workers( params["q_name"] = queue if node: params["node"] = node - + try: - data = await self._make_request( - "GET", - self._build_url("/worker"), - params=params - ) - + data = await self._make_request("GET", self._build_url("/worker"), params=params) + if data.get("code") != 0: raise JobError(f"API error: {data.get('message')}") - + workers = [] for worker_data in data.get("data", []): - workers.append(WorkerInfo( - name=worker_data.get("name", ""), - status=worker_data.get("status", ""), - pid=worker_data.get("pid"), - hostname=worker_data.get("hostname"), - queues=worker_data.get("queues"), - last_heartbeat=worker_data.get("last_heartbeat"), - birth_at=worker_data.get("birth_at"), - successful_job_count=worker_data.get("successful_job_count"), - failed_job_count=worker_data.get("failed_job_count"), - )) - + workers.append( + WorkerInfo( + name=worker_data.get("name", ""), + status=worker_data.get("status", ""), + pid=worker_data.get("pid"), + hostname=worker_data.get("hostname"), + queues=worker_data.get("queues"), + last_heartbeat=worker_data.get("last_heartbeat"), + birth_at=worker_data.get("birth_at"), + successful_job_count=worker_data.get("successful_job_count"), + failed_job_count=worker_data.get("failed_job_count"), + ) + ) + return workers - + except Exception as e: raise JobError(f"Failed to get workers: {e}") - + # ========== 系统管理方法 ========== - + async def health_check(self) -> HealthCheckResult: """健康检查""" try: data = await self._make_request("GET", self._build_url("/health")) - + return HealthCheckResult( status=data.get("status", "unknown"), version=data.get("version"), uptime=data.get("uptime"), ) - + except Exception as e: raise ConnectionError(f"Health check failed: {e}") - + async def test_connection( self, device: ConnectionArgs, @@ -611,7 +603,7 @@ async def test_connection( connection_args["port"] = device.port if device.timeout: connection_args["timeout"] = device.timeout - + # 创建连接测试请求 request_data = create_device_request( driver="netmiko", @@ -619,118 +611,128 @@ async def test_connection( command="show version", # 使用简单命令测试连接 options={"ttl": timeout}, ) - + data = await self._make_request( "POST", self._build_url("/device/test"), - json=request_data.model_dump(exclude_none=True) + json=request_data.model_dump(exclude_none=True), ) - + if data.get("code") != 0: return ConnectionTestResult( success=False, message=data.get("message", "Connection test failed"), details=data.get("data"), ) - + return ConnectionTestResult( success=True, message="Connection successful", details=data.get("data"), ) - + except Exception as e: return ConnectionTestResult( success=False, message=f"Connection test failed: {e}", details=None, ) - + # ========== 任务监控方法 ========== - + def get_active_jobs(self) -> List[AsyncJobHandle]: """获取活动任务列表""" return list(self._active_jobs.values()) - + def get_job_count(self) -> int: """获取活动任务数量""" return len(self._active_jobs) - + async def monitor_jobs(self, interval: float = 5.0): """监控所有活动任务""" while self._active_jobs: completed_jobs = [] - + for job_id, handle in self._active_jobs.items(): if handle.is_expired: completed_jobs.append(job_id) continue - + try: job_info = await self.get_job_status(job_id) if job_info and job_info.status in ["finished", "failed"]: completed_jobs.append(job_id) - + # 执行回调 if job_id in self._job_callbacks: for callback in self._job_callbacks[job_id]: try: - await callback(job_info) if asyncio.iscoroutinefunction(callback) else callback(job_info) + await callback(job_info) if asyncio.iscoroutinefunction( + callback + ) else callback(job_info) except Exception as e: print(f"Callback error: {e}") except Exception as e: print(f"Error monitoring job {job_id}: {e}") - + # 移除已完成的任务 for job_id in completed_jobs: if job_id in self._active_jobs: del self._active_jobs[job_id] if job_id in self._job_callbacks: del self._job_callbacks[job_id] - + await asyncio.sleep(interval) - + def add_job_callback(self, job_id: str, callback: Callable): """添加任务回调函数""" if job_id not in self._job_callbacks: self._job_callbacks[job_id] = [] self._job_callbacks[job_id].append(callback) - + # ========== 兼容性方法 ========== - - async def execute_command(self, device: ConnectionArgs, command: Union[str, List[str]], **kwargs) -> Union[CommandResult, AsyncJobHandle]: - """兼容性方法:执行命令""" + + async def execute_command( + self, device: ConnectionArgs, command: Union[str, List[str]], **kwargs + ) -> Union[CommandResult, AsyncJobHandle]: + """兼容性方法: 执行命令""" return await self.execute(device, command=command, **kwargs) - - async def push_config(self, device: ConnectionArgs, config: Union[str, List[str]], **kwargs) -> Union[ConfigResult, AsyncJobHandle]: - """兼容性方法:推送配置""" + + async def push_config( + self, device: ConnectionArgs, config: Union[str, List[str]], **kwargs + ) -> Union[ConfigResult, AsyncJobHandle]: + """兼容性方法: 推送配置""" return await self.execute(device, config=config, **kwargs) - - async def batch_execute(self, devices: List[ConnectionArgs], command: Union[str, List[str]], **kwargs) -> AsyncJobHandle: - """兼容性方法:批量执行命令""" + + async def batch_execute( + self, devices: List[ConnectionArgs], command: Union[str, List[str]], **kwargs + ) -> AsyncJobHandle: + """兼容性方法: 批量执行命令""" return await self.bulk(devices, command=command, **kwargs) - - async def batch_config(self, devices: List[ConnectionArgs], config: Union[str, List[str]], **kwargs) -> AsyncJobHandle: - """兼容性方法:批量推送配置""" + + async def batch_config( + self, devices: List[ConnectionArgs], config: Union[str, List[str]], **kwargs + ) -> AsyncJobHandle: + """兼容性方法: 批量推送配置""" return await self.bulk(devices, config=config, **kwargs) - async def _wait_for_result_with_progressive_retry(self, job_id: str, - initial_delay: float = 0.4, - max_total_time: float = 120.0) -> Dict: + async def _wait_for_result_with_progressive_retry( + self, job_id: str, initial_delay: float = 0.4, max_total_time: float = 120.0 + ) -> Dict: """ 异步激进递进式重试等待任务结果 - - 轮询策略: + + 轮询策略: - 初始延迟: 0.4秒 - 递增步长序列: 0.1s -> 0.2s -> 0.3s -> 0.5s -> 1.5s -> 2.5s -> 4.0s -> 6.0s -> 9.0s -> 13.5s -> 20.0s -> 30.0s - 轮询间隔: 0.4s -> 0.5s -> 0.7s -> 1.0s -> 1.5s -> 3.0s -> 5.5s -> 9.5s -> 15.5s -> 24.5s -> 37.5s -> 57.5s - 最大间隔: 30秒 - 最大总时长: 120秒 - - 示例轮询序列: + + 示例轮询序列: 第1次: 0.4s (初始) 第2次: 0.5s (0.4s + 0.1s) - 第3次: 0.7s (0.5s + 0.2s) + 第3次: 0.7s (0.5s + 0.2s) 第4次: 1.0s (0.7s + 0.3s) 第5次: 1.5s (1.0s + 0.5s) 第6次: 3.0s (1.5s + 1.5s) @@ -745,125 +747,160 @@ async def _wait_for_result_with_progressive_retry(self, job_id: str, delay = initial_delay total_elapsed = 0.0 attempt = 0 - + # 定义递增步长序列 step_sequence = [0.1, 0.2, 0.3, 0.5, 1.5, 2.5, 4.0, 6.0, 9.0, 13.5, 20.0, 30.0] step_index = 0 - + while total_elapsed < max_total_time: try: - result = await self._make_request("GET", self._build_url("/job"), params={"id": job_id}) - + result = await self._make_request( + "GET", self._build_url("/job"), params={"id": job_id} + ) + if result.get("code") == 0 and result.get("data"): - job_data = result["data"][0] if isinstance(result["data"], list) else result["data"] + job_data = ( + result["data"][0] if isinstance(result["data"], list) else result["data"] + ) status = job_data.get("status") - + if status in ["finished", "failed"]: - logger.info(f"任务 {job_id} 完成,总耗时: {total_elapsed:.2f}秒,轮询次数: {attempt}") + logger.info( + f"任务 {job_id} 完成, 总耗时: {total_elapsed:.2f}秒, 轮询次数: {attempt}" + ) return result elif status == "queued": - logger.info(f"任务 {job_id} 排队中... (第{attempt}次轮询,已耗时{total_elapsed:.2f}秒)") + logger.info( + f"任务 {job_id} 排队中... (第{attempt}次轮询, 已耗时{total_elapsed:.2f}秒)" + ) elif status == "started": - logger.info(f"任务 {job_id} 执行中... (第{attempt}次轮询,已耗时{total_elapsed:.2f}秒)") - + logger.info( + f"任务 {job_id} 执行中... (第{attempt}次轮询, 已耗时{total_elapsed:.2f}秒)" + ) + # 激进递进式延迟 await asyncio.sleep(delay) total_elapsed += delay attempt += 1 - + # 计算下一次延迟 if step_index < len(step_sequence): next_step = step_sequence[step_index] step_index += 1 else: - # 如果步长序列用完,使用最后一个步长 + # 如果步长序列用完, 使用最后一个步长 next_step = step_sequence[-1] - + delay = min(delay + next_step, 30.0) # 限制最大间隔为30秒 - + # 记录轮询信息 if attempt % 3 == 0: # 每3次轮询记录一次详细信息 - logger.info(f"任务 {job_id} 轮询中... 第{attempt}次,当前延迟: {delay:.2f}s,总耗时: {total_elapsed:.2f}s") - + logger.info( + f"任务 {job_id} 轮询中... 第{attempt}次, 当前延迟: {delay:.2f}s, 总耗时: {total_elapsed:.2f}s" + ) + except Exception as e: logger.warning(f"查询任务状态失败 (尝试 {attempt + 1}): {e}") await asyncio.sleep(delay) total_elapsed += delay attempt += 1 - + if step_index < len(step_sequence): next_step = step_sequence[step_index] step_index += 1 else: next_step = step_sequence[-1] - + delay = min(delay + next_step, 30.0) - - raise TimeoutError(f"等待任务 {job_id} 完成超时,总耗时: {total_elapsed:.2f}秒,轮询次数: {attempt}") - + + raise TimeoutError( + f"等待任务 {job_id} 完成超时, 总耗时: {total_elapsed:.2f}秒, 轮询次数: {attempt}" + ) + # ==================== 异步方法 ==================== - - async def aexec_command(self, device: Union[ConnectionArgs, Dict], command: Union[str, List[str]], driver: Optional[str] = None, **kwargs) -> CommandResult: + + async def aexec_command( + self, + device: Union[ConnectionArgs, Dict], + command: Union[str, List[str]], + driver: Optional[str] = None, + **kwargs, + ) -> CommandResult: """ - 异步执行命令,支持ConnectionArgs实例或dict - driver: 可选,临时覆盖实例driver + 异步执行命令, 支持ConnectionArgs实例或dict + driver: 可选, 临时覆盖实例driver """ - if hasattr(device, 'model_dump'): + if hasattr(device, "model_dump"): connection_args = device.model_dump() else: connection_args = device use_driver = driver or self.driver data = create_device_request( - driver=use_driver, - connection_args=connection_args, - command=command, - **kwargs + driver=use_driver, connection_args=connection_args, command=command, **kwargs ) result = await self._make_request("POST", self._build_url("/device/execute"), json=data) if result.get("code") == 0 and result.get("data"): job_id = result["data"]["id"] job_result = await self._wait_for_result_with_progressive_retry(job_id) - job_data = job_result["data"][0] if isinstance(job_result["data"], list) else job_result["data"] - + job_data = ( + job_result["data"][0] + if isinstance(job_result["data"], list) + else job_result["data"] + ) + # 直接返回API的完整结构 return CommandResult(**job_data) raise NetPulseError(f"Command execution failed: {result}") - async def aexec_config(self, device: Union[ConnectionArgs, Dict], config: Union[str, List[str], Dict], driver: Optional[str] = None, **kwargs) -> ConfigResult: + async def aexec_config( + self, + device: Union[ConnectionArgs, Dict], + config: Union[str, List[str], Dict], + driver: Optional[str] = None, + **kwargs, + ) -> ConfigResult: """ - 异步推送配置,支持ConnectionArgs实例或dict - driver: 可选,临时覆盖实例driver + 异步推送配置, 支持ConnectionArgs实例或dict + driver: 可选, 临时覆盖实例driver """ - if hasattr(device, 'model_dump'): + if hasattr(device, "model_dump"): connection_args = device.model_dump() else: connection_args = device use_driver = driver or self.driver data = create_device_request( - driver=use_driver, - connection_args=connection_args, - config=config, - **kwargs + driver=use_driver, connection_args=connection_args, config=config, **kwargs ) result = await self._make_request("POST", self._build_url("/device/execute"), json=data) if result.get("code") == 0 and result.get("data"): job_id = result["data"]["id"] job_result = await self._wait_for_result_with_progressive_retry(job_id) - job_data = job_result["data"][0] if isinstance(job_result["data"], list) else job_result["data"] - + job_data = ( + job_result["data"][0] + if isinstance(job_result["data"], list) + else job_result["data"] + ) + # 直接返回API的完整结构 return ConfigResult(**job_data) raise NetPulseError(f"Config execution failed: {result}") - - async def abulk_command(self, driver: Optional[str] = None, devices: List[Dict] = None, connection_args: Dict = None, command: Union[str, List[str]] = None, **kwargs) -> Dict: - """异步批量执行命令,driver可选,默认self.driver""" + + async def abulk_command( + self, + driver: Optional[str] = None, + devices: Optional[List[Dict]] = None, + connection_args: Optional[Dict] = None, + command: Union[None, str, List[str]] = None, + **kwargs, + ) -> Dict: + """异步批量执行命令, driver可选, 默认self.driver""" use_driver = driver or self.driver data = create_batch_device_request( driver=use_driver, devices=devices, connection_args=connection_args, command=command, - **kwargs + **kwargs, ) result = await self._make_request("POST", self._build_url("/device/bulk"), json=data) if result.get("code") == 0 and result.get("data"): @@ -875,16 +912,23 @@ async def abulk_command(self, driver: Optional[str] = None, devices: List[Dict] tasks.append(self._wait_for_result_with_progressive_retry(job_id)) await asyncio.gather(*tasks, return_exceptions=True) return result - - async def abulk_config(self, driver: Optional[str] = None, devices: List[Dict] = None, connection_args: Dict = None, config: Union[str, List[str], Dict] = None, **kwargs) -> Dict: - """异步批量推送配置,driver可选,默认self.driver""" + + async def abulk_config( + self, + driver: Optional[str] = None, + devices: Optional[List[Dict]] = None, + connection_args: Optional[Dict] = None, + config: Union[None, str, List[str], Dict] = None, + **kwargs, + ) -> Dict: + """异步批量推送配置, driver可选, 默认self.driver""" use_driver = driver or self.driver data = create_batch_device_request( driver=use_driver, devices=devices, connection_args=connection_args, config=config, - **kwargs + **kwargs, ) result = await self._make_request("POST", self._build_url("/device/bulk"), json=data) if result.get("code") == 0 and result.get("data"): @@ -896,12 +940,17 @@ async def abulk_config(self, driver: Optional[str] = None, devices: List[Dict] = tasks.append(self._wait_for_result_with_progressive_retry(job_id)) await asyncio.gather(*tasks, return_exceptions=True) return result - + # ==================== 异步任务管理 ==================== - - async def aget_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = None, - status: Optional[str] = None, node: Optional[str] = None, - host: Optional[str] = None) -> Dict: + + async def aget_jobs( + self, + job_id: Optional[str] = None, + queue: Optional[str] = None, + status: Optional[str] = None, + node: Optional[str] = None, + host: Optional[str] = None, + ) -> Dict: """异步获取任务列表""" params = {} if job_id: @@ -914,11 +963,12 @@ async def aget_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = N params["node"] = node if host: params["host"] = host - + return await self._make_request("GET", self._build_url("/job"), params=params) - - async def adelete_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = None, - host: Optional[str] = None) -> Dict: + + async def adelete_jobs( + self, job_id: Optional[str] = None, queue: Optional[str] = None, host: Optional[str] = None + ) -> Dict: """异步删除任务""" params = {} if job_id: @@ -927,13 +977,14 @@ async def adelete_jobs(self, job_id: Optional[str] = None, queue: Optional[str] params["queue"] = queue if host: params["host"] = host - + return await self._make_request("DELETE", self._build_url("/job"), params=params) - + # ==================== 异步Worker管理 ==================== - - async def aget_workers(self, queue: Optional[str] = None, node: Optional[str] = None, - host: Optional[str] = None) -> Dict: + + async def aget_workers( + self, queue: Optional[str] = None, node: Optional[str] = None, host: Optional[str] = None + ) -> Dict: """异步获取Worker列表""" params = {} if queue: @@ -942,11 +993,16 @@ async def aget_workers(self, queue: Optional[str] = None, node: Optional[str] = params["node"] = node if host: params["host"] = host - + return await self._make_request("GET", self._build_url("/worker"), params=params) - - async def adelete_workers(self, name: Optional[str] = None, queue: Optional[str] = None, - node: Optional[str] = None, host: Optional[str] = None) -> Dict: + + async def adelete_workers( + self, + name: Optional[str] = None, + queue: Optional[str] = None, + node: Optional[str] = None, + host: Optional[str] = None, + ) -> Dict: """异步删除Worker""" params = {} if name: @@ -957,62 +1013,55 @@ async def adelete_workers(self, name: Optional[str] = None, queue: Optional[str] params["node"] = node if host: params["host"] = host - + return await self._make_request("DELETE", self._build_url("/worker"), params=params) - + # ==================== 异步健康检测 ==================== - + async def ahealth_check(self) -> HealthCheckResult: """健康检测""" result = await self._make_request("GET", self._build_url("/health")) if result.get("code") == 0 and result.get("data"): return HealthCheckResult(**result["data"]) raise NetPulseError(f"Health check failed: {result}") - + # ==================== 异步连接测试 ==================== - - async def atest_connection(self, device: Union[ConnectionArgs, Dict], driver: Optional[str] = None) -> ConnectionTestResult: - """测试设备连接,支持ConnectionArgs实例或dict""" - if hasattr(device, 'model_dump'): + + async def atest_connection( + self, device: Union[ConnectionArgs, Dict], driver: Optional[str] = None + ) -> ConnectionTestResult: + """测试设备连接, 支持ConnectionArgs实例或dict""" + if hasattr(device, "model_dump"): connection_args = device.model_dump() else: connection_args = device use_driver = driver or self.driver - data = { - "driver": use_driver, - "connection_args": connection_args - } - result = await self._make_request("POST", self._build_url("/device/test-connection"), json=data) + data = {"driver": use_driver, "connection_args": connection_args} + result = await self._make_request( + "POST", self._build_url("/device/test-connection"), json=data + ) if result.get("code") == 0 and result.get("data"): return ConnectionTestResult(**result["data"]) raise NetPulseError(f"Connection test failed: {result}") - + # ==================== 异步模板管理 ==================== - - async def arender_template(self, name: str, template: str, context: Optional[Dict] = None) -> Dict: + + async def arender_template( + self, name: str, template: str, context: Optional[Dict] = None + ) -> Dict: """异步渲染模板""" - data = { - "name": name, - "template": template, - "context": context or {} - } + data = {"name": name, "template": template, "context": context or {}} return await self._make_request("POST", self._build_url("/template/render"), json=data) - - async def aparse_template(self, name: str, template: str, context: Optional[str] = None) -> Dict: + + async def aparse_template( + self, name: str, template: str, context: Optional[str] = None + ) -> Dict: """异步解析模板""" - data = { - "name": name, - "template": template, - "context": context - } + data = {"name": name, "template": template, "context": context} return await self._make_request("POST", self._build_url("/template/parse"), json=data) # 工厂函数 -async def create_async_client( - endpoint: str, - api_key: str, - **kwargs -) -> AsyncNetPulseClient: +async def create_async_client(endpoint: str, api_key: str, **kwargs) -> AsyncNetPulseClient: """创建异步客户端实例""" - return AsyncNetPulseClient(endpoint, api_key, **kwargs) \ No newline at end of file + return AsyncNetPulseClient(endpoint, api_key, **kwargs) diff --git a/netpulse-client/netpulse_client/client.py b/netpulse-client/netpulse_client/client.py index d0db16a..dd9b5a8 100644 --- a/netpulse-client/netpulse_client/client.py +++ b/netpulse-client/netpulse_client/client.py @@ -1,50 +1,50 @@ +import logging +import time from typing import Any, Dict, List, Optional, Union from urllib.parse import urljoin -import time + import httpx -import logging +from .exceptions import ( + AuthenticationError, + ConnectionError, + NetPulseError, + TimeoutError, +) from .models import ( - ConnectionArgs, + BatchResult, CommandResult, ConfigResult, - BatchResult, + ConnectionArgs, + ConnectionTestResult, + HealthCheckResult, JobInfo, WorkerInfo, - HealthCheckResult, - ConnectionTestResult, + create_batch_device_request, create_device_request, - create_batch_device_request -) -from .exceptions import ( - AuthenticationError, - ConnectionError, - JobError, - NetPulseError, - TimeoutError, - ValidationError ) # 配置日志 logger = logging.getLogger(__name__) + class NetPulseClient: """NetPulse API客户端""" - + def __init__( self, endpoint: str, - api_key: str = None, + api_key: Optional[str] = None, driver: str = "netmiko", timeout: int = 300, - verify_ssl: bool = True + verify_ssl: bool = True, ): """初始化客户端 - + Args: endpoint: API端点URL - api_key: API密钥(可选) - driver: 默认驱动类型(如netmiko/napalm/pyeapi) + api_key: API密钥 (可选) + driver: 默认驱动类型 (如netmiko/napalm/pyeapi) timeout: 请求超时时间(秒) verify_ssl: 是否验证SSL证书 """ @@ -53,66 +53,58 @@ def __init__( self.driver = driver self.timeout = timeout self.verify_ssl = verify_ssl - + self._session = httpx.Client( timeout=httpx.Timeout(timeout), verify=verify_ssl, - headers={ - "Content-Type": "application/json", - "User-Agent": "NetPulse-Client/0.1.0" - } + headers={"Content-Type": "application/json", "User-Agent": "NetPulse-Client/0.1.0"}, ) - + if api_key: self._session.headers["X-API-KEY"] = api_key - + def _build_url(self, path: str) -> str: """构建完整的API URL""" return urljoin(self.endpoint, path) - - def _make_request( - self, - method: str, - url: str, - **kwargs - ) -> Dict[str, Any]: + + def _make_request(self, method: str, url: str, **kwargs) -> Dict[str, Any]: """发送HTTP请求""" try: response = self._session.request(method, url, **kwargs) response.raise_for_status() - + return response.json() except httpx.HTTPStatusError as e: if e.response.status_code == 401: - raise AuthenticationError("Invalid API key") + raise AuthenticationError("Invalid API key") from e elif e.response.status_code == 404: - raise ConnectionError(f"API endpoint not found: {url}") + raise ConnectionError(f"API endpoint not found: {url}") from e else: - raise NetPulseError(f"HTTP error: {e}") + raise NetPulseError(f"HTTP error: {e}") from e except httpx.ConnectError: raise ConnectionError(f"Failed to connect to {url}") except httpx.TimeoutException: raise TimeoutError(f"Request timeout: {url}") except Exception as e: raise NetPulseError(f"Request failed: {e}") - - def _wait_for_result_with_progressive_retry(self, job_id: str, - initial_delay: float = 0.4, - max_total_time: float = 120.0) -> Dict: + + def _wait_for_result_with_progressive_retry( + self, job_id: str, initial_delay: float = 0.4, max_total_time: float = 120.0 + ) -> Dict: """ 使用激进递进式重试等待任务结果 - - 轮询策略: + + 轮询策略: - 初始延迟: 0.4秒 - 递增步长序列: 0.1s -> 0.2s -> 0.3s -> 0.5s -> 1.5s -> 2.5s -> 4.0s -> 6.0s -> 9.0s -> 13.5s -> 20.0s -> 30.0s - 轮询间隔: 0.4s -> 0.5s -> 0.7s -> 1.0s -> 1.5s -> 3.0s -> 5.5s -> 9.5s -> 15.5s -> 24.5s -> 37.5s -> 57.5s - 最大间隔: 30秒 - 最大总时长: 120秒 - - 示例轮询序列: + + 示例轮询序列: 第1次: 0.4s (初始) 第2次: 0.5s (0.4s + 0.1s) - 第3次: 0.7s (0.5s + 0.2s) + 第3次: 0.7s (0.5s + 0.2s) 第4次: 1.0s (0.7s + 0.3s) 第5次: 1.5s (1.0s + 0.5s) 第6次: 3.0s (1.5s + 1.5s) @@ -127,132 +119,169 @@ def _wait_for_result_with_progressive_retry(self, job_id: str, delay = initial_delay total_elapsed = 0.0 attempt = 0 - + # 定义递增步长序列 step_sequence = [0.1, 0.2, 0.3, 0.5, 1.5, 2.5, 4.0, 6.0, 9.0, 13.5, 20.0, 30.0] step_index = 0 - + while total_elapsed < max_total_time: try: result = self._make_request("GET", self._build_url("/job"), params={"id": job_id}) - + if result.get("code") == 0 and result.get("data"): - job_data = result["data"][0] if isinstance(result["data"], list) else result["data"] + job_data = ( + result["data"][0] if isinstance(result["data"], list) else result["data"] + ) status = job_data.get("status") - + if status in ["finished", "failed"]: - logger.info(f"任务 {job_id} 完成,总耗时: {total_elapsed:.2f}秒,轮询次数: {attempt}") + logger.info( + f"任务 {job_id} 完成, " + f"总耗时: {total_elapsed:.2f}秒, 轮询次数: {attempt}" + ) return result elif status == "queued": - logger.info(f"任务 {job_id} 排队中... (第{attempt}次轮询,已耗时{total_elapsed:.2f}秒)") + logger.info( + f"任务 {job_id} 排队中... " + f"(第{attempt}次轮询, 已耗时{total_elapsed:.2f}秒)" + ) elif status == "started": - logger.info(f"任务 {job_id} 执行中... (第{attempt}次轮询,已耗时{total_elapsed:.2f}秒)") - + logger.info( + f"任务 {job_id} 执行中... " + f"(第{attempt}次轮询, 已耗时{total_elapsed:.2f}秒)" + ) + # 激进递进式延迟 time.sleep(delay) total_elapsed += delay attempt += 1 - + # 计算下一次延迟 if step_index < len(step_sequence): next_step = step_sequence[step_index] step_index += 1 else: - # 如果步长序列用完,使用最后一个步长 + # 如果步长序列用完, 使用最后一个步长 next_step = step_sequence[-1] - + delay = min(delay + next_step, 30.0) # 限制最大间隔为30秒 - + # 记录轮询信息 if attempt % 3 == 0: # 每3次轮询记录一次详细信息 - logger.info(f"任务 {job_id} 轮询中... 第{attempt}次,当前延迟: {delay:.2f}s,总耗时: {total_elapsed:.2f}s") - + logger.info( + f"任务 {job_id} 轮询中... 第{attempt}次, 当前延迟: {delay:.2f}s, " + f"总耗时: {total_elapsed:.2f}s" + ) + except Exception as e: logger.warning(f"查询任务状态失败 (尝试 {attempt + 1}): {e}") time.sleep(delay) total_elapsed += delay attempt += 1 - + if step_index < len(step_sequence): next_step = step_sequence[step_index] step_index += 1 else: next_step = step_sequence[-1] - + delay = min(delay + next_step, 30.0) - - raise TimeoutError(f"等待任务 {job_id} 完成超时,总耗时: {total_elapsed:.2f}秒,轮询次数: {attempt}") - + + raise TimeoutError( + f"等待任务 {job_id} 完成超时, 总耗时: {total_elapsed:.2f}秒, 轮询次数: {attempt}" + ) + # ==================== 同步方法 ==================== - - def exec_command(self, device: Union[ConnectionArgs, Dict], command: Union[str, List[str]], driver: Optional[str] = None, **kwargs) -> CommandResult: + + def exec_command( + self, + device: Union[ConnectionArgs, Dict], + command: Union[str, List[str]], + driver: Optional[str] = None, + **kwargs, + ) -> CommandResult: """ - 同步执行命令,支持ConnectionArgs实例或dict - driver: 可选,临时覆盖实例driver + 同步执行命令, 支持ConnectionArgs实例或dict + driver: 可选, 临时覆盖实例driver """ - if hasattr(device, 'model_dump'): + if hasattr(device, "model_dump"): connection_args = device.model_dump() else: connection_args = device use_driver = driver or self.driver data = create_device_request( - driver=use_driver, - connection_args=connection_args, - command=command, - **kwargs + driver=use_driver, connection_args=connection_args, command=command, **kwargs ) result = self._make_request("POST", self._build_url("/device/execute"), json=data) - print("[DEBUG] API原始返回:", result) # 调试用 + # print("[DEBUG] API原始返回:", result) # 调试用 if result.get("code") == 0 and result.get("data"): job_id = result["data"]["id"] job_result = self._wait_for_result_with_progressive_retry(job_id) - print("[DEBUG] job_result:", job_result) # 调试用 - job_data = job_result["data"][0] if isinstance(job_result["data"], list) else job_result["data"] - print("[DEBUG] job_data:", job_data) # 调试用 - + # print("[DEBUG] job_result:", job_result) # 调试用 + job_data = ( + job_result["data"][0] + if isinstance(job_result["data"], list) + else job_result["data"] + ) + # print("[DEBUG] job_data:", job_data) # 调试用 + # 直接返回API的完整结构 return CommandResult(**job_data) raise NetPulseError(f"Command execution failed: {result}") - def exec_config(self, device: Union[ConnectionArgs, Dict], config: Union[str, List[str], Dict], driver: Optional[str] = None, **kwargs) -> ConfigResult: + def exec_config( + self, + device: Union[ConnectionArgs, Dict], + config: Union[str, List[str], Dict], + driver: Optional[str] = None, + **kwargs, + ) -> ConfigResult: """ - 同步推送配置,支持ConnectionArgs实例或dict - driver: 可选,临时覆盖实例driver + 同步推送配置, 支持ConnectionArgs实例或dict + driver: 可选, 临时覆盖实例driver """ - if hasattr(device, 'model_dump'): + if hasattr(device, "model_dump"): connection_args = device.model_dump() else: connection_args = device use_driver = driver or self.driver data = create_device_request( - driver=use_driver, - connection_args=connection_args, - config=config, - **kwargs + driver=use_driver, connection_args=connection_args, config=config, **kwargs ) result = self._make_request("POST", self._build_url("/device/execute"), json=data) if result.get("code") == 0 and result.get("data"): job_id = result["data"]["id"] job_result = self._wait_for_result_with_progressive_retry(job_id) - job_data = job_result["data"][0] if isinstance(job_result["data"], list) else job_result["data"] - + job_data = ( + job_result["data"][0] + if isinstance(job_result["data"], list) + else job_result["data"] + ) + # 直接返回API的完整结构 return ConfigResult(**job_data) raise NetPulseError(f"Config execution failed: {result}") - def bulk_command(self, devices: List[Union[ConnectionArgs, Dict]], connection_args: Dict, command: Union[str, List[str]], driver: Optional[str] = None, **kwargs) -> BatchResult: + def bulk_command( + self, + devices: List[Union[ConnectionArgs, Dict]], + connection_args: Dict, + command: Union[str, List[str]], + driver: Optional[str] = None, + **kwargs, + ) -> BatchResult: """ - 同步批量执行命令,支持ConnectionArgs实例或dict列表 - driver: 可选,临时覆盖实例driver + 同步批量执行命令, 支持ConnectionArgs实例或dict列表 + driver: 可选, 临时覆盖实例driver """ - devices_args = [d.model_dump() if hasattr(d, 'model_dump') else d for d in devices] + devices_args = [d.model_dump() if hasattr(d, "model_dump") else d for d in devices] use_driver = driver or self.driver data = create_batch_device_request( driver=use_driver, devices=devices_args, connection_args=connection_args, command=command, - **kwargs + **kwargs, ) result = self._make_request("POST", self._build_url("/device/bulk"), json=data) if result.get("code") == 0 and result.get("data"): @@ -262,29 +291,40 @@ def bulk_command(self, devices: List[Union[ConnectionArgs, Dict]], connection_ar for job in batch_data["succeeded"]: job_id = job["id"] job_result = self._wait_for_result_with_progressive_retry(job_id) - job_data = job_result["data"][0] if isinstance(job_result["data"], list) else job_result["data"] + job_data = ( + job_result["data"][0] + if isinstance(job_result["data"], list) + else job_result["data"] + ) results.append(job_data) return BatchResult( status="finished", job_id=batch_data.get("batch_id", ""), results=results, - error=None + error=None, ) raise NetPulseError(f"Batch command execution failed: {result}") - def bulk_config(self, devices: List[Union[ConnectionArgs, Dict]], connection_args: Dict, config: Union[str, List[str], Dict], driver: Optional[str] = None, **kwargs) -> BatchResult: + def bulk_config( + self, + devices: List[Union[ConnectionArgs, Dict]], + connection_args: Dict, + config: Union[str, List[str], Dict], + driver: Optional[str] = None, + **kwargs, + ) -> BatchResult: """ - 同步批量推送配置,支持ConnectionArgs实例或dict列表 - driver: 可选,临时覆盖实例driver + 同步批量推送配置, 支持ConnectionArgs实例或dict列表 + driver: 可选, 临时覆盖实例driver """ - devices_args = [d.model_dump() if hasattr(d, 'model_dump') else d for d in devices] + devices_args = [d.model_dump() if hasattr(d, "model_dump") else d for d in devices] use_driver = driver or self.driver data = create_batch_device_request( driver=use_driver, devices=devices_args, connection_args=connection_args, config=config, - **kwargs + **kwargs, ) result = self._make_request("POST", self._build_url("/device/bulk"), json=data) if result.get("code") == 0 and result.get("data"): @@ -294,28 +334,37 @@ def bulk_config(self, devices: List[Union[ConnectionArgs, Dict]], connection_arg for job in batch_data["succeeded"]: job_id = job["id"] job_result = self._wait_for_result_with_progressive_retry(job_id) - job_data = job_result["data"][0] if isinstance(job_result["data"], list) else job_result["data"] + job_data = ( + job_result["data"][0] + if isinstance(job_result["data"], list) + else job_result["data"] + ) results.append(job_data) return BatchResult( status="finished", job_id=batch_data.get("batch_id", ""), results=results, - error=None + error=None, ) raise NetPulseError(f"Batch config execution failed: {result}") - + # ==================== 任务管理 ==================== - + def get_job_info(self, job_id: str) -> JobInfo: """获取任务信息""" result = self._make_request("GET", self._build_url(f"/job/{job_id}")) if result.get("code") == 0 and result.get("data"): return JobInfo(**result["data"]) raise NetPulseError(f"Get job info failed: {result}") - - def get_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = None, - status: Optional[str] = None, node: Optional[str] = None, - host: Optional[str] = None) -> List[JobInfo]: + + def get_jobs( + self, + job_id: Optional[str] = None, + queue: Optional[str] = None, + status: Optional[str] = None, + node: Optional[str] = None, + host: Optional[str] = None, + ) -> List[JobInfo]: """获取任务列表""" params = {} if job_id: @@ -328,14 +377,15 @@ def get_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = None, params["node"] = node if host: params["host"] = host - + result = self._make_request("GET", self._build_url("/job"), params=params) if result.get("code") == 0 and result.get("data"): return [JobInfo(**job) for job in result["data"]] raise NetPulseError(f"Get jobs failed: {result}") - - def delete_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = None, - host: Optional[str] = None) -> Dict: + + def delete_jobs( + self, job_id: Optional[str] = None, queue: Optional[str] = None, host: Optional[str] = None + ) -> Dict: """删除任务""" params = {} if job_id: @@ -344,13 +394,14 @@ def delete_jobs(self, job_id: Optional[str] = None, queue: Optional[str] = None, params["queue"] = queue if host: params["host"] = host - + return self._make_request("DELETE", self._build_url("/job"), params=params) - + # ==================== Worker管理 ==================== - - def get_workers(self, queue: Optional[str] = None, node: Optional[str] = None, - host: Optional[str] = None) -> List[WorkerInfo]: + + def get_workers( + self, queue: Optional[str] = None, node: Optional[str] = None, host: Optional[str] = None + ) -> List[WorkerInfo]: """获取Worker列表""" params = {} if queue: @@ -359,14 +410,19 @@ def get_workers(self, queue: Optional[str] = None, node: Optional[str] = None, params["node"] = node if host: params["host"] = host - + result = self._make_request("GET", self._build_url("/worker"), params=params) if result.get("code") == 0 and result.get("data"): return [WorkerInfo(**worker) for worker in result["data"]] raise NetPulseError(f"Get workers failed: {result}") - - def delete_workers(self, name: Optional[str] = None, queue: Optional[str] = None, - node: Optional[str] = None, host: Optional[str] = None) -> Dict: + + def delete_workers( + self, + name: Optional[str] = None, + queue: Optional[str] = None, + node: Optional[str] = None, + host: Optional[str] = None, + ) -> Dict: """删除Worker""" params = {} if name: @@ -377,66 +433,55 @@ def delete_workers(self, name: Optional[str] = None, queue: Optional[str] = None params["node"] = node if host: params["host"] = host - + return self._make_request("DELETE", self._build_url("/worker"), params=params) - + # ==================== 健康检测 ==================== - + def health_check(self) -> HealthCheckResult: """健康检测""" result = self._make_request("GET", self._build_url("/health")) if result.get("code") == 0 and result.get("data"): return HealthCheckResult(**result["data"]) raise NetPulseError(f"Health check failed: {result}") - + # ==================== 连接测试 ==================== - - def test_connection(self, device: Union[ConnectionArgs, Dict], driver: Optional[str] = None) -> ConnectionTestResult: - """测试设备连接,支持ConnectionArgs实例或dict""" - if hasattr(device, 'model_dump'): + + def test_connection( + self, device: Union[ConnectionArgs, Dict], driver: Optional[str] = None + ) -> ConnectionTestResult: + """测试设备连接, 支持ConnectionArgs实例或dict""" + if hasattr(device, "model_dump"): connection_args = device.model_dump() else: connection_args = device use_driver = driver or self.driver - data = { - "driver": use_driver, - "connection_args": connection_args - } + data = {"driver": use_driver, "connection_args": connection_args} result = self._make_request("POST", self._build_url("/device/test-connection"), json=data) if result.get("code") == 0 and result.get("data"): return ConnectionTestResult(**result["data"]) raise NetPulseError(f"Connection test failed: {result}") - + # ==================== 模板管理 ==================== - + def render_template(self, name: str, template: str, context: Optional[Dict] = None) -> Dict: """渲染模板""" - data = { - "name": name, - "template": template, - "context": context or {} - } + data = {"name": name, "template": template, "context": context or {}} return self._make_request("POST", self._build_url("/template/render"), json=data) - + def parse_template(self, name: str, template: str, context: Optional[str] = None) -> Dict: """解析模板""" - data = { - "name": name, - "template": template, - "context": context - } + data = {"name": name, "template": template, "context": context} return self._make_request("POST", self._build_url("/template/parse"), json=data) - + # ==================== 向后兼容方法 ==================== - - def execute(self, device: ConnectionArgs, command: Union[str, List[str]], **kwargs) -> CommandResult: + + def execute( + self, device: ConnectionArgs, command: Union[str, List[str]], **kwargs + ) -> CommandResult: """向后兼容的命令执行方法""" - result = self.exec_command( - device=device, - command=command, - **kwargs - ) - + result = self.exec_command(device=device, command=command, **kwargs) + if result.get("code") == 0 and result.get("data"): job_data = result["data"][0] if isinstance(result["data"], list) else result["data"] return CommandResult( @@ -444,37 +489,35 @@ def execute(self, device: ConnectionArgs, command: Union[str, List[str]], **kwar data=job_data.get("result", {}).get("retval"), error=job_data.get("result", {}).get("error"), job_id=job_data.get("id", ""), - device_host=device.host + device_host=device.host, ) - + raise NetPulseError(f"Command execution failed: {result}") - - def configure(self, device: ConnectionArgs, config: Union[str, List[str]], **kwargs) -> ConfigResult: + + def configure( + self, device: ConnectionArgs, config: Union[str, List[str]], **kwargs + ) -> ConfigResult: """向后兼容的配置推送方法""" - result = self.exec_config( - device=device, - config=config, - **kwargs - ) - + result = self.exec_config(device=device, config=config, **kwargs) + if result.get("code") == 0 and result.get("data"): job_data = result["data"][0] if isinstance(result["data"], list) else result["data"] return ConfigResult( status=job_data.get("status", "unknown"), job_id=job_data.get("id", ""), device_host=device.host, - error=job_data.get("result", {}).get("error") + error=job_data.get("result", {}).get("error"), ) - + raise NetPulseError(f"Config execution failed: {result}") - + def close(self): """关闭客户端""" if self._session: self._session.close() - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_val, exc_tb): - self.close() \ No newline at end of file + self.close() diff --git a/netpulse-client/netpulse_client/exceptions.py b/netpulse-client/netpulse_client/exceptions.py index e45cbbd..9701533 100644 --- a/netpulse-client/netpulse_client/exceptions.py +++ b/netpulse-client/netpulse_client/exceptions.py @@ -7,7 +7,7 @@ class NetPulseError(Exception): """NetPulse 基础异常类""" - + def __init__(self, message: str, status_code: Optional[int] = None): super().__init__(message) self.message = message @@ -16,50 +16,59 @@ def __init__(self, message: str, status_code: Optional[int] = None): class AuthenticationError(NetPulseError): """认证错误""" + pass class ConnectionError(NetPulseError): """连接错误""" + pass class JobError(NetPulseError): """任务错误""" - - def __init__(self, message: str, job_id: Optional[str] = None, status_code: Optional[int] = None): + + def __init__( + self, message: str, job_id: Optional[str] = None, status_code: Optional[int] = None + ): super().__init__(message, status_code) self.job_id = job_id class TimeoutError(NetPulseError): """超时错误""" + pass class ValidationError(NetPulseError): """验证错误""" + pass class SDKValidationError(NetPulseError): """SDK验证错误""" + pass class ConfigurationError(NetPulseError): """Raised when configuration is invalid.""" + pass class TemplateError(NetPulseError): """Raised when template processing fails.""" + pass class DeviceError(NetPulseError): """Raised when device operation fails.""" - + def __init__(self, message: str, device_host: Optional[str] = None): super().__init__(message) - self.device_host = device_host \ No newline at end of file + self.device_host = device_host diff --git a/netpulse-client/netpulse_client/models.py b/netpulse-client/netpulse_client/models.py index b1e4c03..a10e264 100644 --- a/netpulse-client/netpulse_client/models.py +++ b/netpulse-client/netpulse_client/models.py @@ -2,15 +2,15 @@ NetPulse Client Models 这个模块定义了SDK使用的数据模型。 -为了保持与主程序的一致性,我们直接使用主程序的模型,而不是重新定义。 +为了保持与主程序的一致性, 我们直接使用主程序的模型, 而不是重新定义。 """ -from typing import Any, Dict, List, Optional, Union -from datetime import datetime from dataclasses import dataclass from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field -from pydantic import BaseModel, Field, ConfigDict # 定义基本的枚举类型 class DriverName(str, Enum): @@ -18,21 +18,24 @@ class DriverName(str, Enum): NAPALM = "napalm" PYEAPI = "pyeapi" + class QueueStrategy(str, Enum): FIFO = "fifo" PINNED = "pinned" + # 定义基本的连接参数模型 - 与API中的DriverConnectionArgs保持一致 class ConnectionArgs(BaseModel): """设备连接参数模型 - 与API中的DriverConnectionArgs保持一致""" + device_type: Optional[str] = Field(None, description="设备类型") host: Optional[str] = Field(None, description="设备IP地址") username: Optional[str] = Field(None, description="设备用户名") password: Optional[str] = Field(None, description="设备密码") - - # 允许额外字段,与API保持一致 + + # 允许额外字段, 与API保持一致 model_config = ConfigDict(extra="allow") - + def enforced_field_check(self): """ ConnectionArgs could be auto-filled in Batch APIs. @@ -42,24 +45,33 @@ def enforced_field_check(self): raise ValueError("host is None") return self + class NetmikoConnectionArgs(ConnectionArgs): """Netmiko专用连接参数""" + pass + class NapalmConnectionArgs(ConnectionArgs): """NAPALM专用连接参数""" + pass + class PyeapiConnectionArg(ConnectionArgs): """PyEAPI专用连接参数""" + pass -# 为了向后兼容,提供Device别名 + +# 为了向后兼容, 提供Device别名 Device = ConnectionArgs + # 客户端专用模型 class ConnectionConfig(BaseModel): """连接配置""" + endpoint: str = Field(..., description="NetPulse API端点") api_key: str = Field(..., description="API密钥") timeout: int = Field(300, description="请求超时时间(秒)") @@ -67,21 +79,27 @@ class ConnectionConfig(BaseModel): retry_delay: float = Field(1.0, description="重试延迟(秒)") verify_ssl: bool = Field(True, description="是否验证SSL证书") + class JobResult(BaseModel): """任务结果""" + type: str = Field(..., description="结果类型") retval: Optional[Any] = Field(None, description="返回值") error: Optional[Dict[str, Any]] = Field(None, description="错误信息") + # 根据API返回结构定义的结果模型 class ResultModel(BaseModel): """API返回的result字段结构""" + type: int retval: Optional[Dict[str, Any]] = None # 多命令时是dict error: Optional[Any] = None + class CommandResult(BaseModel): """命令执行结果 - 严格按照API返回结构""" + id: str status: str created_at: str @@ -93,21 +111,21 @@ class CommandResult(BaseModel): result: Optional[ResultModel] = None duration: Optional[float] = None queue_time: Optional[float] = None - + @property def data(self) -> Dict[str, str]: - """直接获取所有命令输出,格式: {命令: 输出}""" + """直接获取所有命令输出, 格式: {命令: 输出}""" if self.result and self.result.retval: return self.result.retval return {} - + @property def results(self) -> List[str]: - """获取所有命令输出的列表,格式: [输出1, 输出2, ...]""" + """获取所有命令输出的列表, 格式: [输出1, 输出2, ...]""" if self.result and self.result.retval: return list(self.result.retval.values()) return [] - + def __getitem__(self, key): """支持 result[0] 或 result['display version'] 访问""" if self.result and self.result.retval: @@ -119,8 +137,10 @@ def __getitem__(self, key): return self.result.retval.get(key, "") return "" + class ConfigResult(BaseModel): """配置推送结果 - 严格按照API返回结构""" + id: str status: str created_at: str @@ -132,21 +152,21 @@ class ConfigResult(BaseModel): result: Optional[ResultModel] = None duration: Optional[float] = None queue_time: Optional[float] = None - + @property def data(self) -> Dict[str, str]: - """直接获取所有命令输出,格式: {命令: 输出}""" + """直接获取所有命令输出, 格式: {命令: 输出}""" if self.result and self.result.retval: return self.result.retval return {} - + @property def results(self) -> List[str]: - """获取所有命令输出的列表,格式: [输出1, 输出2, ...]""" + """获取所有命令输出的列表, 格式: [输出1, 输出2, ...]""" if self.result and self.result.retval: return list(self.result.retval.values()) return [] - + def __getitem__(self, key): """支持 result[0] 或 result['display version'] 访问""" if self.result and self.result.retval: @@ -158,16 +178,20 @@ def __getitem__(self, key): return self.result.retval.get(key, "") return "" + class BatchResult(BaseModel): """批量操作结果""" + status: str = Field(..., description="执行状态") job_id: str = Field(..., description="任务ID") results: List[Dict[str, Any]] = Field(default_factory=list, description="设备执行结果列表") error: Optional[str] = Field(None, description="错误信息") execution_time: Optional[float] = Field(None, description="执行时间(秒)") + class JobInfo(BaseModel): """任务信息""" + job_id: str = Field(..., description="任务ID") status: str = Field(..., description="任务状态") result: Optional[Dict[str, Any]] = Field(None, description="任务结果") @@ -175,30 +199,38 @@ class JobInfo(BaseModel): created_at: Optional[str] = Field(None, description="创建时间") updated_at: Optional[str] = Field(None, description="更新时间") + class WorkerInfo(BaseModel): """工作节点信息""" + id: str = Field(..., description="节点ID") status: str = Field(..., description="节点状态") queue_size: int = Field(..., description="队列大小") active_jobs: int = Field(..., description="活跃任务数") last_heartbeat: str = Field(..., description="最后心跳时间") + class HealthCheckResult(BaseModel): """健康检查结果""" + status: str = Field(..., description="服务状态") version: str = Field(..., description="服务版本") uptime: float = Field(..., description="运行时间(秒)") + class ConnectionTestResult(BaseModel): """连接测试结果""" + success: bool = Field(..., description="测试结果") message: str = Field(..., description="测试信息") connection_time: Optional[float] = Field(None, description="连接时间(秒)") error: Optional[str] = Field(None, description="错误信息") + # 状态枚举 class OperationStatus(str, Enum): """操作状态""" + SUBMITTED = "submitted" QUEUED = "queued" STARTED = "started" @@ -207,8 +239,10 @@ class OperationStatus(str, Enum): TIMEOUT = "timeout" CANCELLED = "cancelled" + class JobStatus(str, Enum): """任务状态""" + QUEUED = "queued" STARTED = "started" FINISHED = "finished" @@ -218,29 +252,35 @@ class JobStatus(str, Enum): STOPPED = "stopped" CANCELED = "canceled" + class WorkerState(str, Enum): """工作节点状态""" + BUSY = "busy" IDLE = "idle" SUSPENDED = "suspended" DEAD = "dead" + # 异步任务句柄 @dataclass class AsyncJobHandle: """异步任务句柄""" + job_id: str task_type: str # command, config, batch_command, batch_config device_hosts: List[str] submitted_at: float timeout: int - + @property def is_expired(self) -> bool: """检查任务是否已超时""" import time + return time.time() - self.submitted_at > self.timeout + # 工具函数 def create_device_request( driver: str, @@ -263,15 +303,15 @@ def create_device_request( else: conn_args = ConnectionArgs(**connection_args) except ValueError: - # 如果driver不是有效的DriverName,直接使用ConnectionArgs + # 如果driver不是有效的DriverName, 直接使用ConnectionArgs conn_args = ConnectionArgs(**connection_args) - + # 构建请求数据 request = { "driver": driver, "connection_args": conn_args.model_dump(), } - + if command is not None: request["command"] = command if config is not None: @@ -280,9 +320,10 @@ def create_device_request( request["driver_args"] = driver_args if options is not None: request["options"] = options - + return request + def create_batch_device_request( driver: str, devices: List[Dict[str, Any]], @@ -298,10 +339,10 @@ def create_batch_device_request( "driver": driver, "devices": devices, } - + if connection_args is not None: request["connection_args"] = connection_args - + if command is not None: request["command"] = command if config is not None: @@ -310,5 +351,5 @@ def create_batch_device_request( request["driver_args"] = driver_args if options is not None: request["options"] = options - - return request \ No newline at end of file + + return request diff --git a/netpulse-client/pyproject.toml b/netpulse-client/pyproject.toml index 6222016..e0c5563 100644 --- a/netpulse-client/pyproject.toml +++ b/netpulse-client/pyproject.toml @@ -45,11 +45,6 @@ dev = [ "mypy>=1.0.0", "ruff>=0.1.0", ] -docs = [ - "sphinx>=6.0.0", - "sphinx-rtd-theme>=1.3.0", - "myst-parser>=1.0.0", -] [project.urls] Homepage = "https://github.com/netpulse/netpulse-client" @@ -111,20 +106,7 @@ strict_equality = true [tool.ruff] line-length = 100 target-version = "py38" -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "UP", # pyupgrade -] -ignore = [ - "E501", # line too long, handled by black - "B008", # do not perform function calls in argument defaults - "C901", # too complex -] +select = ["E", "F", "W", "RUF", "I"] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/netpulse/plugins/drivers/netmiko/model.py b/netpulse/plugins/drivers/netmiko/model.py index 7ee86a9..a127d68 100644 --- a/netpulse/plugins/drivers/netmiko/model.py +++ b/netpulse/plugins/drivers/netmiko/model.py @@ -42,7 +42,7 @@ class NetmikoConnectionArgs(DriverConnectionArgs): auth_timeout: Optional[float] = None blocking_timeout: Optional[int] = 20 banner_timeout: Optional[int] = 15 - keepalive: Optional[int] = 180 # keepalive (3m) differs from netmiko default (0) + keepalive: Optional[int] = 180 # keepalive (3m) differs from netmiko default (0) default_enter: Optional[str] = None response_return: Optional[str] = None serial_settings: Optional[str] = None diff --git a/pyproject.toml b/pyproject.toml index 8cde75c..948f20b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,6 @@ classifiers = [ ] keywords = ["network", "api", "netmiko", "napalm", "automation", "fastapi"] -[project.urls] -"Homepage" = "https://github.com/scitix/netpulse" -"Documentation" = "https://netpulse.readthedocs.io" -"Repository" = "https://github.com/scitix/netpulse" -"Bug Tracker" = "https://github.com/scitix/netpulse/issues" -"Changelog" = "https://github.com/scitix/netpulse/blob/master/CHANGELOG.md" - dependencies = [ # Core "colorlog~=6.9.0", @@ -75,6 +68,13 @@ dev = ["ruff>=0.11.0", "mkdocs-material~=9.6.0", "mkdocs-static-i18n~=1.3.0"] [project.scripts] netpulse-cli = "netpulse.cli.main:main" +[project.urls] +"Homepage" = "https://github.com/scitix/netpulse" +"Documentation" = "https://netpulse.readthedocs.io" +"Repository" = "https://github.com/scitix/netpulse" +"Bug Tracker" = "https://github.com/scitix/netpulse/issues" +"Changelog" = "https://github.com/scitix/netpulse/blob/master/CHANGELOG.md" + [tool.setuptools] packages = ["netpulse"] @@ -93,7 +93,7 @@ packages = ["netpulse"] [tool.ruff] line-length = 100 respect-gitignore = true +exclude = ["netpulse-client/*"] -# https://docs.astral.sh/ruff/rules/ [tool.ruff.lint] select = ["E", "F", "W", "RUF", "I"]