Skip to content

Commit

Permalink
fix(io): add ensembl batch 429 and 404 error handling
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Apr 3, 2024
1 parent a97feac commit 942a175
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions src/pyrovelocity/io/ensembl_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import random
from os import PathLike
from pathlib import Path

Expand All @@ -26,15 +27,23 @@

logger = configure_logging(__name__)

CONCURRENCY_LIMIT = 30
CONCURRENCY_LIMIT = 15

BatchResult = List[Result[dict, str]]


@beartype
async def backoff(retry_count: int):
delay = min(60, (2**retry_count) + random.uniform(0, 1))
print(f"Waiting for {delay} seconds before retrying...")
await anyio.sleep(delay)


@beartype
async def fetch_sequences_batch(
gene_ids: List[str],
client: AsyncClient,
retry_count: int = 0,
) -> BatchResult:
"""
Fetch a batch of genomic sequences for a given list of gene or transcript IDs from Ensembl.
Expand All @@ -52,14 +61,39 @@ async def fetch_sequences_batch(
"Accept": "application/json",
}
payload = json.dumps({"ids": gene_ids})
max_retries = 3

response = await client.post(url, data=payload, headers=headers)
if response.status_code == 200:
return [Success(item) for item in response.json()]
else:
error_message = f"Failed to fetch batch: {response.status_code}"
logger.error(error_message)
return [Failure(error_message)] * len(gene_ids)
try:
response = await client.post(url, data=payload, headers=headers)
if response.status_code == 200:
return [Success(item) for item in response.json()]
elif response.status_code == 429 and retry_count < max_retries:
await backoff(retry_count)
return await fetch_sequences_batch(
gene_ids, client, retry_count + 1
)
elif response.status_code == 404:
logger.error(f"404 Not Found for IDs: {', '.join(gene_ids)}")
return [
Failure(f"404 Not Found for ID: {gene_id}")
for gene_id in gene_ids
]
else:
error_message = f"Failed to fetch batch: {response.status_code}"
logger.error(error_message)
return [Failure(error_message)] * len(gene_ids)
except httpx.ReadTimeout:
if retry_count < max_retries:
logger.warning(
f"Timeout encountered. Retrying... Attempt {retry_count+1}/{max_retries}"
)
await backoff(retry_count)
return await fetch_sequences_batch(
gene_ids, client, retry_count + 1
)
else:
logger.error("Max retries reached. Failing...")
return [Failure("ReadTimeout")] * len(gene_ids)


@beartype
Expand All @@ -79,7 +113,7 @@ async def main(
semaphore = Semaphore(CONCURRENCY_LIMIT)

async with httpx.AsyncClient(
timeout=10.0
timeout=30.0
) as client, anyio.create_task_group() as tg:
for i in range(0, len(to_fetch), query_batch_size):
batch_ids = to_fetch[i : i + query_batch_size]
Expand Down Expand Up @@ -166,7 +200,7 @@ def fetch_gene_sequences_batch(
os.makedirs(cache_dir, exist_ok=True)
parquet_file_path = os.path.join(cache_dir, f"{parquet_file_name}.parquet")

cache = Cache(cache_path)
cache = Cache(directory=cache_path, size_limit=int(4e9))
try:
table = anyio.run(
main, gene_id_list, cache, query_batch_size, backend="trio"
Expand Down

0 comments on commit 942a175

Please sign in to comment.