Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Refactor Python user API (LangStream#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Sep 28, 2023
1 parent 9ab16fc commit 728c506
Show file tree
Hide file tree
Showing 16 changed files with 495 additions and 502 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@

import base64
import io
from typing import List, Optional, Dict, Any
from typing import Dict, Any

import openai
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster
from langstream import Sink, CommitCallback, Record
from langstream import Sink, Record
from llama_index import VectorStoreIndex, Document
from llama_index.vector_stores import CassandraVectorStore


class LlamaIndexCassandraSink(Sink):
def __init__(self):
self.commit_cb: Optional[CommitCallback] = None
self.config = None
self.session = None
self.index = None
Expand Down Expand Up @@ -63,13 +62,8 @@ def start(self):

self.index = VectorStoreIndex.from_vector_store(vector_store)

def write(self, records: List[Record]):
for record in records:
self.index.insert(Document(text=record.value()))
self.commit_cb.commit([record])

def set_commit_callback(self, commit_callback: CommitCallback):
self.commit_cb = commit_callback
def write(self, record: Record):
self.index.insert(Document(text=record.value()))

def close(self):
if self.session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ def init(self, config):
print("init", config)
openai.api_key = config["openaiKey"]

def process(self, records):
processed_records = []
for record in records:
embedding = get_embedding(record.value(), engine="text-embedding-ada-002")
result = {"input": str(record.value()), "embedding": embedding}
new_value = json.dumps(result)
processed_records.append((record, [(new_value,)]))
return processed_records
def process(self, record):
embedding = get_embedding(record.value(), engine="text-embedding-ada-002")
result = {"input": str(record.value()), "embedding": embedding}
new_value = json.dumps(result)
return [(new_value,)]
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
# limitations under the License.
#

from langstream import SimpleRecord, SingleRecordProcessor
from langstream import SimpleRecord, Processor


# Example Python processor that adds an exclamation mark to the end of the record value
class Exclamation(SingleRecordProcessor):
def process_record(self, record):
class Exclamation(Processor):
def process(self, record):
return [SimpleRecord(record.value() + "!!", headers=record.headers())]
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
# limitations under the License.
#

from langstream import SimpleRecord, SingleRecordProcessor
from langstream import SimpleRecord, Processor


class Exclamation(SingleRecordProcessor):
class Exclamation(Processor):
def init(self, config):
print("init", config)
self.secret_value = config["secret_value"]

def process_record(self, record):
def process(self, record):
return [
SimpleRecord(
record.value() + "!!" + self.secret_value, headers=record.headers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,19 @@

class TestSink(object):
def __init__(self):
self.commit_callback = None
self.producer = None

def init(self, config):
logging.info("Init config: " + str(config))
self.producer = Producer({"bootstrap.servers": config["bootstrapServers"]})

def write(self, records):
logging.info("Write records: " + str(records))
def write(self, record):
logging.info("Write record: " + str(record))
try:
for record in records:
self.producer.produce(
"ls-test-output", value=("write: " + record.value()).encode("utf-8")
)
self.producer.produce(
"ls-test-output", value=("write: " + record.value()).encode("utf-8")
)
self.producer.flush()
self.commit_callback.commit(records)
except Exception as e:
logging.error("Error writing records: " + str(e))
raise e

def set_commit_callback(self, commit_callback):
self.commit_callback = commit_callback
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@ def read(self):
print(f"read {records}")
return records

def set_commit_callback(self, cb):
pass
def process(self, record):
print(f"process {record}")
return [record]

def process(self, records):
print(f"process {records}")
return [(record, [record]) for record in records]

def write(self, records):
print(f"write {records}")
def write(self, record):
print(f"write {record}")
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
Sink,
Source,
Processor,
CommitCallback,
)
from .util import SimpleRecord, SingleRecordProcessor, AvroValue
from .util import SimpleRecord, AvroValue

__all__ = [
"Record",
Expand All @@ -33,8 +32,6 @@
"Source",
"Sink",
"Processor",
"CommitCallback",
"SimpleRecord",
"SingleRecordProcessor",
"AvroValue",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#

from abc import ABC, abstractmethod
from typing import Any, List, Tuple, Dict, Union
from concurrent.futures import Future
from typing import Any, List, Tuple, Dict, Union, Optional

__all__ = [
"Record",
Expand All @@ -25,7 +26,6 @@
"Source",
"Sink",
"Processor",
"CommitCallback",
]


Expand Down Expand Up @@ -91,9 +91,9 @@ class Source(Agent):
def read(self) -> List[RecordType]:
"""The Source agent generates records and returns them as list of records.
:returns: the list of records. The records must either respect the Record
API contract (have methods value(), key() and so on) or be a dict or
tuples/list.
:returns: the list of records.
The records must either respect the Record API contract (have methods value(),
key() and so on) or be a dict or tuples/list.
If the records are dict, the keys if present shall be "value", "key",
"headers", "origin" and "timestamp".
Eg:
Expand All @@ -108,15 +108,15 @@ def read(self) -> List[RecordType]:
"""
pass

def commit(self, records: List[Record]):
"""Called by the framework to indicate the records that have been successfully
def commit(self, record: Record):
"""Called by the framework to indicate that a record has been successfully
processed."""
pass

def permanent_failure(self, record: Record, error: Exception):
"""Called by the framework to indicate that the agent has permanently failed to
process the record.
The Source agent may send the records to a dead letter queue or raise an error.
process a record.
The Source agent may send the record to a dead letter queue or raise an error.
"""
raise error

Expand All @@ -129,61 +129,44 @@ class Processor(Agent):

@abstractmethod
def process(
self, records: List[Record]
) -> List[Tuple[Record, Union[List[RecordType], Exception]]]:
"""The agent processes records and returns a list containing the associations of
these records with the result of these record processing.
The result of each record processing is a list of new records or an exception.
The transactionality of the function is guaranteed by the runtime.
:returns: the list of associations between an input record and the output
records processed from it.
Eg: [(input_record, [output_record1, output_record2])]
If an input record cannot be processed, the associated element shall be an
exception.
Eg: [(input_record, RuntimeError("Could not process"))]
self, record: Record
) -> Union[List[RecordType], Future[List[RecordType]]]:
"""The agent processes a record and returns a list of new records.
:returns: the list of records or a concurrent.futures.Future that will complete
with the list of records.
When the processing is successful, the output records must either respect the
Record API contract (have methods value(), key() and so on) or be a dict or
tuples/list.
If the records are dict, the keys if present shall be "value", "key",
"headers", "origin" and "timestamp".
Eg:
* if you return [(input_record, [{"value": "foo"}])] a record
Record(value="foo") will be built.
* if you return {"value": "foo"} a record Record(value="foo") will be built.
If the output records are tuples/list, the framework will automatically
construct Record objects from them with the values in the following order :
value, key, headers, origin, timestamp.
Eg:
* if you return [(input_record, [("foo",)])] a record Record(value="foo") will
be built.
* if you return [(input_record, [("foo", "bar")])] a record
Record(value="foo", key="bar") will be built.
* if you return ("foo",) a record Record(value="foo") will be built.
* if you return ("foo", "bar") a record Record(value="foo", key="bar") will be
built.
"""
pass


class CommitCallback(ABC):
@abstractmethod
def commit(self, records: List[Record]):
"""Called by a Sink to indicate the records that have been successfully
written."""
pass


class Sink(Agent):
"""The Sink agent interface
A Sink agent is used by the runtime to write Records.
"""

@abstractmethod
def write(self, records: List[Record]):
def write(self, record: Record) -> Optional[Future[None]]:
"""The Sink agent receives records from the framework and typically writes them
to an external service."""
pass
to an external service.
For a synchronous result, return None/nothing if successful or otherwise raise
an Exception.
For an asynchronous result, return a concurrent.futures.Future.
@abstractmethod
def set_commit_callback(self, commit_callback: CommitCallback):
"""Called by the framework to specify a CommitCallback that shall be used by the
Sink to indicate the records that have been written."""
:returns: nothing if the write is successful or a concurrent.futures.Future
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# limitations under the License.
#

from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, List, Tuple, Union
from typing import Any, List, Tuple

from .api import Record, Processor, RecordType
from .api import Record

__all__ = ["SimpleRecord", "SingleRecordProcessor", "AvroValue"]
__all__ = ["SimpleRecord", "AvroValue"]


class SimpleRecord(Record):
Expand Down Expand Up @@ -57,46 +56,15 @@ def timestamp(self) -> int:

def __str__(self):
return (
f"Record(value={self._value}, key={self._key}, origin={self._origin}, "
f"timestamp={self._timestamp}, headers={self._headers})"
f"SimpleRecord(value={self._value}, key={self._key}, "
f"origin={self._origin},timestamp={self._timestamp}, "
f"headers={self._headers})"
)

def __repr__(self):
return self.__str__()


class SingleRecordProcessor(Processor):
"""A Processor that processes records one-by-one"""

@abstractmethod
def process_record(self, record: Record) -> List[RecordType]:
"""Process one record and return a list of records or raise an exception.
:returns: the list of processed records. The records must either respect the
Record API contract (have methods value(), key() and so on) or be tuples/list.
If the records are tuples/list, the framework will automatically construct
Record objects from them with the values in the following order : value, key,
headers, origin, timestamp.
Eg:
* if you return [("foo",)] a record Record(value="foo") will be built.
* if you return [("foo", "bar")] a record Record(value="foo", key="bar") will
be built.
"""
pass

def process(
self, records: List[Record]
) -> List[Tuple[Record, Union[List[RecordType], Exception]]]:
results = []
for record in records:
try:
processed = self.process_record(record)
results.append((record, processed))
except Exception as e:
results.append((record, e))
return results


@dataclass
class AvroValue(object):
schema: dict
Expand Down
Loading

0 comments on commit 728c506

Please sign in to comment.