Skip to content

Commit

Permalink
Extend AbstractLLM for RAG and OpenAI models
Browse files Browse the repository at this point in the history
  • Loading branch information
vmesel committed May 1, 2024
1 parent 9df24e1 commit 64ad092
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[run]
source=.
omit=dialog_lib/tests/*,dialog_lib/samples/*

[report]
omit=dialog_lib/tests/*,dialog_lib/samples/*
1 change: 1 addition & 0 deletions dialog_lib/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .openai import DialogOpenAI
45 changes: 29 additions & 16 deletions dialog_lib/agents/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain.prompts import ChatPromptTemplate
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.prompts import ChatPromptTemplate


class AbstractLLM:
Expand Down Expand Up @@ -37,12 +37,6 @@ def __init__(
self.parent_session_id = parent_session_id
self.dbsession = dbsession

def get_prompt(self, input) -> ChatPromptTemplate:
"""
Function that generates the prompt for the LLM.
"""
raise NotImplementedError("Prompt must be implemented")

@property
def memory(self) -> BaseChatMemory:
"""
Expand Down Expand Up @@ -83,6 +77,32 @@ def postprocess(self, output: str) -> str:
"""
return output

def process(self, input: str):
"""
Function that encapsulates the pre-processing, processing and post-processing
of the LLM.
"""
processed_input = self.preprocess(input)
self.generate_prompt(processed_input)
output = self.llm.invoke(
{
"user_message": processed_input,
}
)
processed_output = self.postprocess(output)
return processed_output

@property
def messages(self):
"""
Returns the messages from the memory instance
"""
return self.memory.messages


class AbstractRAG(AbstractLLM):
relevant_contents = []

def process(self, input: str):
"""
Function that encapsulates the pre-processing, processing and post-processing
Expand All @@ -98,17 +118,10 @@ def process(self, input: str):
"fallback_not_found_relevant_contents"
)
}
output = self.llm(
output = self.llm.invoke(
{
"user_message": processed_input,
}
)
processed_output = self.postprocess(output)
return processed_output

@property
def messages(self):
"""
Returns the messages from the memory instance
"""
return self.memory.messages
return processed_output
61 changes: 61 additions & 0 deletions dialog_lib/agents/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from .abstract import AbstractLLM
from langchain.prompts import ChatPromptTemplate
from langchain.memory import ConversationBufferWindowMemory

from langchain.chains.llm import LLMChain
from langchain_openai.chat_models.base import ChatOpenAI


class DialogOpenAI(AbstractLLM):
def __init__(self, *args, **kwargs):
model = kwargs.pop("model", "gpt-3.5-turbo")
temperature = kwargs.pop("temperature", 0.1)
kwargs["config"] = kwargs.get("config", {})

self.memory_instance = kwargs.pop("memory", None)
self.llm_api_key = kwargs.pop("openai_api_key", None)
self.prompt_content = kwargs.pop("prompt", None)

super().__init__(*args, **kwargs)

self.chat_model = ChatOpenAI(
openai_api_key=self.llm_api_key,
model=model,
temperature=temperature
)

def generate_prompt(self, input_text):
self.prompt = ChatPromptTemplate.from_messages([
("ai", self.prompt_content),
("human", input_text)
])
return input_text

@property
def memory(self):
return self.memory_instance

@property
def llm(self):
chain_settings = dict(
llm=self.chat_model,
prompt=self.prompt
)

if self.memory:
buffer_config = {
"chat_memory": self.memory,
"memory_key": "chat_history",
"return_messages": True,
"k": self.config.get("memory_size", 5)
}
chain_settings["memory"] = ConversationBufferWindowMemory(
**buffer_config
)

return LLMChain(
**chain_settings
)

def postprocess(self, output):
return output.get("text")
1 change: 1 addition & 0 deletions dialog_lib/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from dialog_lib.db.memory import *
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def test_abstract_agent_get_prompt():
"temperature": 0.5,
}
agent = AbstractLLM(config=config)
with pytest.raises(NotImplementedError):
agent.get_prompt(input="Hello")

def test_abstract_agent_memory():
config = {
Expand Down Expand Up @@ -61,7 +59,7 @@ def test_abstract_agent_generate_prompt():
prompt = agent.generate_prompt(text="Hello")
assert prompt == "Hello"

def test_abstract_agent_process():
def test_abstract_agent_process(mocker):
config = {
"model": "gpt3.5-turbo",
"temperature": 0.5,
Expand All @@ -70,6 +68,9 @@ def test_abstract_agent_process():
}
}
agent = AbstractLLM(config=config)

agent.relevant_contents = [] # sample mock
mocker.patch('dialog_lib.agents.abstract.AbstractLLM.llm')
mocker.patch('dialog_lib.agents.abstract.AbstractLLM.llm.invoke', return_value={'text': '404 Not Found'})
output = agent.process(input="Hello")
assert output == {'text': '404 Not Found'}
2 changes: 0 additions & 2 deletions docs/abstract-llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ Here we will explain more about the abstract llm class.

## Available Methods

### get_prompt(self, input)

### memory(self)

### llm(self)
Expand Down
11 changes: 11 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ filterwarnings =
pythonpath =
/app/dialog_lib/

ignore =
.git
.github
.idea
.vscode
__pycache__
.pytest_cache
samples/
tests/


[tool.pytest.ini_options]
log_cli = true
log_cli_level = "INFO"
Expand Down
32 changes: 32 additions & 0 deletions samples/openai/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

from sqlalchemy import create_engine
from dialog_lib.agents import DialogOpenAI
from dialog_lib.memory import generate_memory_instance


database_url = "postgresql://talkdai:talkdai@db:5432/test_talkdai"

engine = create_engine(database_url)

dbsession = engine.connect()


memory = generate_memory_instance(
session_id="test_session",
dbsession=dbsession,
database_url=database_url,
)

agent = DialogOpenAI(
model="gpt-3.5-turbo",
temperature=0.1,
llm_api_key=os.environ.get("OPENAI_API_KEY"),
prompt="You are a bot called Sara. Be nice to other human beings.",
memory=memory,
)

while True:
input_text = input("You: ")
output_text = agent.process(input_text)
print(f"Sara: {output_text}")

0 comments on commit 64ad092

Please sign in to comment.