Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhive/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DBAPICursor(object):
_STATE_NONE = 0
_STATE_RUNNING = 1
_STATE_FINISHED = 2
_STATE_CANCELLED = 3

def __init__(self, poll_interval=1):
self._poll_interval = poll_interval
Expand Down
20 changes: 18 additions & 2 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def execute(self, operation, parameters=None):
response = requests.post(url, data=sql.encode('utf-8'), headers=headers)
self._process_response(response)

def cancel(self):
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
if self._nextUri is None:
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
return None
response = requests.delete(self._nextUri)
self._process_response(response)
return response.status_code

def poll(self):
"""Poll for and return the raw status data provided by the Presto REST API.

Expand Down Expand Up @@ -195,12 +205,18 @@ def _process_response(self, response):
URI and any data from the response
"""
# TODO handle HTTP 503
if response.status_code != requests.codes.ok:
if response.status_code not in (requests.codes.ok, requests.codes.no_content):
fmt = "Unexpected status code {}\n{}"
raise OperationalError(fmt.format(response.status_code, response.content))

if response.status_code == requests.codes.no_content:
self._state = self._STATE_CANCELLED
return

response_json = response.json()
_logger.debug("Got response %s", response_json)
assert self._state == self._STATE_RUNNING, "Should be running if processing response"
assert self._state in (self._STATE_RUNNING, self._STATE_CANCELLED), \
"Should be running or cancelled if processing response"
self._nextUri = response_json.get('nextUri')
self._columns = response_json.get('columns')
if 'data' in response_json:
Expand Down
15 changes: 13 additions & 2 deletions pyhive/tests/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
These rely on having a Presto+Hadoop cluster set up.
They also require a tables created by make_test_tables.sh.
"""

from __future__ import absolute_import
from __future__ import unicode_literals
from pyhive import exc
Expand All @@ -14,13 +13,14 @@
import unittest

_HOST = 'localhost'
_PORT = '8080'


class TestPresto(unittest.TestCase, DBAPITestCase):
__test__ = True

def connect(self):
return presto.connect(host=_HOST, source=self.id())
return presto.connect(host=_HOST, port=_PORT, source=self.id())

@with_cursor
def test_description(self, cursor):
Expand Down Expand Up @@ -66,6 +66,17 @@ def test_complex(self, cursor):
#0.1,
]])

@with_cursor
def test_cancel(self, cursor):
cursor.execute(
"SELECT a.a * rand(), b.a*rand()"
"FROM many_rows a "
"CROSS JOIN many_rows b "
)
self.assertIn(cursor.poll()['stats']['state'], ('PLANNING', 'RUNNING'))
cursor.cancel()
self.assertRaises(exc.DatabaseError, cursor.poll)

def test_noops(self):
"""The DB-API specification requires that certain actions exist, even though they might not
be applicable."""
Expand Down