Skip to content

Commit

Permalink
custom api transform support (#843)
Browse files Browse the repository at this point in the history
* custom api transform support

* params
  • Loading branch information
Abhinav-Naikawadi committed Jun 4, 2024
1 parent 0a0d2b3 commit 26ccd92
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/autolabel/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .pdf import PDFTransform
from .serp_api import SerpApi
from .serper_api import SerperApi
from .custom_api import CustomApi
from .webpage_transform import WebpageTransform
from .webpage_scrape import WebpageScrape
from .image import ImageTransform
Expand All @@ -20,6 +21,7 @@
TransformType.IMAGE: ImageTransform,
TransformType.WEB_SEARCH_SERP_API: SerpApi,
TransformType.WEB_SEARCH_SERPER: SerperApi,
TransformType.CUSTOM_API: CustomApi,
}


Expand Down
134 changes: 134 additions & 0 deletions src/autolabel/transforms/custom_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import asyncio
from collections import defaultdict
import json
from urllib.parse import urlparse
from autolabel.cache import BaseCache
from autolabel.transforms import BaseTransform
from langchain_community.utilities import GoogleSerperAPIWrapper
from typing import Dict, Any, List
import logging
import pandas as pd
import ssl

from autolabel.transforms.schema import (
TransformError,
TransformErrorType,
TransformType,
)

logger = logging.getLogger(__name__)

MAX_RETRIES = 5
MAX_KEEPALIVE_CONNECTIONS = 20
CONNECTION_TIMEOUT = 10
MAX_CONNECTIONS = 100
BACKOFF = 2
HEADERS = {}


class CustomApi(BaseTransform):
COLUMN_NAMES = ["result"]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
base_url: str,
request_columns: List[str],
headers: Dict[str, str] = HEADERS,
timeout: int = 60,
) -> None:
super().__init__(cache, output_columns)
self.request_columns = request_columns
if not urlparse(base_url).scheme:
base_url = f"https://{base_url}"
self.base_url = base_url
self.headers = headers
self.max_retries = MAX_RETRIES
try:
import httpx

if not headers.get("User-Agent"):
from fake_useragent import UserAgent

headers["User-Agent"] = UserAgent().random

self.httpx = httpx
self.timeout_time = timeout
self.timeout = httpx.Timeout(connect=CONNECTION_TIMEOUT, timeout=timeout)
limits = httpx.Limits(
max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS,
max_connections=MAX_CONNECTIONS,
keepalive_expiry=timeout,
)
self.client = httpx.AsyncClient(
timeout=self.timeout, limits=limits, follow_redirects=True
)
self.client_with_no_verify = httpx.AsyncClient(
timeout=self.timeout, limits=limits, follow_redirects=True, verify=False
)
except ImportError:
raise ImportError(
"httpx and fake_useragent are required to use the custom API transform. Please install them with the following command: pip install httpx fake_useragent"
)

def name(self) -> str:
return TransformType.CUSTOM_API

async def _get_result(
self, url: str, params: Dict, verify=True, headers=HEADERS, retry_count=0
) -> Dict[str, Any]:
if retry_count >= self.max_retries:
logger.warning(f"Max retries reached for URL: {url}")
raise TransformError(
TransformErrorType.MAX_RETRIES_REACHED,
f"Max retries reached for URL: {url}",
)

try:
client = self.client
if not verify:
client = self.client_with_no_verify
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
return response.text
except self.httpx.ConnectTimeout as e:
logger.error(f"Timeout when fetching content from URL: {url}")
raise TransformError(
TransformErrorType.TRANSFORM_TIMEOUT,
"Timeout when fetching content from URL",
)
except ssl.SSLCertVerificationError as e:
logger.warning(
f"SSL verification error when fetching content from URL: {url}, retrying with verify=False"
)
await asyncio.sleep(BACKOFF**retry_count)
return await self._get_result(
url, verify=False, headers=headers, retry_count=retry_count + 1
)
except Exception as e:
logger.error(f"Error fetching content from URL: {url}. Exception: {e}")
raise e

async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
params = {}
for col in self.request_columns:
if col not in row:
logger.error(
f"Missing request column: {col} in row {row}",
)
else:
params[col] = row.get(col)
result = await self._get_result(self.base_url, params)
transformed_row = {self.output_columns["result"]: result}
return self._return_output_row(transformed_row)

def params(self):
return {
"output_columns": self.output_columns,
"base_url": self.base_url,
"request_columns": self.request_columns,
}

def input_columns(self) -> List[str]:
return self.request_columns
1 change: 1 addition & 0 deletions src/autolabel/transforms/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class TransformType(str, Enum):
IMAGE = "image"
WEB_SEARCH_SERP_API = "web_search_serp_api"
WEB_SEARCH_SERPER = "web_search"
CUSTOM_API = "custom_api"


class TransformCacheEntry(BaseModel):
Expand Down

0 comments on commit 26ccd92

Please sign in to comment.