Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 491636b

Browse files
committed
feat: allow to expose models base path and model name from cli
1 parent 1f809f2 commit 491636b

File tree

5 files changed

+102
-45
lines changed

5 files changed

+102
-45
lines changed

docs/cli.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ codegate serve [OPTIONS]
6666
- Base URL for Ollama provider (/api path is added automatically)
6767
- Overrides configuration file and environment variables
6868

69+
- `--model-base-path TEXT`: Base path for loading models needed for the system
70+
- Optional
71+
72+
- `--embedding-model TEXT`: Name of the model used for embeddings
73+
- Optional
74+
6975
### show-prompts
7076

7177
Display the loaded system prompts:

scripts/entrypoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ exec nginx -g 'daemon off;' &
1515

1616
# Step 3: Start the main application (serve)
1717
echo "Starting the application..."
18-
exec python -m src.codegate.cli serve --port 8989 --host 0.0.0.0 --vllm-url https://inference.codegate.ai
18+
exec python -m src.codegate.cli serve --port 8989 --host 0.0.0.0 --vllm-url https://inference.codegate.ai --model-base-path /app/models

src/codegate/cli.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,18 @@ def show_prompts(prompts: Optional[Path]) -> None:
115115
default=None,
116116
help="Ollama provider URL (default: http://localhost:11434/api)",
117117
)
118+
@click.option(
119+
"--model-base-path",
120+
type=str,
121+
default="./models",
122+
help="Path to the model base directory",
123+
)
124+
@click.option(
125+
"--embedding-model",
126+
type=str,
127+
default="all-minilm-L6-v2-q5_k_m.gguf",
128+
help="Name of the model to use for embeddings",
129+
)
118130
def serve(
119131
port: Optional[int],
120132
host: Optional[str],
@@ -126,6 +138,8 @@ def serve(
126138
openai_url: Optional[str],
127139
anthropic_url: Optional[str],
128140
ollama_url: Optional[str],
141+
model_base_path: Optional[str],
142+
embedding_model: Optional[str],
129143
) -> None:
130144
"""Start the codegate server."""
131145
logger = None
@@ -150,6 +164,8 @@ def serve(
150164
cli_log_level=log_level,
151165
cli_log_format=log_format,
152166
cli_provider_urls=cli_provider_urls,
167+
model_base_path=model_base_path,
168+
embedding_model=embedding_model,
153169
)
154170

155171
setup_logging(cfg.log_level, cfg.log_format)
@@ -163,6 +179,8 @@ def serve(
163179
"log_format": cfg.log_format.value,
164180
"prompts_loaded": len(cfg.prompts.prompts),
165181
"provider_urls": cfg.provider_urls,
182+
"model_base_path": cfg.model_base_path,
183+
"embedding_model": cfg.embedding_model,
166184
},
167185
)
168186

src/codegate/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def load(
176176
cli_log_level: Optional[str] = None,
177177
cli_log_format: Optional[str] = None,
178178
cli_provider_urls: Optional[Dict[str, str]] = None,
179+
model_base_path: Optional[str] = None,
180+
embedding_model: Optional[str] = None,
179181
) -> "Config":
180182
"""Load configuration with priority resolution.
181183
@@ -193,6 +195,8 @@ def load(
193195
cli_log_level: Optional CLI log level override
194196
cli_log_format: Optional CLI log format override
195197
cli_provider_urls: Optional dict of provider URLs from CLI
198+
model_base_path: Optional path to model base directory
199+
embedding_model: Optional name of the model to use for embeddings
196200
197201
Returns:
198202
Config: Resolved configuration
@@ -223,6 +227,10 @@ def load(
223227
config.log_format = env_config.log_format
224228
if "CODEGATE_PROMPTS_FILE" in os.environ:
225229
config.prompts = env_config.prompts
230+
if "CODEGATE_MODEL_BASE_PATH" in os.environ:
231+
config.model_base_path = env_config.model_base_path
232+
if "CODEGATE_EMBEDDING_MODEL" in os.environ:
233+
config.embedding_model = env_config.embedding_model
226234

227235
# Override provider URLs from environment
228236
for provider, url in env_config.provider_urls.items():
@@ -241,6 +249,10 @@ def load(
241249
config.prompts = PromptConfig.from_file(prompts_path)
242250
if cli_provider_urls is not None:
243251
config.provider_urls.update(cli_provider_urls)
252+
if model_base_path is not None:
253+
config.model_base_path = model_base_path
254+
if embedding_model is not None:
255+
config.embedding_model = embedding_model
244256

245257
# Set the __config class attribute
246258
Config.__config = config

tests/test_cli.py

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,23 @@ def test_serve_default_options(
6767
assert result.exit_code == 0
6868
mock_setup_logging.assert_called_once_with(LogLevel.INFO, LogFormat.JSON)
6969
mock_logging.assert_called_with("codegate")
70-
logger_instance.info.assert_any_call(
71-
"Starting server",
72-
extra={
73-
"host": "localhost",
74-
"port": 8989,
75-
"log_level": "INFO",
76-
"log_format": "JSON",
77-
"prompts_loaded": 6, # Default prompts are loaded
78-
"provider_urls": DEFAULT_PROVIDER_URLS,
79-
},
80-
)
70+
71+
# validate only a subset of the expected extra arguments, as image provides more
72+
expected_extra = {
73+
"host": "localhost",
74+
"port": 8989,
75+
"log_level": "INFO",
76+
"log_format": "JSON",
77+
"prompts_loaded": 6,
78+
"provider_urls": DEFAULT_PROVIDER_URLS,
79+
}
80+
81+
# Retrieve the actual call arguments
82+
calls = [call[1]['extra'] for call in logger_instance.info.call_args_list]
83+
84+
# Check if one of the calls matches the expected subset
85+
assert any(all(expected_extra[k] == actual_extra.get(k)
86+
for k in expected_extra) for actual_extra in calls)
8187
mock_run.assert_called_once()
8288

8389

@@ -106,17 +112,22 @@ def test_serve_custom_options(
106112
assert result.exit_code == 0
107113
mock_setup_logging.assert_called_once_with(LogLevel.DEBUG, LogFormat.TEXT)
108114
mock_logging.assert_called_with("codegate")
109-
logger_instance.info.assert_any_call(
110-
"Starting server",
111-
extra={
112-
"host": "localhost",
113-
"port": 8989,
114-
"log_level": "DEBUG",
115-
"log_format": "TEXT",
116-
"prompts_loaded": 6, # Default prompts are loaded
117-
"provider_urls": DEFAULT_PROVIDER_URLS,
118-
},
119-
)
115+
116+
# Retrieve the actual call arguments
117+
calls = [call[1]['extra'] for call in logger_instance.info.call_args_list]
118+
119+
expected_extra = {
120+
"host": "localhost",
121+
"port": 8989,
122+
"log_level": "DEBUG",
123+
"log_format": "TEXT",
124+
"prompts_loaded": 6, # Default prompts are loaded
125+
"provider_urls": DEFAULT_PROVIDER_URLS,
126+
}
127+
128+
# Check if one of the calls matches the expected subset
129+
assert any(all(expected_extra[k] == actual_extra.get(k)
130+
for k in expected_extra) for actual_extra in calls)
120131
mock_run.assert_called_once()
121132

122133

@@ -146,17 +157,22 @@ def test_serve_with_config_file(
146157
assert result.exit_code == 0
147158
mock_setup_logging.assert_called_once_with(LogLevel.DEBUG, LogFormat.JSON)
148159
mock_logging.assert_called_with("codegate")
149-
logger_instance.info.assert_any_call(
150-
"Starting server",
151-
extra={
152-
"host": "localhost",
153-
"port": 8989,
154-
"log_level": "DEBUG",
155-
"log_format": "JSON",
156-
"prompts_loaded": 6, # Default prompts are loaded
157-
"provider_urls": DEFAULT_PROVIDER_URLS,
158-
},
159-
)
160+
161+
# Retrieve the actual call arguments
162+
calls = [call[1]['extra'] for call in logger_instance.info.call_args_list]
163+
164+
expected_extra = {
165+
"host": "localhost",
166+
"port": 8989,
167+
"log_level": "DEBUG",
168+
"log_format": "JSON",
169+
"prompts_loaded": 6, # Default prompts are loaded
170+
"provider_urls": DEFAULT_PROVIDER_URLS,
171+
}
172+
173+
# Check if one of the calls matches the expected subset
174+
assert any(all(expected_extra[k] == actual_extra.get(k)
175+
for k in expected_extra) for actual_extra in calls)
160176
mock_run.assert_called_once()
161177

162178

@@ -198,17 +214,22 @@ def test_serve_priority_resolution(
198214
assert result.exit_code == 0
199215
mock_setup_logging.assert_called_once_with(LogLevel.ERROR, LogFormat.TEXT)
200216
mock_logging.assert_called_with("codegate")
201-
logger_instance.info.assert_any_call(
202-
"Starting server",
203-
extra={
204-
"host": "example.com",
205-
"port": 8080,
206-
"log_level": "ERROR",
207-
"log_format": "TEXT",
208-
"prompts_loaded": 6, # Default prompts are loaded
209-
"provider_urls": DEFAULT_PROVIDER_URLS,
210-
},
211-
)
217+
218+
# Retrieve the actual call arguments
219+
calls = [call[1]['extra'] for call in logger_instance.info.call_args_list]
220+
221+
expected_extra = {
222+
"host": "example.com",
223+
"port": 8080,
224+
"log_level": "ERROR",
225+
"log_format": "TEXT",
226+
"prompts_loaded": 6, # Default prompts are loaded
227+
"provider_urls": DEFAULT_PROVIDER_URLS,
228+
}
229+
230+
# Check if one of the calls matches the expected subset
231+
assert any(all(expected_extra[k] == actual_extra.get(k)
232+
for k in expected_extra) for actual_extra in calls)
212233
mock_run.assert_called_once()
213234

214235

0 commit comments

Comments
 (0)