Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ jobs:
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Checkout github repo
uses: actions/checkout@v4
with:
lfs: true

- name: Checkout LFS objects
run: git lfs pull

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
Expand Down
11 changes: 11 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
class Config:
"""Application configuration with priority resolution."""

# Singleton instance of Config which is set in Config.load().
# All consumers can call: Config.get_config() to get the config.
__config = None

port: int = 8989
host: str = "localhost"
log_level: LogLevel = LogLevel.INFO
Expand Down Expand Up @@ -208,4 +212,11 @@ def load(
if prompts_path is not None:
config.prompts = PromptConfig.from_file(prompts_path)

# Set the __config class attribute
Config.__config = config

return config

@classmethod
def get_config(cls):
return cls.__config
60 changes: 39 additions & 21 deletions src/codegate/inference/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,41 @@
from llama_cpp import Llama

from codegate.codegate_logging import setup_logging


class LlamaCppInferenceEngine:
_inference_engine = None
"""
A wrapper class for llama.cpp models

Attributes:
__inference_engine: Singleton instance of this class
"""

__inference_engine = None

def __new__(cls):
if cls._inference_engine is None:
cls._inference_engine = super().__new__(cls)
return cls._inference_engine
if cls.__inference_engine is None:
cls.__inference_engine = super().__new__(cls)
return cls.__inference_engine

def __init__(self):
if not hasattr(self, "models"):
self.__models = {}
self.__logger = setup_logging()

async def get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0):
def __del__(self):
self.__close_models()

async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0):
"""
Returns Llama model object from __models if present. Otherwise, the model
is loaded and added to __models and returned.
"""
if model_path not in self.__models:
self.__logger.info(
f"Loading model from {model_path} with parameters "
f"n_gpu_layers={n_gpu_layers} and n_ctx={n_ctx}"
)
self.__models[model_path] = Llama(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
Expand All @@ -25,29 +46,26 @@ async def get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0

return self.__models[model_path]

async def generate(
self, model_path, prompt, n_ctx=512, n_gpu_layers=0, stream=True
):
model = await self.get_model(
model_path=model_path, n_ctx=n_ctx, n_gpu_layers=n_gpu_layers
)

for chunk in model.create_completion(prompt=prompt, stream=stream):
yield chunk

async def chat(
self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_request
):
model = await self.get_model(
async def chat(self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_request):
"""
Generates a chat completion using the specified model and request parameters.
"""
model = await self.__get_model(
model_path=model_path, n_ctx=n_ctx, n_gpu_layers=n_gpu_layers
)
return model.create_completion(**chat_completion_request)

async def embed(self, model_path, content):
model = await self.get_model(model_path=model_path, embedding=True)
"""
Generates an embedding for the given content using the specified model.
"""
model = await self.__get_model(model_path=model_path, embedding=True)
return model.embed(content)

async def close_models(self):
async def __close_models(self):
"""
Closes all open models and samplers
"""
for _, model in self.__models:
if model._sampler:
model._sampler.close()
Expand Down
4 changes: 3 additions & 1 deletion src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, AsyncIterator, Iterator
import asyncio

from pydantic import BaseModel

Expand Down Expand Up @@ -46,8 +47,9 @@ async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]
if hasattr(chunk, "model_dump_json"):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
chunk['content'] = chunk['choices'][0]['text']
chunk["content"] = chunk["choices"][0]["text"]
yield f"data:{json.dumps(chunk)}\n\n"
await asyncio.sleep(0)
except Exception as e:
yield f"data:{str(e)}\n\n"
except Exception as e:
Expand Down
7 changes: 3 additions & 4 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

class LlamaCppCompletionHandler(BaseCompletionHandler):
def __init__(self, adapter: BaseAdapter):
self._config = Config.from_file('./config.yaml')
self._adapter = adapter
self.inference_engine = LlamaCppInferenceEngine()

Expand Down Expand Up @@ -53,9 +52,9 @@ async def execute_completion(
"""
Execute the completion request with LiteLLM's API
"""
response = await self.inference_engine.chat(self._config.chat_model_path,
self._config.chat_model_n_ctx,
self._config.chat_model_n_gpu_layers,
response = await self.inference_engine.chat(Config.get_config().chat_model_path,
Config.get_config().chat_model_n_ctx,
Config.get_config().chat_model_n_gpu_layers,
**request)
return response

Expand Down
1 change: 0 additions & 1 deletion tests/providers/litellmshim/test_litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
"model": "gpt-3.5-turbo",
"stream": True,
}
api_key = "test-key"

# Execute
result_stream = await litellm_shim.execute_completion(data)
Expand Down
29 changes: 15 additions & 14 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pytest

# @pytest.mark.asyncio
Expand All @@ -20,25 +19,27 @@
@pytest.mark.asyncio
async def test_chat(inference_engine) -> None:
"""Test chat completion."""
pass

# chat_request = {"prompt":
# "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n",
# "stream": True, "max_tokens": 4096, "top_k": 50, "temperature": 0}
chat_request = {
"prompt": "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n",
"stream": True,
"max_tokens": 4096,
"top_k": 50,
"temperature": 0,
}

# model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
# response = await inference_engine.chat(model_path, **chat_request)
model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
response = await inference_engine.chat(model_path, **chat_request)

# for chunk in response:
# assert chunk['choices'][0]['text'] is not None
for chunk in response:
assert chunk["choices"][0]["text"] is not None


@pytest.mark.asyncio
async def test_embed(inference_engine) -> None:
"""Test content embedding."""
pass

# content = "Can I use invokehttp package in my project?"
# model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
# vector = await inference_engine.embed(model_path, content=content)
# assert len(vector) == 384
content = "Can I use invokehttp package in my project?"
model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
vector = await inference_engine.embed(model_path, content=content)
assert len(vector) == 384