Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions extend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,41 @@ async def patch(self, url: str, data: Dict) -> Any:
response.raise_for_status()
return response.json()

async def post_multipart(
self,
url: str,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any]] = None,
) -> Any:
"""Make a POST request with multipart/form-data payload.

This method is designed to support file uploads along with optional form data.

Args:
url (str): The API endpoint path (e.g., "/receiptattachments")
data (Optional[Dict[str, Any]]): Optional form fields to include in the request.
files (Optional[Dict[str, Any]]): Files to be uploaded. For example,
{"file": file_obj} where file_obj is an open file in binary mode.

Returns:
The JSON response from the API.

Raises:
httpx.HTTPError: If the request fails.
ValueError: If the response is not valid JSON.
"""
# When sending multipart data, we pass `data` (for non-file fields)
# and `files` (for file uploads) separately.
async with httpx.AsyncClient() as client:
response = await client.post(
self.build_full_url(url),
headers=self.headers,
data=data,
files=files,
timeout=httpx.Timeout(30)
)
response.raise_for_status()
return response.json()

def build_full_url(self, url: Optional[str]):
return f"https://{API_HOST}{url or ''}"
2 changes: 2 additions & 0 deletions extend/extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .client import APIClient
from .resources.credit_cards import CreditCards
from .resources.expense_data import ExpenseData
from .resources.receipt_attachments import ReceiptAttachments
from .resources.transactions import Transactions


Expand Down Expand Up @@ -31,3 +32,4 @@ def __init__(self, api_key: str, api_secret: str):
self.virtual_cards = VirtualCards(self._api_client)
self.transactions = Transactions(self._api_client)
self.expense_data = ExpenseData(self._api_client)
self.receipt_attachments = ReceiptAttachments(self._api_client)
43 changes: 43 additions & 0 deletions extend/resources/receipt_attachments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Dict, IO

from extend.client import APIClient
from .resource import Resource


class ReceiptAttachments(Resource):
@property
def _base_url(self) -> str:
return "/receiptattachments"

def __init__(self, api_client: APIClient):
super().__init__(api_client)

async def create_receipt_attachment(
self,
transaction_id: str,
file: IO,
) -> Dict:
"""Create a receipt attachment for a transaction by uploading a file using multipart form data.

Args:
transaction_id (str): The unique identifier of the transaction to attach the receipt to
file (IO): A file-like object opened in binary mode that contains the data
to be uploaded

Returns:
Dict: A dictionary representing the receipt attachment, including:
- id: Unique identifier of the receipt attachment.
- transactionId: The associated transaction ID.
- contentType: The MIME type of the uploaded file.
- urls: A dictionary with URLs for the original image, main image, and thumbnail.
- createdAt: Timestamp when the receipt attachment was created.
- uploadType: A string describing the type of upload (e.g., "TRANSACTION", "VIRTUAL_CARD").

Raises:
httpx.HTTPError: If the request fails
"""

return await self._request(
method="post_multipart",
data={"transaction_id": transaction_id},
files={"file": file})
10 changes: 9 additions & 1 deletion extend/resources/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ async def _request(
self,
method: str,
path: str = None,
params: Optional[Dict] = None
params: Optional[Dict] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any]] = None,
) -> Any:
if params is not None:
params = {k: v for k, v in params.items() if v is not None}
Expand All @@ -33,6 +35,12 @@ async def _request(
return await self._api_client.put(self.build_full_path(path), params)
case "patch":
return await self._api_client.patch(self.build_full_path(path), params)
case "post_multipart":
return await self._api_client.post_multipart(
self.build_full_path(path),
data=data,
files=files
)
case _:
raise ValueError(f"Unsupported HTTP method: {method}")

Expand Down
6 changes: 5 additions & 1 deletion extend/resources/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async def get_transactions(
min_amount_cents: Optional[int] = None,
max_amount_cents: Optional[int] = None,
search_term: Optional[str] = None,
sort_field: Optional[str] = None,
) -> Dict:
"""Get a list of transactions with optional filtering and pagination.

Expand All @@ -34,6 +35,9 @@ async def get_transactions(
min_amount_cents (int): Minimum clearing amount in cents
max_amount_cents (int): Maximum clearing amount in cents
search_term (Optional[str]): Filter cards by search term (e.g., "Marketing")
sort_field (Optional[str]): Field to sort by, with optional direction
Use "recipientName", "merchantName", "amount", "date" for ASC
Use "-recipientName", "-merchantName", "-amount", "-date" for DESC

Returns:
Dict: A dictionary containing:
Expand All @@ -57,8 +61,8 @@ async def get_transactions(
"minClearingBillingCents": min_amount_cents,
"maxClearingBillingCents": max_amount_cents,
"search": search_term,
"sort": sort_field,
}
params = {k: v for k, v in params.items() if v is not None}

return await self._request(method="get", params=params)

Expand Down
7 changes: 6 additions & 1 deletion extend/resources/virtual_cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ async def get_virtual_cards(
status: Optional[str] = None,
recipient: Optional[str] = None,
search_term: Optional[str] = None,
sort_field: Optional[str] = None,
sort_direction: Optional[str] = None,
) -> Dict:
"""Get a list of virtual cards with optional filtering and pagination.

Expand All @@ -30,6 +32,8 @@ async def get_virtual_cards(
status (Optional[str]): Filter cards by status (e.g., "ACTIVE", "CANCELLED")
recipient (Optional[str]): Filter cards by recipient id (e.g., "u_1234")
search_term (Optional[str]): Filter cards by search term (e.g., "Marketing")
sort_field (Optional[str]): Field to sort by "createdAt", "updatedAt", "balanceCents", "displayName", "type", or "status"
sort_direction (Optional[str]): Direction to sort (ASC or DESC)

Returns:
Dict: A dictionary containing:
Expand All @@ -49,8 +53,9 @@ async def get_virtual_cards(
"statuses": status,
"recipient": recipient,
"search": search_term,
"sortField": sort_field,
"sortDirection": sort_direction,
}
params = {k: v for k, v in params.items() if v is not None}

return await self._request(method="get", params=params)

Expand Down
182 changes: 182 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import uuid
from datetime import datetime, timedelta
from io import BytesIO

import pytest
from dotenv import load_dotenv
Expand Down Expand Up @@ -129,6 +130,47 @@ async def test_list_virtual_cards(self, extend):
for card in response["virtualCards"]:
assert card["status"] == "CLOSED"

@pytest.mark.asyncio
async def test_list_virtual_cards_with_sorting(self, extend):
"""Test listing virtual cards with various sorting options"""

# Test sorting by display name ascending
asc_response = await extend.virtual_cards.get_virtual_cards(
sort_field="displayName",
sort_direction="ASC",
per_page=50 # Ensure we get enough cards to compare
)

# Test sorting by display name descending
desc_response = await extend.virtual_cards.get_virtual_cards(
sort_field="displayName",
sort_direction="DESC",
per_page=50 # Ensure we get enough cards to compare
)

# Verify responses contain cards
assert "virtualCards" in asc_response
assert "virtualCards" in desc_response

# If sufficient cards exist, just verify the orders are different
# rather than trying to implement our own sorting logic
if len(asc_response["virtualCards"]) > 1 and len(desc_response["virtualCards"]) > 1:
asc_ids = [card["id"] for card in asc_response["virtualCards"]]
desc_ids = [card["id"] for card in desc_response["virtualCards"]]

# Verify the orders are different for different sort directions
assert asc_ids != desc_ids, "ASC and DESC sorting should produce different results"

# Test other sort fields
for field in ["createdAt", "updatedAt", "balanceCents", "status", "type"]:
# Test both directions for each field
for direction in ["ASC", "DESC"]:
response = await extend.virtual_cards.get_virtual_cards(
sort_field=field,
sort_direction=direction
)
assert "virtualCards" in response, f"Sorting by {field} {direction} should return virtual cards"


@pytest.mark.integration
class TestTransactions:
Expand All @@ -152,6 +194,54 @@ async def test_list_transactions(self, extend):
for field in required_fields:
assert field in transaction, f"Transaction should contain '{field}' field"

@pytest.mark.asyncio
async def test_list_transactions_with_sorting(self, extend):
"""Test listing transactions with various sorting options"""

# Define sort fields - positive for ASC, negative (prefixed with -) for DESC
sort_fields = [
"recipientName", "-recipientName",
"merchantName", "-merchantName",
"amount", "-amount",
"date", "-date"
]

# Test each sort field
for sort_field in sort_fields:
# Get transactions with this sort
response = await extend.transactions.get_transactions(
sort_field=sort_field,
per_page=10
)

# Verify response contains transactions and basic structure
assert isinstance(response, dict), f"Response for sort {sort_field} should be a dictionary"
assert "transactions" in response, f"Response for sort {sort_field} should contain 'transactions' key"

# If we have enough data, test opposite sort direction for comparison
if len(response["transactions"]) > 1:
# Determine the field name and opposite sort field
is_desc = sort_field.startswith("-")
field_name = sort_field[1:] if is_desc else sort_field
opposite_sort = field_name if is_desc else f"-{field_name}"

# Get transactions with opposite sort
opposite_response = await extend.transactions.get_transactions(
sort_field=opposite_sort,
per_page=10
)

# Get IDs in both sort orders for comparison
sorted_ids = [tx["id"] for tx in response["transactions"]]
opposite_sorted_ids = [tx["id"] for tx in opposite_response["transactions"]]

# If we have the same set of transactions in both responses,
# verify that different sort directions produce different orders
if set(sorted_ids) == set(opposite_sorted_ids) and len(sorted_ids) > 1:
assert sorted_ids != opposite_sorted_ids, (
f"Different sort directions for {field_name} should produce different results"
)


@pytest.mark.integration
class TestRecurringCards:
Expand Down Expand Up @@ -303,6 +393,98 @@ async def test_get_expense_categories_and_labels(self, extend):
assert "expenseLabels" in labels


@pytest.mark.integration
class TestTransactionExpenseData:
"""Integration tests for updating transaction expense data using a specific expense category and label"""

@pytest.mark.asyncio
async def test_update_transaction_expense_data_with_specific_category_and_label(self, extend):
"""Test updating the expense data for a transaction using a specific expense category and label."""
# Retrieve available expense categories (active ones)
categories_response = await extend.expense_data.get_expense_categories(active=True)
assert "expenseCategories" in categories_response, "Response should include 'expenseCategories'"
expense_categories = categories_response["expenseCategories"]
assert expense_categories, "No expense categories available for testing"

# For this test, pick the first expense category
category = expense_categories[0]
category_id = category["id"]

# Retrieve the labels for the chosen expense category
labels_response = await extend.expense_data.get_expense_category_labels(
category_id=category_id,
page=0,
per_page=10
)
assert "expenseLabels" in labels_response, "Response should include 'expenseLabels'"
expense_labels = labels_response["expenseLabels"]
assert expense_labels, "No expense labels available for the selected category"

# Pick the first label from the list
label = expense_labels[0]
label_id = label["id"]

# Retrieve at least one transaction to update expense data
transactions_response = await extend.transactions.get_transactions(per_page=1)
assert transactions_response.get("transactions"), "No transactions available for testing expense data update"
transaction = transactions_response["transactions"][0]
transaction_id = transaction["id"]

# Prepare the expense data payload with the specific category and label
update_payload = {
"expenseDetails": [
{
"categoryId": category_id,
"labelId": label_id
}
]
}

# Call the update_transaction_expense_data method
response = await extend.transactions.update_transaction_expense_data(transaction_id, update_payload)

# Verify the response contains the transaction id and expected expense details
assert "id" in response, "Response should include the transaction id"
if "expenseDetails" in response:
# Depending on the API response, the structure might vary; adjust assertions accordingly
assert response["expenseDetails"] == update_payload["expenseDetails"], (
"Expense details in the response should match the update payload"
)


@pytest.mark.integration
class TestReceiptAttachments:
"""Integration tests for receipt attachment operations"""

@pytest.mark.asyncio
async def test_create_receipt_attachment(self, extend):
"""Test creating a receipt attachment via multipart upload."""
# Create a dummy PNG file in memory
# This is a minimal PNG header plus extra bytes to simulate file content.
png_header = b'\x89PNG\r\n\x1a\n'
dummy_content = png_header + b'\x00' * 100
file_obj = BytesIO(dummy_content)
# Optionally set a name attribute for file identification in the upload
file_obj.name = f"test_receipt_{uuid.uuid4()}.png"

# Retrieve a valid transaction id from existing transactions
transactions_response = await extend.transactions.get_transactions(page=0, per_page=1)
assert transactions_response.get("transactions"), "No transactions available for testing receipt attachment"
transaction_id = transactions_response["transactions"][0]["id"]

# Call the receipt attachment upload method
response = await extend.receipt_attachments.create_receipt_attachment(
transaction_id=transaction_id,
file=file_obj
)

# Assert that the response contains expected keys
assert "id" in response, "Receipt attachment should have an id"
assert "urls" in response, "Receipt attachment should include urls"
assert "contentType" in response, "Receipt attachment should include a content type"
assert response["contentType"] == "image/png", "Content type should be 'image/png'"


def test_environment_variables():
"""Test that required environment variables are set"""
assert os.getenv("EXTEND_API_KEY"), "EXTEND_API_KEY environment variable is required"
Expand Down