diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 339dd6986b15..a16baf1039e6 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -11,6 +11,7 @@ import json import logging from multiprocessing.pool import ThreadPool +import re from dateutil.parser import parse as dparse from flask import escape, Markup @@ -107,24 +108,29 @@ def data(self): 'backend': 'druid', } + @staticmethod + def get_base_url(host, port): + if not re.match('http(s)?://', host): + host = 'http://' + host + return '{0}:{1}'.format(host, port) + + def get_base_coordinator_url(self): + base_url = self.get_base_url( + self.coordinator_host, self.coordinator_port) + return '{base_url}/{self.coordinator_endpoint}'.format(**locals()) + def get_pydruid_client(self): cli = PyDruid( - 'http://{0}:{1}/'.format(self.broker_host, self.broker_port), + self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint) return cli def get_datasources(self): - endpoint = ( - 'http://{obj.coordinator_host}:{obj.coordinator_port}/' - '{obj.coordinator_endpoint}/datasources' - ).format(obj=self) - + endpoint = self.get_base_coordinator_url() + '/datasources' return json.loads(requests.get(endpoint).text) def get_druid_version(self): - endpoint = ( - 'http://{obj.coordinator_host}:{obj.coordinator_port}/status' - ).format(obj=self) + endpoint = self.get_base_coordinator_url() + '/status' return json.loads(requests.get(endpoint).text)['version'] def refresh_datasources( diff --git a/tests/druid_tests.py b/tests/druid_tests.py index d2a44f968d5b..fc360b6656bc 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -77,6 +77,16 @@ class DruidTests(SupersetTestCase): def __init__(self, *args, **kwargs): super(DruidTests, self).__init__(*args, **kwargs) + def get_test_cluster_obj(self): + return DruidCluster( + cluster_name='test_cluster', + coordinator_host='localhost', + coordinator_endpoint='druid/coordinator/v1/metadata', + coordinator_port=7979, + broker_host='localhost', + broker_port=7980, + metadata_last_refreshed=datetime.now()) + @patch('superset.connectors.druid.models.PyDruid') def test_client(self, PyDruid): self.login(username='admin') @@ -95,13 +105,7 @@ def test_client(self, PyDruid): db.session.delete(cluster) db.session.commit() - cluster = DruidCluster( - cluster_name='test_cluster', - coordinator_host='localhost', - coordinator_port=7979, - broker_host='localhost', - broker_port=7980, - metadata_last_refreshed=datetime.now()) + cluster = self.get_test_cluster_obj() db.session.add(cluster) cluster.get_datasources = PickableMock(return_value=['test_datasource']) @@ -323,6 +327,21 @@ def test_sync_druid_perm(self, PyDruid): permission=permission, view_menu=view_menu).first() assert pv is not None + def test_urls(self): + cluster = self.get_test_cluster_obj() + self.assertEquals( + cluster.get_base_url('localhost', '9999'), 'http://localhost:9999') + self.assertEquals( + cluster.get_base_url('http://localhost', '9999'), + 'http://localhost:9999') + self.assertEquals( + cluster.get_base_url('https://localhost', '9999'), + 'https://localhost:9999') + + self.assertEquals( + cluster.get_base_coordinator_url(), + 'http://localhost:7979/druid/coordinator/v1/metadata') + if __name__ == '__main__': unittest.main()