Skip to content

Commit

Permalink
Merge e48f0e6 into 8094611
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshageman-stripe authored Jan 29, 2019
2 parents 8094611 + e48f0e6 commit 543b967
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 5 deletions.
26 changes: 21 additions & 5 deletions stripe/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import platform
import time
import uuid
import warnings

import stripe
from stripe import error, oauth_error, http_client, version, util, six
Expand Down Expand Up @@ -83,16 +84,31 @@ def __init__(
self.api_version = api_version or stripe.api_version
self.stripe_account = account

self._default_proxy = None

from stripe import verify_ssl_certs as verify
from stripe import proxy

self._client = (
client
or stripe.default_http_client
or http_client.new_default_http_client(
if client:
self._client = client
elif stripe.default_http_client:
self._client = stripe.default_http_client
if proxy != self._default_proxy:
warnings.warn(
"Warning: stripe.proxy was updated after sending a" +
" request - this is a no-op. To use a different proxy," +
" set stripe.default_http_client to a new client" +
" configured with the proxy."
)
else:
# If the stripe.default_http_client has not been set by the user
# yet, we'll set it here. This way, we aren't creating a new
# HttpClient for every request.
stripe.default_http_client = http_client.new_default_http_client(
verify_ssl_certs=verify, proxy=proxy
)
)
self._client = stripe.default_http_client
self._default_proxy = proxy

self._last_request_metrics = None

Expand Down
22 changes: 22 additions & 0 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,17 @@ def setup_stripe(self):
orig_attrs = {
"api_key": stripe.api_key,
"api_version": stripe.api_version,
"default_http_client": stripe.default_http_client,
"enable_telemetry": stripe.enable_telemetry,
}
stripe.api_key = "sk_test_123"
stripe.api_version = "2017-12-14"
stripe.default_http_client = None
stripe.enable_telemetry = False
yield
stripe.api_key = orig_attrs["api_key"]
stripe.api_version = orig_attrs["api_version"]
stripe.default_http_client = orig_attrs["default_http_client"]
stripe.enable_telemetry = orig_attrs["enable_telemetry"]

@pytest.fixture
Expand Down Expand Up @@ -485,6 +488,25 @@ def test_uses_instance_account(
),
)

def test_sets_default_http_client(self, http_client):
assert not stripe.default_http_client

stripe.api_requestor.APIRequestor(client=http_client)

# default_http_client is not populated if a client is provided
assert not stripe.default_http_client

stripe.api_requestor.APIRequestor()

# default_http_client is set when no client is specified
assert stripe.default_http_client

new_default_client = stripe.default_http_client
stripe.api_requestor.APIRequestor()

# the newly created client is reused
assert stripe.default_http_client == new_default_client

def test_uses_app_info(self, requestor, mock_response, check_call):
try:
old = stripe.app_info
Expand Down
128 changes: 128 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import sys
from threading import Thread
import json
import warnings

import stripe
import pytest

if sys.version_info[0] < 3:
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
else:
from http.server import BaseHTTPRequestHandler, HTTPServer


class TestIntegration(object):
@pytest.fixture(autouse=True)
def close_mock_server(self):
yield
if self.mock_server:
self.mock_server.shutdown()
self.mock_server.server_close()
self.mock_server_thread.join()

@pytest.fixture(autouse=True)
def setup_stripe(self):
orig_attrs = {
"api_base": stripe.api_base,
"api_key": stripe.api_key,
"default_http_client": stripe.default_http_client,
"proxy": stripe.proxy,
}
stripe.api_base = "http://localhost:12111" # stripe-mock
stripe.api_key = "sk_test_123"
stripe.default_http_client = None
stripe.proxy = None
yield
stripe.api_base = orig_attrs["api_base"]
stripe.api_key = orig_attrs["api_key"]
stripe.default_http_client = orig_attrs["default_http_client"]
stripe.proxy = orig_attrs["proxy"]

def setup_mock_server(self, handler):
# Configure mock server.
# Passing 0 as the port will cause a random free port to be chosen.
self.mock_server = HTTPServer(("localhost", 0), handler)
_, self.mock_server_port = self.mock_server.server_address

# Start running mock server in a separate thread.
# Daemon threads automatically shut down when the main process exits.
self.mock_server_thread = Thread(target=self.mock_server.serve_forever)
self.mock_server_thread.setDaemon(True)
self.mock_server_thread.start()

def test_hits_api_base(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return

self.setup_mock_server(MockServerRequestHandler)

stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.Balance.retrieve()
assert MockServerRequestHandler.num_requests == 1

def test_hits_proxy_through_default_http_client(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return

self.setup_mock_server(MockServerRequestHandler)

stripe.proxy = "http://localhost:%s" % self.mock_server_port
stripe.Balance.retrieve()
assert MockServerRequestHandler.num_requests == 1

stripe.proxy = "http://bad-url"

with warnings.catch_warnings(record=True) as w:
stripe.Balance.retrieve()
assert len(w) == 1
assert "stripe.proxy was updated after sending a request" in str(
w[0].message
)

assert MockServerRequestHandler.num_requests == 2

def test_hits_proxy_through_custom_client(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return

self.setup_mock_server(MockServerRequestHandler)

stripe.default_http_client = stripe.http_client.new_default_http_client(
proxy="http://localhost:%s" % self.mock_server_port
)
stripe.Balance.retrieve()
assert MockServerRequestHandler.num_requests == 1

0 comments on commit 543b967

Please sign in to comment.