Skip to content

Commit

Permalink
read multiple requests files (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
samos123 committed Dec 14, 2023
1 parent fd9ab5d commit dec3d29
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ COPY requirements.txt /usr/src/app
# Install any needed packages specified in requirements.txt
RUN pip install --no-cache-dir -r requirements.txt

COPY main.py /usr/src/app
COPY batchelor /usr/src/app/batchelor

# Run the Python script when the container launches
# Use the environment variables to pass the arguments
CMD python main.py
CMD python batchelor/main.py
File renamed without changes.
11 changes: 2 additions & 9 deletions main.py → batchelor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from aiohttp_retry import RetryClient, ExponentialRetry
from smart_open import open

from batchelor.reader import read_file_and_enqueue

url = "http://localhost:8080/v1/completions"
filename = "part-{partition}.jsonl"
ignore_fields = []
Expand All @@ -22,15 +24,6 @@
timeout = 1200


async def read_file_and_enqueue(path, queue: asyncio.Queue):
with open(path, mode="r") as file:
print(f"Sending request to Queue from file {path}")
for line in file.readlines():
request = json.loads(line)
await queue.put(request)
await queue.put(None)


async def worker(
requests: asyncio.Queue,
results: asyncio.Queue,
Expand Down
35 changes: 35 additions & 0 deletions batchelor/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import json

from google.cloud import storage
from smart_open import open


def parse_bucket(path: str) -> str:
"""
Parse bucket name from a GCS path. For example given the path
gs://bucket-name/path/to/file, return bucket-name
"""
return path.split("/")[2]


def convert_path_to_list(path: str) -> list[str]:
if path.startswith("gs://"):
bucket_name = parse_bucket(path)
paths = []
client = storage.Client()
for blob in client.list_blobs(bucket_name, prefix=path):
paths.append(f"gs://{bucket_name}/{blob.name}")
return paths
return [path]


async def read_file_and_enqueue(path, queue: asyncio.Queue):
paths = convert_path_to_list(path)
for path in paths:
with open(path, mode="r") as file:
print(f"Sending request to Queue from file {path}")
for line in file.readlines():
request = json.loads(line)
await queue.put(request)
await queue.put(None)
1 change: 1 addition & 0 deletions tests/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
pytest-asyncio
pytest-httpserver
pytest-mock
52 changes: 52 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
from dataclasses import dataclass
from unittest.mock import Mock

from google.cloud import storage
import pytest

from batchelor.reader import parse_bucket, convert_path_to_list


@pytest.fixture
def mock_client(mocker):
return mocker.patch("google.cloud.storage.Client", autospec=True)


def test_parse_bucket(mock_client):
input = "gs://bucket-name/path/to/file"
expected = "bucket-name"
assert parse_bucket(input) == expected

input = "gcss://bucket-name/path/to/file"
expected = "bucket-name"
assert parse_bucket(input) == expected


@dataclass
class Blob:
name: str


def test_convert_path_to_list_single(mocker, mock_client):
path = "gs://bucket-name/path/to/file.json"
mock_client.return_value.list_blobs.return_value = [Blob(name="path/to/file.json")]

output = convert_path_to_list(path)
assert mock_client.return_value.list_blobs.call_count == 1
assert len(output) == 1
assert output[0] == path


def test_convert_path_to_list_multiple(mocker, mock_client):
path = "gs://bucket-name/path"
mock_client.return_value.list_blobs.return_value = [
Blob(name="path/file1.jsonl"),
Blob(name="path/file2.jsonl"),
]

output = convert_path_to_list(path)
assert mock_client.return_value.list_blobs.call_count == 1
assert len(output) == 2
assert output[0] == path + "/file1.jsonl"
assert output[1] == path + "/file2.jsonl"

0 comments on commit dec3d29

Please sign in to comment.