diff --git a/agent_memory_server/cli.py b/agent_memory_server/cli.py index e8f5af7..c031e47 100644 --- a/agent_memory_server/cli.py +++ b/agent_memory_server/cli.py @@ -314,7 +314,18 @@ def token(): show_default=True, help="Output format.", ) -def add(description: str, expires_days: int | None, output_format: str): +@click.option( + "--token", + "provided_token", + type=str, + help="Use a pre-generated token instead of generating a new one.", +) +def add( + description: str, + expires_days: int | None, + output_format: str, + provided_token: str | None, +) -> None: """Add a new authentication token.""" import asyncio @@ -324,9 +335,9 @@ def add(description: str, expires_days: int | None, output_format: str): async def create_token(): redis = await get_redis_conn() - # Generate token - token = generate_token() - token_hash = hash_token(token) + # Determine token value + token_value = provided_token or generate_token() + token_hash = hash_token(token_value) # Calculate expiration now = datetime.now(UTC) @@ -353,7 +364,7 @@ async def create_token(): list_key = Keys.auth_tokens_list_key() await redis.sadd(list_key, token_hash) - return token, token_info + return token_value, token_info token, token_info = asyncio.run(create_token()) diff --git a/tests/test_token_cli.py b/tests/test_token_cli.py index c135c74..e214042 100644 --- a/tests/test_token_cli.py +++ b/tests/test_token_cli.py @@ -95,6 +95,52 @@ def test_token_add_command_json_output( mock_redis.expire.assert_called_once() mock_redis.sadd.assert_called_once() + @patch("agent_memory_server.auth.generate_token") + @patch("agent_memory_server.cli.get_redis_conn") + def test_token_add_command_with_provided_token( + self, + mock_get_redis, + mock_generate_token, + mock_redis, + cli_runner, + ): + """Test token add command when a token is provided via --token.""" + mock_get_redis.return_value = mock_redis + + import json + + provided_token = "test-token-123" + + result = cli_runner.invoke( + token, + [ + "add", + "--description", + "Test token", + "--expires-days", + "7", + "--token", + provided_token, + "--format", + "json", + ], + ) + + assert result.exit_code == 0 + + data = json.loads(result.output) + assert data["token"] == provided_token + assert data["description"] == "Test token" + assert data["expires_at"] is not None + + # generate_token should not be called when a token is provided + mock_generate_token.assert_not_called() + + # Verify Redis calls + mock_redis.set.assert_called_once() + mock_redis.expire.assert_called_once() + mock_redis.sadd.assert_called_once() + @patch("agent_memory_server.cli.get_redis_conn") def test_token_list_command_empty(self, mock_get_redis, mock_redis, cli_runner): """Test token list command with no tokens."""