Skip to content

Commit

Permalink
Fix renaming in tests that somehow were missed
Browse files Browse the repository at this point in the history
  • Loading branch information
rogeriochaves committed Aug 4, 2023
1 parent 9ec0e69 commit 1ab8a96
Show file tree
Hide file tree
Showing 4 changed files with 1,110 additions and 42 deletions.
1 change: 0 additions & 1 deletion .gitignore
Expand Up @@ -7,7 +7,6 @@ docs/static/reference
.chainlit
chainlit.md
!docs/**/chainlit.md
test*.py
build/
langstream.egg-info/
.chroma
Expand Down
62 changes: 62 additions & 0 deletions tests/contrib/llms/test_gpt4all_stream.py
@@ -0,0 +1,62 @@
import unittest
from typing import Any, AsyncGenerator, List

import pytest

from langstream import Stream, StreamOutput, as_async_generator, debug, join_final_output
from langstream.contrib.llms.gpt4all_stream import GPT4AllStream


class GPT4AllStreamTestCase(unittest.IsolatedAsyncioTestCase):
@pytest.mark.integration
async def test_it_completes_a_simple_prompt(self):
stream = debug(
GPT4AllStream[str, str](
"GreetingStream",
lambda name: f"### User: Hello, my name is {name}. How is it going?\n\n### Response:",
model="orca-mini-3b.ggmlv3.q4_0.bin",
temperature=0,
)
)

result = await join_final_output(stream("Alice"))
self.assertIn("I'm doing well, thank you for asking!", result)

@pytest.mark.integration
@pytest.mark.skip(
"parallelization will not work with GPT4All because we create only one instance for it in the main thread and don't multiply memory, this is probably what users will want too to use all the cores, so skipping this for now"
)
async def test_it_is_non_blocking(self):
async_stream = GPT4AllStream[str, str](
"AsyncStream",
lambda _: f"to make a function asynchronous in js, use the keyword `",
model="orca-mini-3b.ggmlv3.q4_0.bin",
max_tokens=2,
temperature=0,
)

parallel_stream: Stream[str, List[List[str]]] = Stream[
str, AsyncGenerator[StreamOutput[str, Any], None]
](
"ParallelStream",
lambda input: as_async_generator(
async_stream(input),
async_stream(input),
async_stream(input),
async_stream(input),
),
).gather()

async for output in parallel_stream("Alice"):
if isinstance(output.data, str):
print(output.data)
if output.final:
self.assertEqual(
output.data,
[
["as", "ync"],
["as", "ync"],
["as", "ync"],
["as", "ync"],
],
)
81 changes: 40 additions & 41 deletions tests/contrib/llms/test_open_ai.py
@@ -1,7 +1,6 @@
import json
import unittest
from typing import (
Any,
AsyncGenerator,
List,
Literal,
Expand All @@ -11,60 +10,60 @@

import pytest

from litechain.core.chain import Chain, ChainOutput
from litechain.utils.async_generator import as_async_generator
from litechain.contrib.llms.open_ai import (
from langstream.core.stream import Stream, StreamOutput
from langstream.utils.async_generator import as_async_generator
from langstream.contrib.llms.open_ai import (
OpenAIChatDelta,
OpenAIChatMessage,
OpenAICompletionChain,
OpenAIChatChain,
OpenAICompletionStream,
OpenAIChatStream,
)
from litechain.utils.chain import collect_final_output, debug, join_final_output
from langstream.utils.stream import collect_final_output, debug, join_final_output


class OpenAICompletionChainTestCase(unittest.IsolatedAsyncioTestCase):
class OpenAICompletionStreamTestCase(unittest.IsolatedAsyncioTestCase):
@pytest.mark.integration
async def test_it_completes_a_simple_prompt(self):
chain = debug(
OpenAICompletionChain[str, str](
"GreetingChain",
stream = debug(
OpenAICompletionStream[str, str](
"GreetingStream",
lambda name: f"Human: Hello, my name is {name}\nAssistant: ",
model="text-ada-001",
temperature=0,
)
)

result = await join_final_output(chain("Alice"))
result = await join_final_output(stream("Alice"))
self.assertIn("I am an assistant", result)

@pytest.mark.integration
@pytest.mark.timeout(
1 # if due to some bug it ends up being blocking, then it will break this threshold
)
async def test_it_is_non_blocking(self):
async_chain = debug(
OpenAICompletionChain[str, str](
"AsyncChain",
async_stream = debug(
OpenAICompletionStream[str, str](
"AsyncStream",
lambda _: f"Say async. Assistant: \n",
model="text-ada-001",
max_tokens=2,
temperature=0,
)
)

parallel_chain: Chain[str, List[List[str]]] = Chain[
str, AsyncGenerator[ChainOutput[str], None]
parallel_stream: Stream[str, List[List[str]]] = Stream[
str, AsyncGenerator[StreamOutput[str], None]
](
"ParallelChain",
"ParallelStream",
lambda input: as_async_generator(
async_chain(input),
async_chain(input),
async_chain(input),
async_chain(input),
async_stream(input),
async_stream(input),
async_stream(input),
async_stream(input),
),
).gather()

result = await collect_final_output(parallel_chain("Alice"))
result = await collect_final_output(parallel_stream("Alice"))
self.assertEqual(
result,
[
Expand All @@ -78,11 +77,11 @@ async def test_it_is_non_blocking(self):
)


class OpenAIChatChainTestCase(unittest.IsolatedAsyncioTestCase):
class OpenAIChatStreamTestCase(unittest.IsolatedAsyncioTestCase):
@pytest.mark.integration
async def test_it_completes_a_simple_prompt(self):
chain = OpenAIChatChain[str, OpenAIChatDelta](
"GreetingChain",
stream = OpenAIChatStream[str, OpenAIChatDelta](
"GreetingStream",
lambda name: [
OpenAIChatMessage(role="user", content=f"Hello, my name is {name}")
],
Expand All @@ -91,7 +90,7 @@ async def test_it_completes_a_simple_prompt(self):
)

result = ""
async for output in chain("Alice"):
async for output in stream("Alice"):
print(output.data.content, end="", flush=True)
result += output.data.content
self.assertIn("Hello Alice! How can I assist you today?", result)
Expand All @@ -116,9 +115,9 @@ def update_delta_on_memory(delta: OpenAIChatDelta) -> OpenAIChatDelta:
memory["history"][-1].content += delta.content
return delta

chain = debug(
OpenAIChatChain[str, OpenAIChatDelta](
"EmojiChatChain",
stream = debug(
OpenAIChatStream[str, OpenAIChatDelta](
"EmojiChatStream",
lambda user_message: [
*memory["history"],
save_message_to_memory(
Expand All @@ -133,12 +132,12 @@ def update_delta_on_memory(delta: OpenAIChatDelta) -> OpenAIChatDelta:
).map(update_delta_on_memory)

outputs = await collect_final_output(
chain("Hey there, my name is 馃Ж how is it going?")
stream("Hey there, my name is 馃Ж how is it going?")
)
result = "".join([output.content for output in outputs])
self.assertIn("馃憢馃Ж", result)

outputs = await collect_final_output(chain("What is my name?"))
outputs = await collect_final_output(stream("What is my name?"))
result = "".join([output.content for output in outputs])
self.assertIn("馃Ж", result)

Expand All @@ -160,9 +159,9 @@ def get_current_weather(
temperature="25 C" if format == "celsius" else "77 F",
)

chain: Chain[str, Union[OpenAIChatDelta, WeatherReturn]] = debug(
OpenAIChatChain[str, OpenAIChatDelta](
"WeatherChain",
stream: Stream[str, Union[OpenAIChatDelta, WeatherReturn]] = debug(
OpenAIChatStream[str, OpenAIChatDelta](
"WeatherStream",
lambda user_input: [
OpenAIChatMessage(role="user", content=user_input),
],
Expand Down Expand Up @@ -197,7 +196,7 @@ def get_current_weather(
)

outputs = await collect_final_output(
chain(
stream(
"I'm in my appartment in Amsterdam, thinking... should I take an umbrella for my pet chicken?"
)
)
Expand Down Expand Up @@ -246,10 +245,10 @@ def update_delta_on_memory(delta: OpenAIChatDelta) -> OpenAIChatDelta:
memory["history"][-1].content += delta.content
return delta

chain = (
stream = (
debug(
OpenAIChatChain[str, OpenAIChatDelta](
"WeatherChain",
OpenAIChatStream[str, OpenAIChatDelta](
"WeatherStream",
lambda user_input: [
*memory["history"],
save_message_to_memory(
Expand Down Expand Up @@ -290,7 +289,7 @@ def update_delta_on_memory(delta: OpenAIChatDelta) -> OpenAIChatDelta:
)

outputs = await collect_final_output(
chain("What is the weather today in amsterdam?")
stream("What is the weather today in amsterdam?")
)
self.assertEqual(
list(outputs)[0],
Expand All @@ -307,7 +306,7 @@ def update_delta_on_memory(delta: OpenAIChatDelta) -> OpenAIChatDelta:
),
)

outputs = await collect_final_output(chain("How many degrees again?"))
outputs = await collect_final_output(stream("How many degrees again?"))
result = "".join(
[
output.content
Expand Down

0 comments on commit 1ab8a96

Please sign in to comment.