Skip to content

Commit

Permalink
Merge pull request #81 from vbalalian/airbyte-integration
Browse files Browse the repository at this point in the history
Fix Airbyte custom source connector state retention
  • Loading branch information
vbalalian authored Jan 6, 2024
2 parents 47f5f6f + 30c2a06 commit ddef598
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Deploy API

on:
push:
branches: [ master ]
branches: [master]
paths:
- 'api/**'

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/deploy-custom-airbyte-connector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Deploy Custom Airbyte Connector

on:
push:
branches: [ master ]
branches: [master]
paths:
- 'custom-airbyte-connector/**'

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/deploy-web-scraper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Deploy Web Scraper

on:
push:
branches: [ master ]
branches: [master]
paths:
- 'web_scraping/**'

Expand Down
16 changes: 5 additions & 11 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,10 @@ async def read_coins(
items_per_page=page_size
)

if coins:
return PaginatedResponse(
data = [dict(row) for row in coins],
pagination=pagination
)
else:
raise HTTPException(status_code=400, detail='No matching coins found')
return PaginatedResponse(
data = [dict(row) for row in coins] if coins else [],
pagination=pagination
)

except psycopg2.Error as e:
print('Database error:', e)
Expand All @@ -272,10 +269,7 @@ async def search_coins(
print('Search error:', e)
finally:
cur.close()
if search_result:
return [dict(row) for row in search_result]
else:
raise HTTPException(status_code=400, detail='No matching coins found')
return [dict(row) for row in search_result] if search_result else []

# Coins by ID endpoint
@app.get('/v1/coins/id/{coin_id}', response_model=Coin, response_model_exclude_none=True)
Expand Down
7 changes: 4 additions & 3 deletions api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_read_coins(test_client, test_database):
assert len(response.json()["data"]) == 1

response = test_client.get(r"/v1/coins/?min_diameter=30")
assert response.status_code == 400
assert response.status_code == 200
assert len(response.json()["data"]) == 0

# Coin Search endpoint
def test_search_coins(test_client, test_database):
Expand All @@ -182,8 +183,8 @@ def test_search_coins(test_client, test_database):

# Query with no results
response = test_client.get(r"/v1/coins/search?query=Caligula")
assert response.status_code == 400
assert response.json()["detail"]== "No matching coins found"
assert response.status_code == 200
assert len(response.json()) == 0

# Empty query
response = test_client.get(r"/v1/coins/search?query=")
Expand Down
72 changes: 31 additions & 41 deletions custom-airbyte-connector/source_roman_coin_api/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
from datetime import datetime, timedelta
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple
import requests

from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams import Stream, IncrementalMixin
from airbyte_cdk.sources.streams.http import HttpStream
# from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator # Authentication not currently implemented

url_base = "http://host.docker.internal:8010/v1/"

# Incremental stream
class RomanCoinApiStream(HttpStream):
class RomanCoinApiStream(HttpStream, IncrementalMixin):

# Save the state every 100 records
state_checkpoint_interval = 100

url_base = url_base
cursor_field = "modified"
Expand All @@ -23,35 +25,19 @@ class RomanCoinApiStream(HttpStream):
def __init__(self, config:Mapping[str, Any], start_date:datetime, **kwargs):
super().__init__()
self.start_date = start_date
self._cursor_value = datetime.min

def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> Mapping[str, any]:
latest_state = max(latest_record[self.cursor_field], current_stream_state.get(self.cursor_field, self.start_date.strftime("%Y-%m-%d")))
return {self.cursor_field: latest_state}
self._cursor_value = datetime.strptime(start_date, "%Y-%m-%dT%H:%M:%S.%f") if isinstance(start_date, str) else start_date

def _chunk_date_range(self, start_date: datetime) -> List[Mapping[str, Any]]:
"""
Returns a list of each day between the start date and now.
The return value is a list of dicts {'date': date_string}.
"""
dates = []
while start_date < datetime.now():
dates.append({self.cursor_field: start_date.strftime('%Y-%m-%d')})
start_date += timedelta(days=1)
return dates
@property
def state(self) -> Mapping[str, Any]:
return {self.cursor_field: self._cursor_value.strftime("%Y-%m-%dT%H:%M:%S.%f")}

def stream_slices(self, sync_mode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None) -> Iterable[Optional[Mapping[str, Any]]]:
start_date = datetime.strptime(stream_state[self.cursor_field], '%Y-%m-%d') if stream_state and self.cursor_field in stream_state else self.start_date
return self._chunk_date_range(start_date)
@state.setter
def state(self, value: Mapping[str, Any]):
self._cursor_value = datetime.strptime(value[self.cursor_field], "%Y-%m-%dT%H:%M:%S.%f")

def path(
self,
stream_state: Mapping[str, Any] = None,
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None
) -> str:
def path(self, stream_state: Mapping[str, Any] = None, stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None) -> str:
return "coins/"

def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
json_response = response.json()
pagination_info = json_response.get("pagination", {})
Expand All @@ -65,25 +51,29 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str,

def parse_response(self, response: requests.Response, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None) -> Iterable[Mapping]:
json_response = response.json()
records = json_response.get('data', []) # Extract records from 'data' key

records = json_response.get('data', [])
for record in records:
record_modified = datetime.fromisoformat(record[self.cursor_field])
self._cursor_value = max(self._cursor_value, record_modified)
yield record

if self._cursor_value:
self.state = {self.cursor_field: self._cursor_value.isoformat()}

def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]:
for record in super().read_records(*args, **kwargs):
record_cursor_value = datetime.strptime(record[self.cursor_field], "%Y-%m-%dT%H:%M:%S.%f")
if record_cursor_value > self._cursor_value:
yield record
self._cursor_value = record_cursor_value

def request_params(self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None) -> MutableMapping[str, Any]:
params = {
"page": next_page_token["page"] if next_page_token else 1,
"page_size": 10,
"sort_by": "modified",
"desc": True
"page_size": 100,
"sort_by": "modified"
}
if stream_state and self.cursor_field in stream_state:
params["start_modified"] = stream_state[self.cursor_field]
if stream_state:
last_synced_time = datetime.strptime(stream_state[self.cursor_field], "%Y-%m-%dT%H:%M:%S.%f")
next_start_time = last_synced_time + timedelta(microseconds=1)
params["start_modified"] = next_start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")
else:
params["start_modified"] = self.start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")
return params

# Source
Expand All @@ -97,5 +87,5 @@ def check_connection(self, logger, config) -> Tuple[bool, any]:
return False, f"Connection check failed: {e}"

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
start_date = datetime.strptime(config["start_date"], '%Y-%m-%d')
start_date = datetime.strptime(config["start_date"], "%Y-%m-%d")
return [RomanCoinApiStream(config=config, start_date=start_date)]

0 comments on commit ddef598

Please sign in to comment.