From bf6a0e6a50d6c7b074ce6c4d737e8f768cca4453 Mon Sep 17 00:00:00 2001 From: Jac Date: Thu, 22 Sep 2022 15:50:32 -0700 Subject: [PATCH 01/18] Jac/headers (#1117) * ssl-verify is an option, not a header --- samples/create_group.py | 2 +- .../server/endpoint/endpoint.py | 22 +++----- tableauserverclient/server/server.py | 30 +++++----- test/http/test_http_requests.py | 56 +++++++++++++++++++ 4 files changed, 80 insertions(+), 30 deletions(-) create mode 100644 test/http/test_http_requests.py diff --git a/samples/create_group.py b/samples/create_group.py index 50d84a187..d5cf712db 100644 --- a/samples/create_group.py +++ b/samples/create_group.py @@ -46,7 +46,7 @@ def main(): logging.basicConfig(level=logging_level) tableau_auth = TSC.PersonalAccessTokenAuth(args.token_name, args.token_value, site_id=args.site) - server = TSC.Server(args.server, use_server_version=True) + server = TSC.Server(args.server, use_server_version=True, http_options={"verify": False}) with server.auth.sign_in(tableau_auth): # this code shows 3 different error codes that mean "resource is already in collection" # 409009: group already exists on server diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 378c84746..a7b33068b 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -11,9 +11,12 @@ NonXMLResponseError, EndpointUnavailableError, ) -from .. import endpoint from ..query import QuerySet from ... import helpers +from ..._version import get_versions + +__TSC_VERSION__ = get_versions()["version"] +del get_versions logger = logging.getLogger("tableau.endpoint") @@ -22,34 +25,25 @@ XML_CONTENT_TYPE = "text/xml" JSON_CONTENT_TYPE = "application/json" +USERAGENT_HEADER = "User-Agent" + if TYPE_CHECKING: from ..server import Server from requests import Response -_version_header: Optional[str] = None - - class Endpoint(object): def __init__(self, parent_srv: "Server"): - global _version_header self.parent_srv = parent_srv @staticmethod def _make_common_headers(auth_token, content_type): - global _version_header - - if not _version_header: - from ..server import __TSC_VERSION__ - - _version_header = __TSC_VERSION__ - headers = {} if auth_token is not None: headers["x-tableau-auth"] = auth_token if content_type is not None: headers["content-type"] = content_type - headers["User-Agent"] = "Tableau Server Client/{}".format(_version_header) + headers["User-Agent"] = "Tableau Server Client/{}".format(__TSC_VERSION__) return headers def _make_request( @@ -62,9 +56,9 @@ def _make_request( parameters: Optional[Dict[str, Any]] = None, ) -> "Response": parameters = parameters or {} - parameters.update(self.parent_srv.http_options) if "headers" not in parameters: parameters["headers"] = {} + parameters.update(self.parent_srv.http_options) parameters["headers"].update(Endpoint._make_common_headers(auth_token, content_type)) if content is not None: diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index c82f4a6e2..18f5834b1 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -37,11 +37,6 @@ from ..namespace import Namespace -from .._version import get_versions - -__TSC_VERSION__ = get_versions()["version"] -del get_versions - _PRODUCT_TO_REST_VERSION = { "10.0": "2.3", "9.3": "2.2", @@ -51,7 +46,6 @@ } minimum_supported_server_version = "2.3" default_server_version = "2.3" -client_version_header = "X-TableauServerClient-Version" class Server(object): @@ -98,23 +92,29 @@ def __init__(self, server_address, use_server_version=False, http_options=None): # must set this before calling use_server_version, because that's a server call if http_options: self.add_http_options(http_options) - self.add_http_version_header() if use_server_version: self.use_server_version() - def add_http_options(self, options_dict): - self._http_options.update(options_dict) - if options_dict.get("verify") == False: + def add_http_options(self, option_pair: dict): + if not option_pair: + # log debug message + return + if len(option_pair) != 1: + raise ValueError( + "Update headers one at a time. Expected type: ", + {"key": 12}.__class__, + "Actual type: ", + option_pair, + option_pair.__class__, + ) + self._http_options.update(option_pair) + if "verify" in option_pair.keys() and self._http_options.get("verify") is False: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - def add_http_version_header(self): - if not self._http_options[client_version_header]: - self._http_options.update({client_version_header: __TSC_VERSION__}) + # would be nice if you could turn them back on def clear_http_options(self): self._http_options = dict() - self.add_http_version_header() def _clear_auth(self): self._site_id = None diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py new file mode 100644 index 000000000..5759b1c2e --- /dev/null +++ b/test/http/test_http_requests.py @@ -0,0 +1,56 @@ +import tableauserverclient as TSC +import unittest +from requests.exceptions import MissingSchema + + +class ServerTests(unittest.TestCase): + def test_init_server_model_empty_throws(self): + with self.assertRaises(TypeError): + server = TSC.Server() + + def test_init_server_model_bad_server_name_complains(self): + # by default, it will just set the version to 2.3 + server = TSC.Server("fake-url") + + def test_init_server_model_valid_server_name_works(self): + # by default, it will just set the version to 2.3 + server = TSC.Server("http://fake-url") + + def test_init_server_model_valid_https_server_name_works(self): + # by default, it will just set the version to 2.3 + server = TSC.Server("https://fake-url") + + def test_init_server_model_bad_server_name_not_version_check(self): + # by default, it will just set the version to 2.3 + server = TSC.Server("fake-url", use_server_version=False) + + def test_init_server_model_bad_server_name_do_version_check(self): + with self.assertRaises(MissingSchema): + server = TSC.Server("fake-url", use_server_version=True) + + def test_init_server_model_bad_server_name_not_version_check_random_options(self): + # by default, it will just set the version to 2.3 + server = TSC.Server("fake-url", use_server_version=False, http_options={"foo": 1}) + + def test_init_server_model_bad_server_name_not_version_check_real_options(self): + # by default, it will attempt to contact the server to check it's version + server = TSC.Server("fake-url", use_server_version=False, http_options={"verify": False}) + + def test_http_options_skip_ssl_works(self): + http_options = {"verify": False} + server = TSC.Server("http://fake-url") + server.add_http_options(http_options) + + # ValueError: dictionary update sequence element #0 has length 1; 2 is required + def test_http_options_multiple_options_fails(self): + http_options_1 = {"verify": False} + http_options_2 = {"birdname": "Parrot"} + server = TSC.Server("http://fake-url") + with self.assertRaises(ValueError): + server.add_http_options([http_options_1, http_options_2]) + + # TypeError: cannot convert dictionary update sequence element #0 to a sequence + def test_http_options_not_sequence_fails(self): + server = TSC.Server("http://fake-url") + with self.assertRaises(ValueError): + server.add_http_options({1, 2, 3}) From a62ad5a2d7027bd3e78a12c72d220b3c84740b50 Mon Sep 17 00:00:00 2001 From: Marwan Baghdad Date: Fri, 23 Sep 2022 06:01:44 +0200 Subject: [PATCH 02/18] Allow injection of sessions (#1111) * Allow injection of session_factory to allow use of a custom session --- tableauserverclient/server/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 18f5834b1..1013def96 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -54,12 +54,13 @@ class PublishMode: Overwrite = "Overwrite" CreateNew = "CreateNew" - def __init__(self, server_address, use_server_version=False, http_options=None): + def __init__(self, server_address, use_server_version=False, http_options=None, session_factory=requests.Session): self._server_address = server_address self._auth_token = None self._site_id = None self._user_id = None - self._session = requests.Session() + self._session_factory = session_factory + self._session = session_factory() self._http_options = dict() self.version = default_server_version @@ -120,7 +121,7 @@ def _clear_auth(self): self._site_id = None self._user_id = None self._auth_token = None - self._session = requests.Session() + self._session = self._session_factory() def _set_auth(self, site_id, user_id, auth_token): self._site_id = site_id From d71b9789d9f51514840e487168836ded4f88c57b Mon Sep 17 00:00:00 2001 From: Jac Date: Fri, 23 Sep 2022 15:36:18 -0700 Subject: [PATCH 03/18] Jac/show server info (#1118) --- contributing.md | 10 ++- .../models/server_info_item.py | 9 ++- tableauserverclient/models/site_item.py | 1 - .../server/endpoint/endpoint.py | 12 ++-- .../server/endpoint/server_info_endpoint.py | 21 ++++-- tableauserverclient/server/server.py | 65 +++++++++++-------- test/http/test_http_requests.py | 27 +++++++- 7 files changed, 100 insertions(+), 45 deletions(-) diff --git a/contributing.md b/contributing.md index 90fbdc4f0..41c339cb6 100644 --- a/contributing.md +++ b/contributing.md @@ -66,18 +66,22 @@ pytest pip install . ``` +### Debugging Tools +See what your outgoing requests look like: https://requestbin.net/ (unaffiliated link not under our control) + + ### Before Committing Our CI runs include a Python lint run, so you should run this locally and fix complaints before committing as this will fail your checkin. ```shell # this will run the formatter without making changes -black --line-length 120 tableauserverclient test samples --check +black . --check # this will format the directory and code for you -black --line-length 120 tableauserverclient test samples +black . # this will run type checking pip install mypy -mypy --show-error-codes --disable-error-code misc --disable-error-code import tableauserverclient test +mypy tableauserverclient test samples ``` diff --git a/tableauserverclient/models/server_info_item.py b/tableauserverclient/models/server_info_item.py index d0ac5d292..350ae3a0d 100644 --- a/tableauserverclient/models/server_info_item.py +++ b/tableauserverclient/models/server_info_item.py @@ -1,3 +1,6 @@ +import warnings +import xml + from defusedxml.ElementTree import fromstring @@ -32,7 +35,11 @@ def rest_api_version(self): @classmethod def from_response(cls, resp, ns): - parsed_response = fromstring(resp) + try: + parsed_response = fromstring(resp) + except xml.etree.ElementTree.ParseError as error: + warnings.warn("Unexpected response for ServerInfo: {}".format(resp)) + return cls("Unknown", "Unknown", "Unknown") product_version_tag = parsed_response.find(".//t:productVersion", namespaces=ns) rest_api_version_tag = parsed_response.find(".//t:restApiVersion", namespaces=ns) diff --git a/tableauserverclient/models/site_item.py b/tableauserverclient/models/site_item.py index 3deda03e2..8c9e8fe8e 100644 --- a/tableauserverclient/models/site_item.py +++ b/tableauserverclient/models/site_item.py @@ -1,7 +1,6 @@ import warnings import xml.etree.ElementTree as ET -from distutils.version import Version from defusedxml.ElementTree import fromstring from .property_decorators import ( property_is_enum, diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index a7b33068b..3cdc49322 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -1,6 +1,6 @@ import requests import logging -from distutils.version import LooseVersion as Version +from packaging.version import Version from functools import wraps from xml.etree.ElementTree import ParseError from typing import Any, Callable, Dict, Optional, TYPE_CHECKING @@ -83,14 +83,12 @@ def _check_status(self, server_response, url: str = None): if server_response.status_code >= 500: raise InternalServerError(server_response, url) elif server_response.status_code not in Success_codes: - # todo: is an error reliably of content-type application/xml? try: raise ServerResponseError.from_response(server_response.content, self.parent_srv.namespace, url) except ParseError: - # This will happen if we get a non-success HTTP code that - # doesn't return an xml error object (like metadata endpoints or 503 pages) - # we convert this to a better exception and pass through the raw - # response body + # This will happen if we get a non-success HTTP code that doesn't return an xml error object + # e.g metadata endpoints, 503 pages, totally different servers + # we convert this to a better exception and pass through the raw response body raise NonXMLResponseError(server_response.content) except Exception: # anything else re-raise here @@ -188,7 +186,7 @@ def api(version): def _decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): - self.parent_srv.assert_at_least_version(version, "endpoint") + self.parent_srv.assert_at_least_version(version, self.__class__.__name__) return func(self, *args, **kwargs) return wrapper diff --git a/tableauserverclient/server/endpoint/server_info_endpoint.py b/tableauserverclient/server/endpoint/server_info_endpoint.py index 2036d8d5e..943aabee6 100644 --- a/tableauserverclient/server/endpoint/server_info_endpoint.py +++ b/tableauserverclient/server/endpoint/server_info_endpoint.py @@ -12,6 +12,19 @@ class ServerInfo(Endpoint): + def __init__(self, server): + self.parent_srv = server + self._info = None + + @property + def serverInfo(self): + if not self._info: + self.get() + return self._info + + def __repr__(self): + return "".format(self.serverInfo) + @property def baseurl(self): return "{0}/serverInfo".format(self.parent_srv.baseurl) @@ -23,10 +36,10 @@ def get(self): server_response = self.get_unauthenticated_request(self.baseurl) except ServerResponseError as e: if e.code == "404003": - raise ServerInfoEndpointNotFoundError + raise ServerInfoEndpointNotFoundError(e) if e.code == "404001": - raise EndpointUnavailableError + raise EndpointUnavailableError(e) raise e - server_info = ServerInfoItem.from_response(server_response.content, self.parent_srv.namespace) - return server_info + self._info = ServerInfoItem.from_response(server_response.content, self.parent_srv.namespace) + return self._info diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 1013def96..ebe11dac7 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -1,3 +1,5 @@ +import warnings + import requests import urllib3 @@ -54,16 +56,14 @@ class PublishMode: Overwrite = "Overwrite" CreateNew = "CreateNew" - def __init__(self, server_address, use_server_version=False, http_options=None, session_factory=requests.Session): - self._server_address = server_address + def __init__(self, server_address, use_server_version=False, http_options=None, session_factory=None): self._auth_token = None self._site_id = None self._user_id = None - self._session_factory = session_factory - self._session = session_factory() - self._http_options = dict() - self.version = default_server_version + self._server_address = server_address + self._session_factory = session_factory or requests.session + self.auth = Auth(self) self.views = Views(self) self.users = Users(self) @@ -90,29 +90,39 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self.flow_runs = FlowRuns(self) self.metrics = Metrics(self) - # must set this before calling use_server_version, because that's a server call + self._session = self._session_factory() + self._http_options = dict() # must set this before making a server call if http_options: self.add_http_options(http_options) + self.validate_server_connection() + + self.version = default_server_version if use_server_version: - self.use_server_version() - - def add_http_options(self, option_pair: dict): - if not option_pair: - # log debug message - return - if len(option_pair) != 1: - raise ValueError( - "Update headers one at a time. Expected type: ", - {"key": 12}.__class__, - "Actual type: ", - option_pair, - option_pair.__class__, - ) - self._http_options.update(option_pair) - if "verify" in option_pair.keys() and self._http_options.get("verify") is False: - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # would be nice if you could turn them back on + self.use_server_version() # this makes a server call + + def validate_server_connection(self): + try: + self._session.prepare_request(requests.Request("GET", url=self._server_address, params=self._http_options)) + except Exception as req_ex: + warnings.warn("Invalid server initialization\n {}".format(req_ex.__str__()), UserWarning) + print("==================") + + def __repr__(self): + return " [Connection: {}, {}]".format(self.baseurl, self.server_info.serverInfo) + + def add_http_options(self, options_dict: dict): + try: + self._http_options.update(options_dict) + if "verify" in options_dict.keys() and self._http_options.get("verify") is False: + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + # would be nice if you could turn them back on + except BaseException as be: + print(be) + # expected errors on invalid input: + # 'set' object has no attribute 'keys', 'list' object has no attribute 'keys' + # TypeError: cannot convert dictionary update sequence element #0 to a sequence (input is a tuple) + raise ValueError("Invalid http options given: {}".format(options_dict)) def clear_http_options(self): self._http_options = dict() @@ -142,9 +152,10 @@ def _determine_highest_version(self): version = self.server_info.get().rest_api_version except ServerInfoEndpointNotFoundError: version = self._get_legacy_version() + except BaseException: + version = self._get_legacy_version() - finally: - self.version = old_version + self.version = old_version return version diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py index 5759b1c2e..a5f4f4669 100644 --- a/test/http/test_http_requests.py +++ b/test/http/test_http_requests.py @@ -1,5 +1,8 @@ import tableauserverclient as TSC import unittest +import requests + +from requests_mock import adapter, mock from requests.exceptions import MissingSchema @@ -33,7 +36,6 @@ def test_init_server_model_bad_server_name_not_version_check_random_options(self server = TSC.Server("fake-url", use_server_version=False, http_options={"foo": 1}) def test_init_server_model_bad_server_name_not_version_check_real_options(self): - # by default, it will attempt to contact the server to check it's version server = TSC.Server("fake-url", use_server_version=False, http_options={"verify": False}) def test_http_options_skip_ssl_works(self): @@ -41,8 +43,13 @@ def test_http_options_skip_ssl_works(self): server = TSC.Server("http://fake-url") server.add_http_options(http_options) + def test_http_options_multiple_options_works(self): + http_options = {"verify": False, "birdname": "Parrot"} + server = TSC.Server("http://fake-url") + server.add_http_options(http_options) + # ValueError: dictionary update sequence element #0 has length 1; 2 is required - def test_http_options_multiple_options_fails(self): + def test_http_options_multiple_dicts_fails(self): http_options_1 = {"verify": False} http_options_2 = {"birdname": "Parrot"} server = TSC.Server("http://fake-url") @@ -54,3 +61,19 @@ def test_http_options_not_sequence_fails(self): server = TSC.Server("http://fake-url") with self.assertRaises(ValueError): server.add_http_options({1, 2, 3}) + + +class SessionTests(unittest.TestCase): + test_header = {"x-test": "true"} + + @staticmethod + def session_factory(): + session = requests.session() + session.headers.update(SessionTests.test_header) + return session + + def test_session_factory_adds_headers(self): + test_request_bin = "http://capture-this-with-mock.com" + with mock() as m: + m.get(url="http://capture-this-with-mock.com/api/2.4/serverInfo", request_headers=SessionTests.test_header) + server = TSC.Server(test_request_bin, use_server_version=True, session_factory=SessionTests.session_factory) From a203a04f28ad7970bdf501fd8a4ade80adc1a587 Mon Sep 17 00:00:00 2001 From: jorwoods Date: Mon, 26 Sep 2022 14:33:18 -0500 Subject: [PATCH 04/18] Fix bug in exposing ExcelRequestOptions and test (#1123) --- tableauserverclient/__init__.py | 1 + tableauserverclient/server/__init__.py | 1 + test/test_view.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tableauserverclient/__init__.py b/tableauserverclient/__init__.py index 394184120..7c1e6d705 100644 --- a/tableauserverclient/__init__.py +++ b/tableauserverclient/__init__.py @@ -39,6 +39,7 @@ ) from .server import ( CSVRequestOptions, + ExcelRequestOptions, ImageRequestOptions, PDFRequestOptions, RequestOptions, diff --git a/tableauserverclient/server/__init__.py b/tableauserverclient/server/__init__.py index 25abb3c9a..84d118a2e 100644 --- a/tableauserverclient/server/__init__.py +++ b/tableauserverclient/server/__init__.py @@ -2,6 +2,7 @@ from .request_factory import RequestFactory from .request_options import ( CSVRequestOptions, + ExcelRequestOptions, ImageRequestOptions, PDFRequestOptions, RequestOptions, diff --git a/test/test_view.py b/test/test_view.py index 3562650d1..f5d3db47b 100644 --- a/test/test_view.py +++ b/test/test_view.py @@ -294,7 +294,7 @@ def test_populate_excel(self) -> None: m.get(self.baseurl + "/d79634e1-6063-4ec9-95ff-50acbf609ff5/crosstab/excel?maxAge=1", content=response) single_view = TSC.ViewItem() single_view._id = "d79634e1-6063-4ec9-95ff-50acbf609ff5" - request_option = TSC.CSVRequestOptions(maxage=1) + request_option = TSC.ExcelRequestOptions(maxage=1) self.server.views.populate_excel(single_view, request_option) excel_file = b"".join(single_view.excel) From af80100af28f135c95c3613bc2a632fae925de97 Mon Sep 17 00:00:00 2001 From: Brian Cantoni Date: Tue, 27 Sep 2022 15:02:50 -0700 Subject: [PATCH 05/18] Fix a few pylint errors (#1124) Pylint with "errors only" isn't 100% accurate, but it found a few problems that should be fixed. --- samples/initialize_server.py | 6 +++--- tableauserverclient/models/flow_item.py | 4 ---- tableauserverclient/models/permissions_item.py | 2 +- tableauserverclient/models/revision_item.py | 5 +++-- tableauserverclient/models/tableau_auth.py | 2 +- tableauserverclient/server/endpoint/databases_endpoint.py | 2 +- 6 files changed, 9 insertions(+), 12 deletions(-) diff --git a/samples/initialize_server.py b/samples/initialize_server.py index 586011120..21b243013 100644 --- a/samples/initialize_server.py +++ b/samples/initialize_server.py @@ -56,7 +56,7 @@ def main(): # Create the site if it doesn't exist if existing_site is None: - print("Site not found: {0} Creating it...").format(args.site_id) + print("Site not found: {0} Creating it...".format(args.site_id)) new_site = TSC.SiteItem( name=args.site_id, content_url=args.site_id.replace(" ", ""), @@ -64,7 +64,7 @@ def main(): ) server.sites.create(new_site) else: - print("Site {0} exists. Moving on...").format(args.site_id) + print("Site {0} exists. Moving on...".format(args.site_id)) ################################################################################ # Step 3: Sign-in to our target site @@ -87,7 +87,7 @@ def main(): # Create our project if it doesn't exist if project is None: - print("Project not found: {0} Creating it...").format(args.project) + print("Project not found: {0} Creating it...".format(args.project)) new_project = TSC.ProjectItem(name=args.project) project = server_upload.projects.create(new_project) diff --git a/tableauserverclient/models/flow_item.py b/tableauserverclient/models/flow_item.py index d957f5e14..18f0ecae2 100644 --- a/tableauserverclient/models/flow_item.py +++ b/tableauserverclient/models/flow_item.py @@ -93,10 +93,6 @@ def description(self, value: str) -> None: def project_name(self) -> Optional[str]: return self._project_name - @property - def flow_type(self): # What is this? It doesn't seem to get set anywhere. - return self._flow_type - @property def updated_at(self) -> Optional["datetime.datetime"]: return self._updated_at diff --git a/tableauserverclient/models/permissions_item.py b/tableauserverclient/models/permissions_item.py index 1c1e9db4d..74b167e9d 100644 --- a/tableauserverclient/models/permissions_item.py +++ b/tableauserverclient/models/permissions_item.py @@ -69,7 +69,7 @@ def from_response(cls, resp, ns=None) -> List["PermissionsRule"]: mode = capability_xml.get("mode") if name is None or mode is None: - logger.error("Capability was not valid: ", capability_xml) + logger.error("Capability was not valid: {}".format(capability_xml)) raise UnpopulatedPropertyError() else: capability_dict[name] = mode diff --git a/tableauserverclient/models/revision_item.py b/tableauserverclient/models/revision_item.py index 024d45edd..a49be88a7 100644 --- a/tableauserverclient/models/revision_item.py +++ b/tableauserverclient/models/revision_item.py @@ -53,8 +53,9 @@ def user_name(self) -> Optional[str]: def __repr__(self): return ( - "" - ).format(**self.__dict__) + "".format(**self.__dict__) + ) @classmethod def from_response(cls, resp: bytes, ns, resource_item) -> List["RevisionItem"]: diff --git a/tableauserverclient/models/tableau_auth.py b/tableauserverclient/models/tableau_auth.py index f373a84ab..6ad0fda5a 100644 --- a/tableauserverclient/models/tableau_auth.py +++ b/tableauserverclient/models/tableau_auth.py @@ -9,7 +9,7 @@ def credentials(self): +"This method returns values to set as an attribute on the credentials element of the request" def __repr__(self): - display = "All Credentials types must have a debug display that does not print secrets" + return "All Credentials types must have a debug display that does not print secrets" def deprecate_site_attribute(): diff --git a/tableauserverclient/server/endpoint/databases_endpoint.py b/tableauserverclient/server/endpoint/databases_endpoint.py index 1fab7ac4b..aa9d73f18 100644 --- a/tableauserverclient/server/endpoint/databases_endpoint.py +++ b/tableauserverclient/server/endpoint/databases_endpoint.py @@ -116,7 +116,7 @@ def update_table_default_permissions(self, item): @api(version="3.5") def delete_table_default_permissions(self, item): - self._default_permissions.delete_default_permissions(item, Resource.Table) + self._default_permissions.delete_default_permission(item, Resource.Table) @api(version="3.5") def populate_dqw(self, item): From ca4d79e0f24c06fa1f10e0b36d18e24d5220ebff Mon Sep 17 00:00:00 2001 From: Jac Date: Thu, 6 Oct 2022 10:57:56 -0700 Subject: [PATCH 06/18] fix behavior when url has no protocol (#1125) * fix behavior when url has no protocol --- tableauserverclient/models/tableau_auth.py | 2 +- .../server/endpoint/auth_endpoint.py | 13 ++++- tableauserverclient/server/server.py | 26 ++++++---- test/http/test_http_requests.py | 52 ++++++++++++++++--- 4 files changed, 73 insertions(+), 20 deletions(-) diff --git a/tableauserverclient/models/tableau_auth.py b/tableauserverclient/models/tableau_auth.py index 6ad0fda5a..24ba1d682 100644 --- a/tableauserverclient/models/tableau_auth.py +++ b/tableauserverclient/models/tableau_auth.py @@ -65,6 +65,6 @@ def credentials(self): } def __repr__(self): - return "".format( + return "(site={})".format( self.token_name, self.personal_access_token[:2] + "...", self.site_id ) diff --git a/tableauserverclient/server/endpoint/auth_endpoint.py b/tableauserverclient/server/endpoint/auth_endpoint.py index 6baf399ed..68d75eaa8 100644 --- a/tableauserverclient/server/endpoint/auth_endpoint.py +++ b/tableauserverclient/server/endpoint/auth_endpoint.py @@ -28,7 +28,18 @@ def baseurl(self): def sign_in(self, auth_req): url = "{0}/{1}".format(self.baseurl, "signin") signin_req = RequestFactory.Auth.signin_req(auth_req) - server_response = self.parent_srv.session.post(url, data=signin_req, **self.parent_srv.http_options) + server_response = self.parent_srv.session.post( + url, data=signin_req, **self.parent_srv.http_options, allow_redirects=False + ) + # manually handle a redirect so that we send the correct POST request instead of GET + # this will make e.g http://online.tableau.com work to redirect to http://east.online.tableau.com + if server_response.status_code == 301: + server_response = self.parent_srv.session.post( + server_response.headers["Location"], + data=signin_req, + **self.parent_srv.http_options, + allow_redirects=False, + ) self.parent_srv._namespace.detect(server_response.content) self._check_status(server_response, url) parsed_response = fromstring(server_response.content) diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index ebe11dac7..9623d722d 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -1,9 +1,10 @@ +import logging import warnings import requests import urllib3 -from defusedxml.ElementTree import fromstring +from defusedxml.ElementTree import fromstring, ParseError from packaging.version import Version from .endpoint import ( Sites, @@ -61,7 +62,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self._site_id = None self._user_id = None - self._server_address = server_address + self._server_address: str = server_address self._session_factory = session_factory or requests.session self.auth = Auth(self) @@ -103,10 +104,13 @@ def __init__(self, server_address, use_server_version=False, http_options=None, def validate_server_connection(self): try: - self._session.prepare_request(requests.Request("GET", url=self._server_address, params=self._http_options)) + if not self._server_address.startswith("http://") and not self._server_address.startswith("https://"): + self._server_address = "http://" + self._server_address + self._session.prepare_request( + requests.Request("GET", url=self._server_address, params=self._http_options) + ) except Exception as req_ex: - warnings.warn("Invalid server initialization\n {}".format(req_ex.__str__()), UserWarning) - print("==================") + raise ValueError("Invalid server initialization", req_ex) def __repr__(self): return " [Connection: {}, {}]".format(self.baseurl, self.server_info.serverInfo) @@ -140,7 +144,13 @@ def _set_auth(self, site_id, user_id, auth_token): def _get_legacy_version(self): response = self._session.get(self.server_address + "/auth?format=xml") - info_xml = fromstring(response.content) + try: + info_xml = fromstring(response.content) + except ParseError as parseError: + logging.getLogger("TSC.server").info( + "Could not read server version info. The server may not be running or configured." + ) + return self.version prod_version = info_xml.find(".//product_version").text version = _PRODUCT_TO_REST_VERSION.get(prod_version, "2.1") # 2.1 return version @@ -152,8 +162,6 @@ def _determine_highest_version(self): version = self.server_info.get().rest_api_version except ServerInfoEndpointNotFoundError: version = self._get_legacy_version() - except BaseException: - version = self._get_legacy_version() self.version = old_version @@ -164,8 +172,6 @@ def use_server_version(self): def use_highest_version(self): self.use_server_version() - import warnings - warnings.warn("use use_server_version instead", DeprecationWarning) def check_at_least_version(self, target: str): diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py index a5f4f4669..e96879277 100644 --- a/test/http/test_http_requests.py +++ b/test/http/test_http_requests.py @@ -1,22 +1,39 @@ import tableauserverclient as TSC import unittest import requests +import requests_mock -from requests_mock import adapter, mock +from unittest import mock from requests.exceptions import MissingSchema +# This method will be used by the mock to replace requests.get +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, status_code): + self.content = ( + "" + "" + "0.31" + "0.31" + "2022.3" + "" + "" + ) + self.status_code = status_code + + return MockResponse(200) + + class ServerTests(unittest.TestCase): def test_init_server_model_empty_throws(self): with self.assertRaises(TypeError): server = TSC.Server() - def test_init_server_model_bad_server_name_complains(self): - # by default, it will just set the version to 2.3 + def test_init_server_model_no_protocol_defaults_htt(self): server = TSC.Server("fake-url") def test_init_server_model_valid_server_name_works(self): - # by default, it will just set the version to 2.3 server = TSC.Server("http://fake-url") def test_init_server_model_valid_https_server_name_works(self): @@ -24,18 +41,18 @@ def test_init_server_model_valid_https_server_name_works(self): server = TSC.Server("https://fake-url") def test_init_server_model_bad_server_name_not_version_check(self): - # by default, it will just set the version to 2.3 server = TSC.Server("fake-url", use_server_version=False) def test_init_server_model_bad_server_name_do_version_check(self): - with self.assertRaises(MissingSchema): + with self.assertRaises(requests.exceptions.ConnectionError): server = TSC.Server("fake-url", use_server_version=True) def test_init_server_model_bad_server_name_not_version_check_random_options(self): - # by default, it will just set the version to 2.3 + # with self.assertRaises(MissingSchema): server = TSC.Server("fake-url", use_server_version=False, http_options={"foo": 1}) def test_init_server_model_bad_server_name_not_version_check_real_options(self): + # with self.assertRaises(ValueError): server = TSC.Server("fake-url", use_server_version=False, http_options={"verify": False}) def test_http_options_skip_ssl_works(self): @@ -62,6 +79,25 @@ def test_http_options_not_sequence_fails(self): with self.assertRaises(ValueError): server.add_http_options({1, 2, 3}) + def test_validate_connection_http(self): + url = "http://cookies.com" + server = TSC.Server(url) + server.validate_server_connection() + self.assertEqual(url, server.server_address) + + def test_validate_connection_https(self): + url = "https://cookies.com" + server = TSC.Server(url) + server.validate_server_connection() + self.assertEqual(url, server.server_address) + + def test_validate_connection_no_protocol(self): + url = "cookies.com" + fixed_url = "http://cookies.com" + server = TSC.Server(url) + server.validate_server_connection() + self.assertEqual(fixed_url, server.server_address) + class SessionTests(unittest.TestCase): test_header = {"x-test": "true"} @@ -74,6 +110,6 @@ def session_factory(): def test_session_factory_adds_headers(self): test_request_bin = "http://capture-this-with-mock.com" - with mock() as m: + with requests_mock.mock() as m: m.get(url="http://capture-this-with-mock.com/api/2.4/serverInfo", request_headers=SessionTests.test_header) server = TSC.Server(test_request_bin, use_server_version=True, session_factory=SessionTests.session_factory) From 24a55187cd9f165071ac53973a1404eb50ba0212 Mon Sep 17 00:00:00 2001 From: Jac Date: Thu, 6 Oct 2022 10:58:45 -0700 Subject: [PATCH 07/18] Jac/smoke tests (#1115) * smoke test for pypi --- .github/workflows/pypi-smoke-tests.yml | 36 ++++++++++++++++++ LICENSE | 2 +- samples/smoke_test.py | 8 ++++ tableauserverclient/__init__.py | 1 + .../server/endpoint/endpoint.py | 38 ++++++++++--------- versioneer.py | 0 6 files changed, 66 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/pypi-smoke-tests.yml create mode 100644 samples/smoke_test.py mode change 100755 => 100644 versioneer.py diff --git a/.github/workflows/pypi-smoke-tests.yml b/.github/workflows/pypi-smoke-tests.yml new file mode 100644 index 000000000..eb6406573 --- /dev/null +++ b/.github/workflows/pypi-smoke-tests.yml @@ -0,0 +1,36 @@ +# This workflow will install TSC from pypi and validate that it runs. For more information see: +# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Pypi smoke tests + +on: + workflow_dispatch: + schedule: + - cron: 0 11 * * * # Every day at 11AM UTC (7AM EST) + +permissions: + contents: read + +jobs: + build: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.x'] + + runs-on: ${{ matrix.os }} + + steps: + - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: pip install + run: | + pip uninstall tableauserverclient + pip install tableauserverclient + - name: Launch app + run: | + python -c "import tableauserverclient as TSC + server = TSC.Server('http://example.com', use_server_version=False)" diff --git a/LICENSE b/LICENSE index 6222b2e80..22f90640f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2016 Tableau +Copyright (c) 2022 Tableau Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/samples/smoke_test.py b/samples/smoke_test.py new file mode 100644 index 000000000..f2dad1048 --- /dev/null +++ b/samples/smoke_test.py @@ -0,0 +1,8 @@ +# This sample verifies that tableau server client is installed +# and you can run it. It also shows the version of the client. + +import tableauserverclient as TSC + +server = TSC.Server("Fake-Server-Url", use_server_version=False) +print("Client details:") +print(TSC.server.endpoint.Endpoint._make_common_headers("fake-token", "any-content")) diff --git a/tableauserverclient/__init__.py b/tableauserverclient/__init__.py index 7c1e6d705..212540d84 100644 --- a/tableauserverclient/__init__.py +++ b/tableauserverclient/__init__.py @@ -1,3 +1,4 @@ +from ._version import get_versions from .namespace import NEW_NAMESPACE as DEFAULT_NAMESPACE from .models import ( BackgroundJobItem, diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 3cdc49322..8f02ffd24 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -1,6 +1,6 @@ import requests import logging -from packaging.version import Version +from distutils.version import LooseVersion as Version from functools import wraps from xml.etree.ElementTree import ParseError from typing import Any, Callable, Dict, Optional, TYPE_CHECKING @@ -12,11 +12,11 @@ EndpointUnavailableError, ) from ..query import QuerySet -from ... import helpers -from ..._version import get_versions +from ... import helpers, get_versions -__TSC_VERSION__ = get_versions()["version"] -del get_versions +if TYPE_CHECKING: + from ..server import Server + from requests import Response logger = logging.getLogger("tableau.endpoint") @@ -25,11 +25,10 @@ XML_CONTENT_TYPE = "text/xml" JSON_CONTENT_TYPE = "application/json" -USERAGENT_HEADER = "User-Agent" - -if TYPE_CHECKING: - from ..server import Server - from requests import Response +CONTENT_TYPE_HEADER = "content-type" +TABLEAU_AUTH_HEADER = "x-tableau-auth" +CLIENT_VERSION_HEADER = "X-TableauServerClient-Version" +USER_AGENT_HEADER = "User-Agent" class Endpoint(object): @@ -38,12 +37,13 @@ def __init__(self, parent_srv: "Server"): @staticmethod def _make_common_headers(auth_token, content_type): + _client_version: Optional[str] = get_versions()["version"] headers = {} if auth_token is not None: - headers["x-tableau-auth"] = auth_token + headers[TABLEAU_AUTH_HEADER] = auth_token if content_type is not None: - headers["content-type"] = content_type - headers["User-Agent"] = "Tableau Server Client/{}".format(__TSC_VERSION__) + headers[CONTENT_TYPE_HEADER] = content_type + headers[USER_AGENT_HEADER] = "Tableau Server Client/{}".format(_client_version) return headers def _make_request( @@ -56,9 +56,9 @@ def _make_request( parameters: Optional[Dict[str, Any]] = None, ) -> "Response": parameters = parameters or {} + parameters.update(self.parent_srv.http_options) if "headers" not in parameters: parameters["headers"] = {} - parameters.update(self.parent_srv.http_options) parameters["headers"].update(Endpoint._make_common_headers(auth_token, content_type)) if content is not None: @@ -83,12 +83,14 @@ def _check_status(self, server_response, url: str = None): if server_response.status_code >= 500: raise InternalServerError(server_response, url) elif server_response.status_code not in Success_codes: + # todo: is an error reliably of content-type application/xml? try: raise ServerResponseError.from_response(server_response.content, self.parent_srv.namespace, url) except ParseError: - # This will happen if we get a non-success HTTP code that doesn't return an xml error object - # e.g metadata endpoints, 503 pages, totally different servers - # we convert this to a better exception and pass through the raw response body + # This will happen if we get a non-success HTTP code that + # doesn't return an xml error object (like metadata endpoints or 503 pages) + # we convert this to a better exception and pass through the raw + # response body raise NonXMLResponseError(server_response.content) except Exception: # anything else re-raise here @@ -186,7 +188,7 @@ def api(version): def _decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): - self.parent_srv.assert_at_least_version(version, self.__class__.__name__) + self.parent_srv.assert_at_least_version(version, "endpoint") return func(self, *args, **kwargs) return wrapper diff --git a/versioneer.py b/versioneer.py old mode 100755 new mode 100644 From bad5db9e020909dc7271303613d4f4c9483041d6 Mon Sep 17 00:00:00 2001 From: TrimPeachu <77048868+TrimPeachu@users.noreply.github.com> Date: Thu, 6 Oct 2022 20:03:32 +0200 Subject: [PATCH 08/18] Add permission control for Data Roles and Metrics (Issue #1063) (#1120) * Add permission control for Data Roles and Metrics (#1) * Add functions to control permissions * Add new resource types --- tableauserverclient/models/tableau_types.py | 2 ++ .../server/endpoint/projects_endpoint.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tableauserverclient/models/tableau_types.py b/tableauserverclient/models/tableau_types.py index feaf02873..6ed77318f 100644 --- a/tableauserverclient/models/tableau_types.py +++ b/tableauserverclient/models/tableau_types.py @@ -11,9 +11,11 @@ class Resource: Database = "database" + Datarole = "datarole" Datasource = "datasource" Flow = "flow" Lens = "lens" + Metric = "metric" Project = "project" Table = "table" View = "view" diff --git a/tableauserverclient/server/endpoint/projects_endpoint.py b/tableauserverclient/server/endpoint/projects_endpoint.py index e268d2011..efdac7b93 100644 --- a/tableauserverclient/server/endpoint/projects_endpoint.py +++ b/tableauserverclient/server/endpoint/projects_endpoint.py @@ -99,6 +99,14 @@ def populate_workbook_default_permissions(self, item): def populate_datasource_default_permissions(self, item): self._default_permissions.populate_default_permissions(item, Resource.Datasource) + @api(version="3.2") + def populate_metric_default_permissions(self, item): + self._default_permissions.populate_default_permissions(item, Resource.Metric) + + @api(version="3.4") + def populate_datarole_default_permissions(self, item): + self._default_permissions.populate_default_permissions(item, Resource.Datarole) + @api(version="3.4") def populate_flow_default_permissions(self, item): self._default_permissions.populate_default_permissions(item, Resource.Flow) @@ -115,6 +123,14 @@ def update_workbook_default_permissions(self, item, rules): def update_datasource_default_permissions(self, item, rules): return self._default_permissions.update_default_permissions(item, rules, Resource.Datasource) + @api(version="3.2") + def update_metric_default_permissions(self, item, rules): + return self._default_permissions.update_default_permissions(item, rules, Resource.Metric) + + @api(version="3.4") + def update_datarole_default_permissions(self, item, rules): + return self._default_permissions.update_default_permissions(item, rules, Resource.Datarole) + @api(version="3.4") def update_flow_default_permissions(self, item, rules): return self._default_permissions.update_default_permissions(item, rules, Resource.Flow) @@ -130,6 +146,14 @@ def delete_workbook_default_permissions(self, item, rule): @api(version="2.1") def delete_datasource_default_permissions(self, item, rule): self._default_permissions.delete_default_permission(item, rule, Resource.Datasource) + + @api(version="3.2") + def delete_metric_default_permissions(self, item, rule): + self._default_permissions.delete_default_permission(item, rule, Resource.Metric) + + @api(version="3.4") + def delete_datarole_default_permissions(self, item, rule): + self._default_permissions.delete_default_permission(item, rule, Resource.Datarole) @api(version="3.4") def delete_flow_default_permissions(self, item, rule): From 14d1af6671adc454757f16cc67d853262a7e1f39 Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Thu, 6 Oct 2022 11:04:15 -0700 Subject: [PATCH 09/18] run black for formatting --- tableauserverclient/server/endpoint/projects_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tableauserverclient/server/endpoint/projects_endpoint.py b/tableauserverclient/server/endpoint/projects_endpoint.py index efdac7b93..7ccdcd775 100644 --- a/tableauserverclient/server/endpoint/projects_endpoint.py +++ b/tableauserverclient/server/endpoint/projects_endpoint.py @@ -146,7 +146,7 @@ def delete_workbook_default_permissions(self, item, rule): @api(version="2.1") def delete_datasource_default_permissions(self, item, rule): self._default_permissions.delete_default_permission(item, rule, Resource.Datasource) - + @api(version="3.2") def delete_metric_default_permissions(self, item, rule): self._default_permissions.delete_default_permission(item, rule, Resource.Metric) From 173c22ac292006a01edcab0ec1c49913bf8f1c4c Mon Sep 17 00:00:00 2001 From: Jac Date: Fri, 14 Oct 2022 11:27:28 -0700 Subject: [PATCH 10/18] fix check for being on master bring over fix from test repo --- .github/workflows/publish-pypi.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 33438bed8..467d23879 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -26,13 +26,15 @@ jobs: pip install -e .[test] build python -m build git describe --tag --dirty --always - - name: Publish distribution 📦 to Test PyPI + + - name: Publish distribution 📦 to Test PyPI # always run uses: pypa/gh-action-pypi-publish@release/v1 # license BSD-2 with: password: ${{ secrets.TEST_PYPI_API_TOKEN }} repository_url: https://test.pypi.org/legacy/ + - name: Publish distribution 📦 to PyPI - if: github.ref == 'refs/heads/master' + if: ${{ github.ref.name == 'master' || (github.event.ref.name == 'refs/head/master') }} uses: pypa/gh-action-pypi-publish@release/v1 # license BSD-2 with: password: ${{ secrets.PYPI_API_TOKEN }} From 0bb9dd5d2bdaf4adbca9427deb26c9469e961d9b Mon Sep 17 00:00:00 2001 From: Jac Date: Mon, 12 Dec 2022 20:25:24 -0800 Subject: [PATCH 11/18] mypy no-implicit-optional (#1151) --- tableauserverclient/models/datasource_item.py | 2 +- tableauserverclient/models/site_item.py | 6 +++--- tableauserverclient/models/workbook_item.py | 2 +- .../server/endpoint/datasources_endpoint.py | 8 ++++---- tableauserverclient/server/endpoint/endpoint.py | 2 +- tableauserverclient/server/endpoint/exceptions.py | 3 ++- .../server/endpoint/flows_endpoint.py | 2 +- .../server/endpoint/permissions_endpoint.py | 4 ++-- .../server/endpoint/schedules_endpoint.py | 8 ++++---- .../server/endpoint/users_endpoint.py | 12 ++++++------ .../server/endpoint/workbooks_endpoint.py | 6 +++--- tableauserverclient/server/request_factory.py | 4 ++-- 12 files changed, 30 insertions(+), 29 deletions(-) diff --git a/tableauserverclient/models/datasource_item.py b/tableauserverclient/models/datasource_item.py index 37ec1449a..4a7a74c4b 100644 --- a/tableauserverclient/models/datasource_item.py +++ b/tableauserverclient/models/datasource_item.py @@ -34,7 +34,7 @@ class AskDataEnablement: Disabled = "Disabled" SiteDefault = "SiteDefault" - def __init__(self, project_id: str, name: str = None) -> None: + def __init__(self, project_id: str, name: Optional[str] = None) -> None: self._ask_data_enablement = None self._certified = None self._certification_note = None diff --git a/tableauserverclient/models/site_item.py b/tableauserverclient/models/site_item.py index 8c9e8fe8e..e6bc3af24 100644 --- a/tableauserverclient/models/site_item.py +++ b/tableauserverclient/models/site_item.py @@ -50,9 +50,9 @@ def __init__( self, name: str, content_url: str, - admin_mode: str = None, - user_quota: int = None, - storage_quota: int = None, + admin_mode: Optional[str] = None, + user_quota: Optional[int] = None, + storage_quota: Optional[int] = None, disable_subscriptions: bool = False, subscribe_others_enabled: bool = True, revision_history_enabled: bool = False, diff --git a/tableauserverclient/models/workbook_item.py b/tableauserverclient/models/workbook_item.py index 0d18e770d..6d9a21b6b 100644 --- a/tableauserverclient/models/workbook_item.py +++ b/tableauserverclient/models/workbook_item.py @@ -33,7 +33,7 @@ class WorkbookItem(object): - def __init__(self, project_id: str, name: str = None, show_tabs: bool = False) -> None: + def __init__(self, project_id: str, name: Optional[str] = None, show_tabs: bool = False) -> None: self._connections = None self._content_url = None self._webpage_url = None diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index 022523aa4..5cea8fa5c 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -80,7 +80,7 @@ def baseurl(self) -> str: # Get all datasources @api(version="2.0") - def get(self, req_options: RequestOptions = None) -> Tuple[List[DatasourceItem], PaginationItem]: + def get(self, req_options: Optional[RequestOptions] = None) -> Tuple[List[DatasourceItem], PaginationItem]: logger.info("Querying all datasources on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -135,7 +135,7 @@ def delete(self, datasource_id: str) -> None: def download( self, datasource_id: str, - filepath: FilePath = None, + filepath: Optional[FilePath] = None, include_extract: bool = True, no_extract: Optional[bool] = None, ) -> str: @@ -234,8 +234,8 @@ def publish( datasource_item: DatasourceItem, file: PathOrFile, mode: str, - connection_credentials: ConnectionCredentials = None, - connections: Sequence[ConnectionItem] = None, + connection_credentials: Optional[ConnectionCredentials] = None, + connections: Optional[Sequence[ConnectionItem]] = None, as_job: bool = False, ) -> Union[DatasourceItem, JobItem]: diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index a836b000d..e04acc595 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -78,7 +78,7 @@ def _make_request( return server_response - def _check_status(self, server_response, url: str = None): + def _check_status(self, server_response, url: Optional[str] = None): if server_response.status_code >= 500: raise InternalServerError(server_response, url) elif server_response.status_code not in Success_codes: diff --git a/tableauserverclient/server/endpoint/exceptions.py b/tableauserverclient/server/endpoint/exceptions.py index 3ce0d5e92..d7b1d5ad2 100644 --- a/tableauserverclient/server/endpoint/exceptions.py +++ b/tableauserverclient/server/endpoint/exceptions.py @@ -1,4 +1,5 @@ from defusedxml.ElementTree import fromstring +from typing import Optional class TableauError(Exception): @@ -33,7 +34,7 @@ def from_response(cls, resp, ns, url=None): class InternalServerError(TableauError): - def __init__(self, server_response, request_url: str = None): + def __init__(self, server_response, request_url: Optional[str] = None): self.code = server_response.status_code self.content = server_response.content self.url = request_url or "server" diff --git a/tableauserverclient/server/endpoint/flows_endpoint.py b/tableauserverclient/server/endpoint/flows_endpoint.py index 2c54d17c4..931c85d06 100644 --- a/tableauserverclient/server/endpoint/flows_endpoint.py +++ b/tableauserverclient/server/endpoint/flows_endpoint.py @@ -94,7 +94,7 @@ def delete(self, flow_id: str) -> None: # Download 1 flow by id @api(version="3.3") - def download(self, flow_id: str, filepath: FilePath = None) -> str: + def download(self, flow_id: str, filepath: Optional[FilePath] = None) -> str: if not flow_id: error = "Flow ID undefined." raise ValueError(error) diff --git a/tableauserverclient/server/endpoint/permissions_endpoint.py b/tableauserverclient/server/endpoint/permissions_endpoint.py index f7c2f9f13..e3e9af2a6 100644 --- a/tableauserverclient/server/endpoint/permissions_endpoint.py +++ b/tableauserverclient/server/endpoint/permissions_endpoint.py @@ -6,7 +6,7 @@ from .exceptions import MissingRequiredFieldError from ...models import TableauItem -from typing import Callable, TYPE_CHECKING, List, Union +from typing import Optional, Callable, TYPE_CHECKING, List, Union logger = logging.getLogger(__name__) @@ -82,7 +82,7 @@ def permission_fetcher(): item._set_permissions(permission_fetcher) logger.info("Populated permissions for item (ID: {0})".format(item.id)) - def _get_permissions(self, item: TableauItem, req_options: "RequestOptions" = None): + def _get_permissions(self, item: TableauItem, req_options: Optional["RequestOptions"] = None): url = "{0}/{1}/permissions".format(self.owner_baseurl(), item.id) server_response = self.get_request(url, req_options) permissions = PermissionsRule.from_response(server_response.content, self.parent_srv.namespace) diff --git a/tableauserverclient/server/endpoint/schedules_endpoint.py b/tableauserverclient/server/endpoint/schedules_endpoint.py index 21c828989..65a55bcb6 100644 --- a/tableauserverclient/server/endpoint/schedules_endpoint.py +++ b/tableauserverclient/server/endpoint/schedules_endpoint.py @@ -85,10 +85,10 @@ def create(self, schedule_item: ScheduleItem) -> ScheduleItem: def add_to_schedule( self, schedule_id: str, - workbook: "WorkbookItem" = None, - datasource: "DatasourceItem" = None, - flow: "FlowItem" = None, - task_type: str = None, + workbook: Optional["WorkbookItem"] = None, + datasource: Optional["DatasourceItem"] = None, + flow: Optional["FlowItem"] = None, + task_type: Optional[str] = None, ) -> List[AddResponse]: # There doesn't seem to be a good reason to allow one item of each type? diff --git a/tableauserverclient/server/endpoint/users_endpoint.py b/tableauserverclient/server/endpoint/users_endpoint.py index 28406ab71..3faf4d173 100644 --- a/tableauserverclient/server/endpoint/users_endpoint.py +++ b/tableauserverclient/server/endpoint/users_endpoint.py @@ -21,7 +21,7 @@ def baseurl(self) -> str: # Gets all users @api(version="2.0") - def get(self, req_options: RequestOptions = None) -> Tuple[List[UserItem], PaginationItem]: + def get(self, req_options: Optional[RequestOptions] = None) -> Tuple[List[UserItem], PaginationItem]: logger.info("Querying all users on site") if req_options is None: @@ -47,7 +47,7 @@ def get_by_id(self, user_id: str) -> UserItem: # Update user @api(version="2.0") - def update(self, user_item: UserItem, password: str = None) -> UserItem: + def update(self, user_item: UserItem, password: Optional[str] = None) -> UserItem: if not user_item.id: error = "User item missing ID." raise MissingRequiredFieldError(error) @@ -122,7 +122,7 @@ def create_from_file(self, filepath: str) -> Tuple[List[UserItem], List[Tuple[Us # Get workbooks for user @api(version="2.0") - def populate_workbooks(self, user_item: UserItem, req_options: RequestOptions = None) -> None: + def populate_workbooks(self, user_item: UserItem, req_options: Optional[RequestOptions] = None) -> None: if not user_item.id: error = "User item missing ID." raise MissingRequiredFieldError(error) @@ -133,7 +133,7 @@ def wb_pager(): user_item._set_workbooks(wb_pager) def _get_wbs_for_user( - self, user_item: UserItem, req_options: RequestOptions = None + self, user_item: UserItem, req_options: Optional[RequestOptions] = None ) -> Tuple[List[WorkbookItem], PaginationItem]: url = "{0}/{1}/workbooks".format(self.baseurl, user_item.id) server_response = self.get_request(url, req_options) @@ -147,7 +147,7 @@ def populate_favorites(self, user_item: UserItem) -> None: # Get groups for user @api(version="3.7") - def populate_groups(self, user_item: UserItem, req_options: RequestOptions = None) -> None: + def populate_groups(self, user_item: UserItem, req_options: Optional[RequestOptions] = None) -> None: if not user_item.id: error = "User item missing ID." raise MissingRequiredFieldError(error) @@ -161,7 +161,7 @@ def groups_for_user_pager(): user_item._set_groups(groups_for_user_pager) def _get_groups_for_user( - self, user_item: UserItem, req_options: RequestOptions = None + self, user_item: UserItem, req_options: Optional[RequestOptions] = None ) -> Tuple[List[GroupItem], PaginationItem]: url = "{0}/{1}/groups".format(self.baseurl, user_item.id) server_response = self.get_request(url, req_options) diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index 4d7a4a2b5..272a1d05d 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -178,7 +178,7 @@ def update_connection(self, workbook_item: WorkbookItem, connection_item: Connec def download( self, workbook_id: str, - filepath: FilePath = None, + filepath: Optional[FilePath] = None, include_extract: bool = True, no_extract: Optional[bool] = None, ) -> str: @@ -250,7 +250,7 @@ def connection_fetcher(): logger.info("Populated connections for workbook (ID: {0})".format(workbook_item.id)) def _get_workbook_connections( - self, workbook_item: WorkbookItem, req_options: "RequestOptions" = None + self, workbook_item: WorkbookItem, req_options: Optional["RequestOptions"] = None ) -> List[ConnectionItem]: url = "{0}/{1}/connections".format(self.baseurl, workbook_item.id) server_response = self.get_request(url, req_options) @@ -259,7 +259,7 @@ def _get_workbook_connections( # Get the pdf of the entire workbook if its tabs are enabled, pdf of the default view if its tabs are disabled @api(version="3.4") - def populate_pdf(self, workbook_item: WorkbookItem, req_options: "RequestOptions" = None) -> None: + def populate_pdf(self, workbook_item: WorkbookItem, req_options: Optional["RequestOptions"] = None) -> None: if not workbook_item.id: error = "Workbook item missing ID." raise MissingRequiredFieldError(error) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index aad8ca074..142297aa0 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -575,7 +575,7 @@ def add_flow_req(self, id_: Optional[str], task_type: str = TaskItem.Type.RunFlo class SiteRequest(object): - def update_req(self, site_item: "SiteItem", parent_srv: "Server" = None): + def update_req(self, site_item: "SiteItem", parent_srv: Optional["Server"] = None): xml_request = ET.Element("tsRequest") site_element = ET.SubElement(xml_request, "site") if site_item.name: @@ -683,7 +683,7 @@ def update_req(self, site_item: "SiteItem", parent_srv: "Server" = None): return ET.tostring(xml_request) # server: the site request model changes based on api version - def create_req(self, site_item: "SiteItem", parent_srv: "Server" = None): + def create_req(self, site_item: "SiteItem", parent_srv: Optional["Server"] = None): xml_request = ET.Element("tsRequest") site_element = ET.SubElement(xml_request, "site") site_element.attrib["name"] = site_item.name From 504d9d4e26cdc6890b4c524c6d332dc4c8fd49ef Mon Sep 17 00:00:00 2001 From: Jac Date: Wed, 14 Dec 2022 17:31:27 -0800 Subject: [PATCH 12/18] add option to pass specific datasources (#1150) * add option to pass specific datasources * mypy no-implicit-optional --- tableauserverclient/server/endpoint/workbooks_endpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index 272a1d05d..163bb8c71 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -117,12 +117,13 @@ def create_extract( # delete all the extracts on 1 workbook @api(version="3.3") - def delete_extract(self, workbook_item: WorkbookItem, includeAll: bool = True) -> None: + def delete_extract(self, workbook_item: WorkbookItem, includeAll: bool = True, datasources=None) -> JobItem: id_ = getattr(workbook_item, "id", workbook_item) url = "{0}/{1}/deleteExtract".format(self.baseurl, id_) - datasource_req = RequestFactory.Workbook.embedded_extract_req(includeAll, None) + datasource_req = RequestFactory.Workbook.embedded_extract_req(includeAll, datasources) server_response = self.post_request(url, datasource_req) new_job = JobItem.from_response(server_response.content, self.parent_srv.namespace)[0] + return new_job # Delete 1 workbook by id @api(version="2.0") From 16b1bdd7dc5cb190665861f4e0fbf1b054975ea0 Mon Sep 17 00:00:00 2001 From: Jac Date: Fri, 6 Jan 2023 15:39:59 -0800 Subject: [PATCH 13/18] allow user agent to be set by caller (#1166) --- .../server/endpoint/endpoint.py | 52 ++++++++++++------- tableauserverclient/server/server.py | 21 ++++---- test/http/test_http_requests.py | 6 +-- test/test_endpoint.py | 18 +++++++ 4 files changed, 67 insertions(+), 30 deletions(-) diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index e04acc595..b1a42b20c 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -3,7 +3,7 @@ from packaging.version import Version from functools import wraps from xml.etree.ElementTree import ParseError -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Mapping from .exceptions import ( ServerResponseError, @@ -35,15 +35,35 @@ def __init__(self, parent_srv: "Server"): self.parent_srv = parent_srv @staticmethod - def _make_common_headers(auth_token, content_type): - _client_version: Optional[str] = get_versions()["version"] - headers = {} + def set_parameters(http_options, auth_token, content, content_type, parameters) -> Dict[str, Any]: + parameters = parameters or {} + parameters.update(http_options) + if "headers" not in parameters: + parameters["headers"] = {} + if auth_token is not None: - headers[TABLEAU_AUTH_HEADER] = auth_token + parameters["headers"][TABLEAU_AUTH_HEADER] = auth_token if content_type is not None: - headers[CONTENT_TYPE_HEADER] = content_type - headers[USER_AGENT_HEADER] = "Tableau Server Client/{}".format(_client_version) - return headers + parameters["headers"][CONTENT_TYPE_HEADER] = content_type + + Endpoint.set_user_agent(parameters) + if content is not None: + parameters["data"] = content + return parameters or {} + + @staticmethod + def set_user_agent(parameters): + if USER_AGENT_HEADER not in parameters["headers"]: + if USER_AGENT_HEADER in parameters: + parameters["headers"][USER_AGENT_HEADER] = parameters[USER_AGENT_HEADER] + else: + # only set the TSC user agent if not already populated + _client_version: Optional[str] = get_versions()["version"] + parameters["headers"][USER_AGENT_HEADER] = "Tableau Server Client/{}".format(_client_version) + + # result: parameters["headers"]["User-Agent"] is set + # return explicitly for testing only + return parameters def _make_request( self, @@ -54,18 +74,14 @@ def _make_request( content_type: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, ) -> "Response": - parameters = parameters or {} - if "headers" not in parameters: - parameters["headers"] = {} - parameters.update(self.parent_srv.http_options) - parameters["headers"].update(Endpoint._make_common_headers(auth_token, content_type)) - - if content is not None: - parameters["data"] = content + parameters = Endpoint.set_parameters( + self.parent_srv.http_options, auth_token, content, content_type, parameters + ) - logger.debug("request {}, url: {}".format(method.__name__, url)) + logger.debug("request {}, url: {}".format(method, url)) if content: - logger.debug("request content: {}".format(helpers.strings.redact_xml(content[:1000]))) + redacted = helpers.strings.redact_xml(content[:1000]) + logger.debug("request content: {}".format(redacted)) server_response = method(url, **parameters) self._check_status(server_response, url) diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 5e2dacf33..d2a8b933b 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -31,6 +31,7 @@ Fileuploads, FlowRuns, Metrics, + Endpoint, ) from .endpoint.exceptions import ( ServerInfoEndpointNotFoundError, @@ -62,6 +63,10 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self._site_id = None self._user_id = None + # TODO: this needs to change to default to https, but without breaking existing code + if not server_address.startswith("http://") and not server_address.startswith("https://"): + server_address = "http://" + server_address + self._server_address: str = server_address self._session_factory = session_factory or requests.session @@ -96,21 +101,17 @@ def __init__(self, server_address, use_server_version=False, http_options=None, if http_options: self.add_http_options(http_options) - self.validate_server_connection() + self.validate_connection_settings() # does not make an actual outgoing request self.version = default_server_version if use_server_version: self.use_server_version() # this makes a server call - def validate_server_connection(self): + def validate_connection_settings(self): try: - if not self._server_address.startswith("http://") and not self._server_address.startswith("https://"): - self._server_address = "http://" + self._server_address - self._session.prepare_request( - requests.Request("GET", url=self._server_address, params=self._http_options) - ) + Endpoint(self).set_parameters(self._http_options, None, None, None, None) except Exception as req_ex: - raise ValueError("Invalid server initialization", req_ex) + raise ValueError("Server connection settings not valid", req_ex) def __repr__(self): return " [Connection: {}, {}]".format(self.baseurl, self.server_info.serverInfo) @@ -143,10 +144,12 @@ def _set_auth(self, site_id, user_id, auth_token): self._auth_token = auth_token def _get_legacy_version(self): - response = self._session.get(self.server_address + "/auth?format=xml") + dest = Endpoint(self) + response = dest._make_request(method=self.session.get, url=self.server_address + "/auth?format=xml") try: info_xml = fromstring(response.content) except ParseError as parseError: + logging.getLogger("TSC.server").info(parseError) logging.getLogger("TSC.server").info( "Could not read server version info. The server may not be running or configured." ) diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py index e96879277..bf9292dec 100644 --- a/test/http/test_http_requests.py +++ b/test/http/test_http_requests.py @@ -82,20 +82,20 @@ def test_http_options_not_sequence_fails(self): def test_validate_connection_http(self): url = "http://cookies.com" server = TSC.Server(url) - server.validate_server_connection() + server.validate_connection_settings() self.assertEqual(url, server.server_address) def test_validate_connection_https(self): url = "https://cookies.com" server = TSC.Server(url) - server.validate_server_connection() + server.validate_connection_settings() self.assertEqual(url, server.server_address) def test_validate_connection_no_protocol(self): url = "cookies.com" fixed_url = "http://cookies.com" server = TSC.Server(url) - server.validate_server_connection() + server.validate_connection_settings() self.assertEqual(fixed_url, server.server_address) diff --git a/test/test_endpoint.py b/test/test_endpoint.py index e583a9188..5b6324cab 100644 --- a/test/test_endpoint.py +++ b/test/test_endpoint.py @@ -38,3 +38,21 @@ class FakeResponse(object): server_response = FakeResponse() log = endpoint.log_response_safely(server_response) self.assertTrue(log.find("[Truncated File Contents]") > 0, log) + + def test_set_user_agent_from_options_headers(self): + params = {"User-Agent": "1", "headers": {"User-Agent": "2"}} + result = TSC.server.Endpoint.set_user_agent(params) + # it should use the value under 'headers' if more than one is given + print(result) + print(result["headers"]["User-Agent"]) + self.assertTrue(result["headers"]["User-Agent"] == "2") + + def test_set_user_agent_from_options(self): + params = {"headers": {"User-Agent": "2"}} + result = TSC.server.Endpoint.set_user_agent(params) + self.assertTrue(result["headers"]["User-Agent"] == "2") + + def test_set_user_agent_when_blank(self): + params = {"headers": {}} + result = TSC.server.Endpoint.set_user_agent(params) + self.assertTrue(result["headers"]["User-Agent"].startswith("Tableau Server Client")) From 7ceed6c023b33de15ccfc5e42e0a0c13599b8b53 Mon Sep 17 00:00:00 2001 From: Stu Tomlinson Date: Tue, 17 Jan 2023 22:50:09 +0000 Subject: [PATCH 14/18] Fix issues with connections publishing workbooks (#1171) Allow publishing using connection credentials on ConnectionItem class without ConnectionCredentials instance, as documented Accept empty string for username or password in connection credentials Avoid Tableau Server internal server error when publishing with empty connection list by setting connections to None --- tableauserverclient/server/request_factory.py | 14 +++-- test/test_workbook.py | 51 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 142297aa0..209626051 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -7,6 +7,7 @@ from tableauserverclient.models.metric_item import MetricItem +from ..models import ConnectionCredentials from ..models import ConnectionItem from ..models import DataAlertItem from ..models import FlowItem @@ -55,6 +56,13 @@ def _add_connections_element(connections_element, connection): connection_element.attrib["serverPort"] = connection.server_port if connection.connection_credentials: connection_credentials = connection.connection_credentials + elif connection.username is not None and connection.password is not None and connection.embed_password is not None: + connection_credentials = ConnectionCredentials( + connection.username, connection.password, embed=connection.embed_password + ) + else: + connection_credentials = None + if connection_credentials: _add_credentials_element(connection_element, connection_credentials) @@ -66,7 +74,7 @@ def _add_hiddenview_element(views_element, view_name): def _add_credentials_element(parent_element, connection_credentials): credentials_element = ET.SubElement(parent_element, "connectionCredentials") - if not connection_credentials.password or not connection_credentials.name: + if connection_credentials.password is None or connection_credentials.name is None: raise ValueError("Connection Credentials must have a name and password") credentials_element.attrib["name"] = connection_credentials.name credentials_element.attrib["password"] = connection_credentials.password @@ -177,7 +185,7 @@ def _generate_xml(self, datasource_item, connection_credentials=None, connection if connection_credentials is not None: _add_credentials_element(datasource_element, connection_credentials) - if connections is not None: + if connections is not None and len(connections) > 0: connections_element = ET.SubElement(datasource_element, "connections") for connection in connections: _add_connections_element(connections_element, connection) @@ -899,7 +907,7 @@ def _generate_xml( if connection_credentials is not None: _add_credentials_element(workbook_element, connection_credentials) - if connections is not None: + if connections is not None and len(connections) > 0: connections_element = ET.SubElement(workbook_element, "connections") for connection in connections: _add_connections_element(connections_element, connection) diff --git a/test/test_workbook.py b/test/test_workbook.py index db7f0723b..ba21dc195 100644 --- a/test/test_workbook.py +++ b/test/test_workbook.py @@ -748,6 +748,30 @@ def test_publish_multi_connection(self) -> None: self.assertEqual(connection_results[1].get("serverAddress", None), "pgsql.test.com") self.assertEqual(connection_results[1].find("connectionCredentials").get("password", None), "secret") # type: ignore[union-attr] + def test_publish_multi_connection_flat(self) -> None: + new_workbook = TSC.WorkbookItem( + name="Sample", show_tabs=False, project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760" + ) + connection1 = TSC.ConnectionItem() + connection1.server_address = "mysql.test.com" + connection1.username = "test" + connection1.password = "secret" + connection1.embed_password = True + connection2 = TSC.ConnectionItem() + connection2.server_address = "pgsql.test.com" + connection2.username = "test" + connection2.password = "secret" + connection2.embed_password = True + + response = RequestFactory.Workbook._generate_xml(new_workbook, connections=[connection1, connection2]) + # Can't use ConnectionItem parser due to xml namespace problems + connection_results = fromstring(response).findall(".//connection") + + self.assertEqual(connection_results[0].get("serverAddress", None), "mysql.test.com") + self.assertEqual(connection_results[0].find("connectionCredentials").get("name", None), "test") # type: ignore[union-attr] + self.assertEqual(connection_results[1].get("serverAddress", None), "pgsql.test.com") + self.assertEqual(connection_results[1].find("connectionCredentials").get("password", None), "secret") # type: ignore[union-attr] + def test_publish_single_connection(self) -> None: new_workbook = TSC.WorkbookItem( name="Sample", show_tabs=False, project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760" @@ -762,6 +786,33 @@ def test_publish_single_connection(self) -> None: self.assertEqual(credentials[0].get("password", None), "secret") self.assertEqual(credentials[0].get("embed", None), "true") + def test_publish_single_connection_username_none(self) -> None: + new_workbook = TSC.WorkbookItem( + name="Sample", show_tabs=False, project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760" + ) + connection_creds = TSC.ConnectionCredentials(None, "secret", True) + + self.assertRaises( + ValueError, + RequestFactory.Workbook._generate_xml, + new_workbook, + connection_credentials=connection_creds, + ) + + def test_publish_single_connection_username_empty(self) -> None: + new_workbook = TSC.WorkbookItem( + name="Sample", show_tabs=False, project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760" + ) + connection_creds = TSC.ConnectionCredentials("", "secret", True) + + response = RequestFactory.Workbook._generate_xml(new_workbook, connection_credentials=connection_creds) + # Can't use ConnectionItem parser due to xml namespace problems + credentials = fromstring(response).findall(".//connectionCredentials") + self.assertEqual(len(credentials), 1) + self.assertEqual(credentials[0].get("name", None), "") + self.assertEqual(credentials[0].get("password", None), "secret") + self.assertEqual(credentials[0].get("embed", None), "true") + def test_credentials_and_multi_connect_raises_exception(self) -> None: new_workbook = TSC.WorkbookItem( name="Sample", show_tabs=False, project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760" From a8c663ea81898273b7a6440659a1788f618e7741 Mon Sep 17 00:00:00 2001 From: Stu Tomlinson Date: Fri, 20 Jan 2023 08:40:20 +0000 Subject: [PATCH 15/18] Allow download to file-like objects (#1172) --- .../server/endpoint/datasources_endpoint.py | 91 +++++---------- .../server/endpoint/flows_endpoint.py | 106 +++++++++++++----- .../server/endpoint/workbooks_endpoint.py | 87 ++++++-------- test/assets/SampleFlow.tfl | Bin 0 -> 1884 bytes test/assets/flow_publish.xml | 10 ++ test/test_datasource.py | 12 ++ test/test_flow.py | 83 +++++++++++++- test/test_workbook.py | 10 ++ 8 files changed, 252 insertions(+), 147 deletions(-) create mode 100644 test/assets/SampleFlow.tfl create mode 100644 test/assets/flow_publish.xml diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index 5cea8fa5c..9df7edfc8 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -31,22 +31,9 @@ ) from ...models import ConnectionCredentials, RevisionItem from ...models.job_item import JobItem -from ...models import ConnectionCredentials -io_types = (io.BytesIO, io.BufferedReader) - -from pathlib import Path -from typing import ( - List, - Mapping, - Optional, - Sequence, - Tuple, - TYPE_CHECKING, - Union, -) - -io_types = (io.BytesIO, io.BufferedReader) +io_types_r = (io.BytesIO, io.BufferedReader) +io_types_w = (io.BytesIO, io.BufferedWriter) # The maximum size of a file that can be published in a single request is 64MB FILESIZE_LIMIT = 1024 * 1024 * 64 # 64MB @@ -61,8 +48,10 @@ from .schedules_endpoint import AddResponse FilePath = Union[str, os.PathLike] -FileObject = Union[io.BufferedReader, io.BytesIO] -PathOrFile = Union[FilePath, FileObject] +FileObjectR = Union[io.BufferedReader, io.BytesIO] +FileObjectW = Union[io.BufferedWriter, io.BytesIO] +PathOrFileR = Union[FilePath, FileObjectR] +PathOrFileW = Union[FilePath, FileObjectW] class Datasources(QuerysetEndpoint): @@ -135,39 +124,11 @@ def delete(self, datasource_id: str) -> None: def download( self, datasource_id: str, - filepath: Optional[FilePath] = None, + filepath: Optional[PathOrFileW] = None, include_extract: bool = True, no_extract: Optional[bool] = None, ) -> str: - if not datasource_id: - error = "Datasource ID undefined." - raise ValueError(error) - url = "{0}/{1}/content".format(self.baseurl, datasource_id) - - if no_extract is False or no_extract is True: - import warnings - - warnings.warn( - "no_extract is deprecated, use include_extract instead.", - DeprecationWarning, - ) - include_extract = not no_extract - - if not include_extract: - url += "?includeExtract=False" - - with closing(self.get_request(url, parameters={"stream": True})) as server_response: - _, params = cgi.parse_header(server_response.headers["Content-Disposition"]) - filename = to_filename(os.path.basename(params["filename"])) - - download_path = make_download_path(filepath, filename) - - with open(download_path, "wb") as f: - for chunk in server_response.iter_content(1024): # 1KB - f.write(chunk) - - logger.info("Downloaded datasource to {0} (ID: {1})".format(download_path, datasource_id)) - return os.path.abspath(download_path) + return self.download_revision(datasource_id, None, filepath, include_extract, no_extract) # Update datasource @api(version="2.0") @@ -232,7 +193,7 @@ def delete_extract(self, datasource_item: DatasourceItem) -> None: def publish( self, datasource_item: DatasourceItem, - file: PathOrFile, + file: PathOrFileR, mode: str, connection_credentials: Optional[ConnectionCredentials] = None, connections: Optional[Sequence[ConnectionItem]] = None, @@ -255,8 +216,7 @@ def publish( error = "Only {} files can be published as datasources.".format(", ".join(ALLOWED_FILE_EXTENSIONS)) raise ValueError(error) - elif isinstance(file, io_types): - + elif isinstance(file, io_types_r): if not datasource_item.name: error = "Datasource item must have a name when passing a file object" raise ValueError(error) @@ -302,7 +262,7 @@ def publish( if isinstance(file, (Path, str)): with open(file, "rb") as f: file_contents = f.read() - elif isinstance(file, io_types): + elif isinstance(file, io_types_r): file_contents = file.read() else: raise TypeError("file should be a filepath or file object.") @@ -433,14 +393,17 @@ def download_revision( self, datasource_id: str, revision_number: str, - filepath: Optional[PathOrFile] = None, + filepath: Optional[PathOrFileW] = None, include_extract: bool = True, no_extract: Optional[bool] = None, - ) -> str: + ) -> PathOrFileW: if not datasource_id: error = "Datasource ID undefined." raise ValueError(error) - url = "{0}/{1}/revisions/{2}/content".format(self.baseurl, datasource_id, revision_number) + if revision_number is None: + url = "{0}/{1}/content".format(self.baseurl, datasource_id) + else: + url = "{0}/{1}/revisions/{2}/content".format(self.baseurl, datasource_id, revision_number) if no_extract is False or no_extract is True: import warnings @@ -455,18 +418,22 @@ def download_revision( with closing(self.get_request(url, parameters={"stream": True})) as server_response: _, params = cgi.parse_header(server_response.headers["Content-Disposition"]) - filename = to_filename(os.path.basename(params["filename"])) - - download_path = make_download_path(filepath, filename) - - with open(download_path, "wb") as f: + if isinstance(filepath, io_types_w): for chunk in server_response.iter_content(1024): # 1KB - f.write(chunk) + filepath.write(chunk) + return_path = filepath + else: + filename = to_filename(os.path.basename(params["filename"])) + download_path = make_download_path(filepath, filename) + with open(download_path, "wb") as f: + for chunk in server_response.iter_content(1024): # 1KB + f.write(chunk) + return_path = os.path.abspath(download_path) logger.info( - "Downloaded datasource revision {0} to {1} (ID: {2})".format(revision_number, download_path, datasource_id) + "Downloaded datasource revision {0} to {1} (ID: {2})".format(revision_number, return_path, datasource_id) ) - return os.path.abspath(download_path) + return return_path @api(version="2.3") def delete_revision(self, datasource_id: str, revision_number: str) -> None: diff --git a/tableauserverclient/server/endpoint/flows_endpoint.py b/tableauserverclient/server/endpoint/flows_endpoint.py index 931c85d06..5b182111b 100644 --- a/tableauserverclient/server/endpoint/flows_endpoint.py +++ b/tableauserverclient/server/endpoint/flows_endpoint.py @@ -1,8 +1,10 @@ import cgi import copy +import io import logging import os from contextlib import closing +from pathlib import Path from typing import Iterable, List, Optional, TYPE_CHECKING, Tuple, Union from .dqw_endpoint import _DataQualityWarningEndpoint @@ -11,9 +13,17 @@ from .permissions_endpoint import _PermissionsEndpoint from .resource_tagger import _ResourceTagger from .. import RequestFactory, FlowItem, PaginationItem, ConnectionItem -from ...filesys_helpers import to_filename, make_download_path +from ...filesys_helpers import ( + to_filename, + make_download_path, + get_file_type, + get_file_object_size, +) from ...models.job_item import JobItem +io_types_r = (io.BytesIO, io.BufferedReader) +io_types_w = (io.BytesIO, io.BufferedWriter) + # The maximum size of a file that can be published in a single request is 64MB FILESIZE_LIMIT = 1024 * 1024 * 64 # 64MB @@ -29,6 +39,10 @@ FilePath = Union[str, os.PathLike] +FileObjectR = Union[io.BufferedReader, io.BytesIO] +FileObjectW = Union[io.BufferedWriter, io.BytesIO] +PathOrFileR = Union[FilePath, FileObjectR] +PathOrFileW = Union[FilePath, FileObjectW] class Flows(QuerysetEndpoint): @@ -94,7 +108,7 @@ def delete(self, flow_id: str) -> None: # Download 1 flow by id @api(version="3.3") - def download(self, flow_id: str, filepath: Optional[FilePath] = None) -> str: + def download(self, flow_id: str, filepath: Optional[PathOrFileW] = None) -> PathOrFileW: if not flow_id: error = "Flow ID undefined." raise ValueError(error) @@ -102,16 +116,20 @@ def download(self, flow_id: str, filepath: Optional[FilePath] = None) -> str: with closing(self.get_request(url, parameters={"stream": True})) as server_response: _, params = cgi.parse_header(server_response.headers["Content-Disposition"]) - filename = to_filename(os.path.basename(params["filename"])) - - download_path = make_download_path(filepath, filename) - - with open(download_path, "wb") as f: + if isinstance(filepath, io_types_w): for chunk in server_response.iter_content(1024): # 1KB - f.write(chunk) - - logger.info("Downloaded flow to {0} (ID: {1})".format(download_path, flow_id)) - return os.path.abspath(download_path) + filepath.write(chunk) + return_path = filepath + else: + filename = to_filename(os.path.basename(params["filename"])) + download_path = make_download_path(filepath, filename) + with open(download_path, "wb") as f: + for chunk in server_response.iter_content(1024): # 1KB + f.write(chunk) + return_path = os.path.abspath(download_path) + + logger.info("Downloaded flow to {0} (ID: {1})".format(return_path, flow_id)) + return return_path # Update flow @api(version="3.3") @@ -153,24 +171,49 @@ def refresh(self, flow_item: FlowItem) -> JobItem: # Publish flow @api(version="3.3") def publish( - self, flow_item: FlowItem, file_path: FilePath, mode: str, connections: Optional[List[ConnectionItem]] = None + self, flow_item: FlowItem, file: PathOrFileR, mode: str, connections: Optional[List[ConnectionItem]] = None ) -> FlowItem: - if not os.path.isfile(file_path): - error = "File path does not lead to an existing file." - raise IOError(error) if not mode or not hasattr(self.parent_srv.PublishMode, mode): error = "Invalid mode defined." raise ValueError(error) - filename = os.path.basename(file_path) - file_extension = os.path.splitext(filename)[1][1:] + if isinstance(file, (str, os.PathLike)): + if not os.path.isfile(file): + error = "File path does not lead to an existing file." + raise IOError(error) + + filename = os.path.basename(file) + file_extension = os.path.splitext(filename)[1][1:] + file_size = os.path.getsize(file) + + # If name is not defined, grab the name from the file to publish + if not flow_item.name: + flow_item.name = os.path.splitext(filename)[0] + if file_extension not in ALLOWED_FILE_EXTENSIONS: + error = "Only {} files can be published as flows.".format(", ".join(ALLOWED_FILE_EXTENSIONS)) + raise ValueError(error) + + elif isinstance(file, io_types_r): + if not flow_item.name: + error = "Flow item must have a name when passing a file object" + raise ValueError(error) + + file_type = get_file_type(file) + if file_type == "zip": + file_extension = "tflx" + elif file_type == "xml": + file_extension = "tfl" + else: + error = "Unsupported file type {}!".format(file_type) + raise ValueError(error) + + # Generate filename for file object. + # This is needed when publishing the flow in a single request + filename = "{}.{}".format(flow_item.name, file_extension) + file_size = get_file_object_size(file) - # If name is not defined, grab the name from the file to publish - if not flow_item.name: - flow_item.name = os.path.splitext(filename)[0] - if file_extension not in ALLOWED_FILE_EXTENSIONS: - error = "Only {} files can be published as flows.".format(", ".join(ALLOWED_FILE_EXTENSIONS)) - raise ValueError(error) + else: + raise TypeError("file should be a filepath or file object.") # Construct the url with the defined mode url = "{0}?flowType={1}".format(self.baseurl, file_extension) @@ -178,15 +221,24 @@ def publish( url += "&{0}=true".format(mode.lower()) # Determine if chunking is required (64MB is the limit for single upload method) - if os.path.getsize(file_path) >= FILESIZE_LIMIT: + if file_size >= FILESIZE_LIMIT: logger.info("Publishing {0} to server with chunking method (flow over 64MB)".format(filename)) - upload_session_id = self.parent_srv.fileuploads.upload(file_path) + upload_session_id = self.parent_srv.fileuploads.upload(file) url = "{0}&uploadSessionId={1}".format(url, upload_session_id) xml_request, content_type = RequestFactory.Flow.publish_req_chunked(flow_item, connections) else: logger.info("Publishing {0} to server".format(filename)) - with open(file_path, "rb") as f: - file_contents = f.read() + + if isinstance(file, (str, Path)): + with open(file, "rb") as f: + file_contents = f.read() + + elif isinstance(file, io_types_r): + file_contents = file.read() + + else: + raise TypeError("file should be a filepath or file object.") + xml_request, content_type = RequestFactory.Flow.publish_req(flow_item, filename, file_contents, connections) # Send the publishing request to server diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index 163bb8c71..8cca4150a 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -45,6 +45,9 @@ from ...models.connection_credentials import ConnectionCredentials from .schedules_endpoint import AddResponse +io_types_r = (io.BytesIO, io.BufferedReader) +io_types_w = (io.BytesIO, io.BufferedWriter) + # The maximum size of a file that can be published in a single request is 64MB FILESIZE_LIMIT = 1024 * 1024 * 64 # 64MB @@ -53,7 +56,10 @@ logger = logging.getLogger("tableau.endpoint.workbooks") FilePath = Union[str, os.PathLike] FileObject = Union[io.BufferedReader, io.BytesIO] -PathOrFile = Union[FilePath, FileObject] +FileObjectR = Union[io.BufferedReader, io.BytesIO] +FileObjectW = Union[io.BufferedWriter, io.BytesIO] +PathOrFileR = Union[FilePath, FileObjectR] +PathOrFileW = Union[FilePath, FileObjectW] class Workbooks(QuerysetEndpoint): @@ -179,38 +185,11 @@ def update_connection(self, workbook_item: WorkbookItem, connection_item: Connec def download( self, workbook_id: str, - filepath: Optional[FilePath] = None, + filepath: Optional[PathOrFileW] = None, include_extract: bool = True, no_extract: Optional[bool] = None, ) -> str: - if not workbook_id: - error = "Workbook ID undefined." - raise ValueError(error) - url = "{0}/{1}/content".format(self.baseurl, workbook_id) - - if no_extract is False or no_extract is True: - import warnings - - warnings.warn( - "no_extract is deprecated, use include_extract instead.", - DeprecationWarning, - ) - include_extract = not no_extract - - if not include_extract: - url += "?includeExtract=False" - - with closing(self.get_request(url, parameters={"stream": True})) as server_response: - _, params = cgi.parse_header(server_response.headers["Content-Disposition"]) - filename = to_filename(os.path.basename(params["filename"])) - - download_path = make_download_path(filepath, filename) - - with open(download_path, "wb") as f: - for chunk in server_response.iter_content(1024): # 1KB - f.write(chunk) - logger.info("Downloaded workbook to {0} (ID: {1})".format(download_path, workbook_id)) - return os.path.abspath(download_path) + return self.download_revision(workbook_id, None, filepath, include_extract, no_extract) # Get all views of workbook @api(version="2.0") @@ -332,7 +311,7 @@ def delete_permission(self, item, capability_item): def publish( self, workbook_item: WorkbookItem, - file: PathOrFile, + file: PathOrFileR, mode: str, connection_credentials: Optional["ConnectionCredentials"] = None, connections: Optional[Sequence[ConnectionItem]] = None, @@ -350,7 +329,6 @@ def publish( ) if isinstance(file, (str, os.PathLike)): - # Expect file to be a filepath if not os.path.isfile(file): error = "File path does not lead to an existing file." raise IOError(error) @@ -366,12 +344,12 @@ def publish( error = "Only {} files can be published as workbooks.".format(", ".join(ALLOWED_FILE_EXTENSIONS)) raise ValueError(error) - elif isinstance(file, (io.BytesIO, io.BufferedReader)): - # Expect file to be a file object - file_size = get_file_object_size(file) + elif isinstance(file, io_types_r): + if not workbook_item.name: + error = "Workbook item must have a name when passing a file object" + raise ValueError(error) file_type = get_file_type(file) - if file_type == "zip": file_extension = "twbx" elif file_type == "xml": @@ -380,13 +358,10 @@ def publish( error = "Unsupported file type {}!".format(file_type) raise ValueError(error) - if not workbook_item.name: - error = "Workbook item must have a name when passing a file object" - raise ValueError(error) - # Generate filename for file object. # This is needed when publishing the workbook in a single request filename = "{}.{}".format(workbook_item.name, file_extension) + file_size = get_file_object_size(file) else: raise TypeError("file should be a filepath or file object.") @@ -428,7 +403,7 @@ def publish( with open(file, "rb") as f: file_contents = f.read() - elif isinstance(file, (io.BytesIO, io.BufferedReader)): + elif isinstance(file, io_types_r): file_contents = file.read() else: @@ -489,14 +464,17 @@ def download_revision( self, workbook_id: str, revision_number: str, - filepath: Optional[PathOrFile] = None, + filepath: Optional[PathOrFileW] = None, include_extract: bool = True, no_extract: Optional[bool] = None, - ) -> str: + ) -> PathOrFileW: if not workbook_id: error = "Workbook ID undefined." raise ValueError(error) - url = "{0}/{1}/revisions/{2}/content".format(self.baseurl, workbook_id, revision_number) + if revision_number is None: + url = "{0}/{1}/content".format(self.baseurl, workbook_id) + else: + url = "{0}/{1}/revisions/{2}/content".format(self.baseurl, workbook_id, revision_number) if no_extract is False or no_extract is True: import warnings @@ -512,17 +490,22 @@ def download_revision( with closing(self.get_request(url, parameters={"stream": True})) as server_response: _, params = cgi.parse_header(server_response.headers["Content-Disposition"]) - filename = to_filename(os.path.basename(params["filename"])) - - download_path = make_download_path(filepath, filename) - - with open(download_path, "wb") as f: + if isinstance(filepath, io_types_w): for chunk in server_response.iter_content(1024): # 1KB - f.write(chunk) + filepath.write(chunk) + return_path = filepath + else: + filename = to_filename(os.path.basename(params["filename"])) + download_path = make_download_path(filepath, filename) + with open(download_path, "wb") as f: + for chunk in server_response.iter_content(1024): # 1KB + f.write(chunk) + return_path = os.path.abspath(download_path) + logger.info( - "Downloaded workbook revision {0} to {1} (ID: {2})".format(revision_number, download_path, workbook_id) + "Downloaded workbook revision {0} to {1} (ID: {2})".format(revision_number, return_path, workbook_id) ) - return os.path.abspath(download_path) + return return_path @api(version="2.3") def delete_revision(self, workbook_id: str, revision_number: str) -> None: diff --git a/test/assets/SampleFlow.tfl b/test/assets/SampleFlow.tfl new file mode 100644 index 0000000000000000000000000000000000000000..c46d9ced964c70d7601e58f4b4d3002412dc4dc1 GIT binary patch literal 1884 zcmWIWW@Zs#;Nak3Q1jLeV?YA@Kz2%IaY0UEWpHXqNoHPpacciYzC#KkuHWleU*2&H z5BI%uWy!Il76O$~XI?L6l**O*7rWrbg&p&XcYZe%>zS%suUX5OyJXR%+#9>PLM}C} zQ>yiFD7?Z{lCaoaa>|j5F6UxQHSbS4az{|y-ty}O-?uC7S!Y^3apt||RW+;T*zLVx zwdDa!K{-3cgiICg;#i=PDt$jX%sx3u&PBtv@OF#H*N*ze;#C|*yJ}3W zGYc(?KR=h!WVsNQct2=Lc$xMc@c?gjj>t!=@9Y42sfm$+ApozZa}!gGON#P+Q%e$4 z5=#=7hD7HJyNlF45AXStvP9tMN$$JuP0dcME;G6XoK~;=enT$ed`^0rFSE<4y<%L` zPq{HMG9A%1@#J20spV8|V7%EMhJTJrIhSRAb1zBLFB94-+BALYw>j@?o?B11oc>*5 z%Ie3Td*(CgaV)k>S$N1JN42EARA*aQOxZ>c)%mtvuM^H}R2Ba%qhup~ewu@q)9NJk zoax4$3Q26!^&ew-JoT==5nisF{n>N- z{py?bUz3V|8lI}(VRpa%N0G#*eZ_nKC**!UUfr3z`1`8*-^=ULb<=hnc%$DR9afTW z_sZ5a+I-r-xZLO4>+;yYIX4?~1;=bWP9XrRQ}QcfEbN zDAl*d&}`rP^~=4(|Fyn-+qJCv`oxEys@T`A|8sM9tJK3!yZa*I6Q89YkkNY)-&=W` zOMHpro4rD>ManL|IMT%Vq(O4NU!dAUE!WC7S2|VroE^5Va#>r*k`cuSb~j5?W( zDgL+EOYV0GRa+=u`7L0VUidOi+C<*>L{I5$wKvZ$1<$crwfx7LyY;n)eLl9g_RoEE zHtkHn!*5f)7D6@u3Bwtc_gn7VY8MrB$&w%2rfoRYIx9j*q{)&;Gu@^4X%SyQ_AbJL5j%?}UpvLaT$+9HwN?@mLu@-M@5u3(x%%-j{SP zue`gdcCXLY57TGeR4dV|J^5?bq+RjzQ%d>0KTYM{z3yE7oLPtd`P-Vd{*Cf^_}|Z0 z(bDvnb^E!uGk>##^5Vl@^Edv?3=Gn2`0^qPFt?`VJ8JJ@x|K!EMRZhyfdvkO-Y zSMF;R|F}i`rnAvZhiPULO`5!4?vBnD7rFWR`}zOIhr0~x3~!c%XnddZ_NKu_^-YD3 z3jbR08%{~Mm0VA>Mqf)+6@bNpME!e zv9i{0@4J#`>QC3d{CxCjuyakY?P~9ew;I!$t0nF?T;jD-U9(9cVfW-iCZ)UZj}eul&{f>DsNU!t&v`*KcgrUhmm1T1j3D1{F>Q`J1N2)F z-jYZR>U;O1#Da{FjMChsyu{2Lz2dTT-@db4%?1K3@u@e1xsR?{)*<$IVZKtX57Wo? z9TLC4`7UXfON(sFZ+p4z?g7J}mBAS^ueS>r9#cA%#OW5|q9W8NaCi1vC(o9It;W+t zr=68eEVr_o?ySh@=*zMxZ^QehI>Qw4hsF2e9skBp2$Y!noojLFnqa>t3HK!*7@WDi zqs8;#&tHvSU4yE%Z@-%5IC;vlukj#nS?9+aOa?mo5D*7=Gct)VAnFz5!WUGppaQg7 z23<39YC+Zf6=)7xX^pNKIj}(q3IQ5{Tu|~t*MaO0P$VEgJu5^9W|RbYvjUSc0|N^X KegM*8U>*P_VF4)s literal 0 HcmV?d00001 diff --git a/test/assets/flow_publish.xml b/test/assets/flow_publish.xml new file mode 100644 index 000000000..55af88d11 --- /dev/null +++ b/test/assets/flow_publish.xml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/test/test_datasource.py b/test/test_datasource.py index 46378201f..e486eec33 100644 --- a/test/test_datasource.py +++ b/test/test_datasource.py @@ -470,6 +470,18 @@ def test_download(self) -> None: self.assertTrue(os.path.exists(file_path)) os.remove(file_path) + def test_download_object(self) -> None: + with BytesIO() as file_object: + with requests_mock.mock() as m: + m.get( + self.baseurl + "/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/content", + headers={"Content-Disposition": 'name="tableau_datasource"; filename="Sample datasource.tds"'}, + ) + file_path = self.server.datasources.download( + "9dbd2263-16b5-46e1-9c43-a76bb8ab65fb", filepath=file_object + ) + self.assertTrue(isinstance(file_path, BytesIO)) + def test_download_sanitizes_name(self) -> None: filename = "Name,With,Commas.tds" disposition = 'name="tableau_workbook"; filename="{}"'.format(filename) diff --git a/test/test_flow.py b/test/test_flow.py index 269bc2f7e..bbd8a39d3 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1,16 +1,21 @@ +import os +import requests_mock import unittest -import requests_mock +from io import BytesIO import tableauserverclient as TSC from tableauserverclient.datetime_helpers import format_datetime from ._utils import read_xml_asset, asset -GET_XML = "flow_get.xml" -POPULATE_CONNECTIONS_XML = "flow_populate_connections.xml" -POPULATE_PERMISSIONS_XML = "flow_populate_permissions.xml" -UPDATE_XML = "flow_update.xml" -REFRESH_XML = "flow_refresh.xml" +TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") + +GET_XML = os.path.join(TEST_ASSET_DIR, "flow_get.xml") +POPULATE_CONNECTIONS_XML = os.path.join(TEST_ASSET_DIR, "flow_populate_connections.xml") +POPULATE_PERMISSIONS_XML = os.path.join(TEST_ASSET_DIR, "flow_populate_permissions.xml") +PUBLISH_XML = os.path.join(TEST_ASSET_DIR, "flow_publish.xml") +UPDATE_XML = os.path.join(TEST_ASSET_DIR, "flow_update.xml") +REFRESH_XML = os.path.join(TEST_ASSET_DIR, "flow_refresh.xml") class FlowTests(unittest.TestCase): @@ -24,6 +29,26 @@ def setUp(self) -> None: self.baseurl = self.server.flows.baseurl + def test_download(self) -> None: + with requests_mock.mock() as m: + m.get( + self.baseurl + "/587daa37-b84d-4400-a9a2-aa90e0be7837/content", + headers={"Content-Disposition": 'name="tableau_flow"; filename="FlowOne.tfl"'}, + ) + file_path = self.server.flows.download("587daa37-b84d-4400-a9a2-aa90e0be7837") + self.assertTrue(os.path.exists(file_path)) + os.remove(file_path) + + def test_download_object(self) -> None: + with BytesIO() as file_object: + with requests_mock.mock() as m: + m.get( + self.baseurl + "/587daa37-b84d-4400-a9a2-aa90e0be7837/content", + headers={"Content-Disposition": 'name="tableau_flow"; filename="FlowOne.tfl"'}, + ) + file_path = self.server.flows.download("587daa37-b84d-4400-a9a2-aa90e0be7837", filepath=file_object) + self.assertTrue(isinstance(file_path, BytesIO)) + def test_get(self) -> None: response_xml = read_xml_asset(GET_XML) with requests_mock.mock() as m: @@ -116,6 +141,52 @@ def test_populate_permissions(self) -> None: }, ) + def test_publish(self) -> None: + with open(PUBLISH_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + m.post(self.baseurl, text=response_xml) + + new_flow = TSC.FlowItem(name="SampleFlow", project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760") + + sample_flow = os.path.join(TEST_ASSET_DIR, "SampleFlow.tfl") + publish_mode = self.server.PublishMode.CreateNew + + new_flow = self.server.flows.publish(new_flow, sample_flow, publish_mode) + + self.assertEqual("2457c468-1b24-461a-8f95-a461b3209d32", new_flow.id) + self.assertEqual("SampleFlow", new_flow.name) + self.assertEqual("2023-01-13T09:50:55Z", format_datetime(new_flow.created_at)) + self.assertEqual("2023-01-13T09:50:55Z", format_datetime(new_flow.updated_at)) + self.assertEqual("ee8c6e70-43b6-11e6-af4f-f7b0d8e20760", new_flow.project_id) + self.assertEqual("default", new_flow.project_name) + self.assertEqual("5de011f8-5aa9-4d5b-b991-f462c8dd6bb7", new_flow.owner_id) + + def test_publish_file_object(self) -> None: + with open(PUBLISH_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + m.post(self.baseurl, text=response_xml) + + new_flow = TSC.FlowItem(name="SampleFlow", project_id="ee8c6e70-43b6-11e6-af4f-f7b0d8e20760") + + sample_flow = os.path.join(TEST_ASSET_DIR, "SampleFlow.tfl") + publish_mode = self.server.PublishMode.CreateNew + + with open(sample_flow, "rb") as fp: + + publish_mode = self.server.PublishMode.CreateNew + + new_flow = self.server.flows.publish(new_flow, fp, publish_mode) + + self.assertEqual("2457c468-1b24-461a-8f95-a461b3209d32", new_flow.id) + self.assertEqual("SampleFlow", new_flow.name) + self.assertEqual("2023-01-13T09:50:55Z", format_datetime(new_flow.created_at)) + self.assertEqual("2023-01-13T09:50:55Z", format_datetime(new_flow.updated_at)) + self.assertEqual("ee8c6e70-43b6-11e6-af4f-f7b0d8e20760", new_flow.project_id) + self.assertEqual("default", new_flow.project_name) + self.assertEqual("5de011f8-5aa9-4d5b-b991-f462c8dd6bb7", new_flow.owner_id) + def test_refresh(self): with open(asset(REFRESH_XML), "rb") as f: response_xml = f.read().decode("utf-8") diff --git a/test/test_workbook.py b/test/test_workbook.py index ba21dc195..2e5de9369 100644 --- a/test/test_workbook.py +++ b/test/test_workbook.py @@ -267,6 +267,16 @@ def test_download(self) -> None: self.assertTrue(os.path.exists(file_path)) os.remove(file_path) + def test_download_object(self) -> None: + with BytesIO() as file_object: + with requests_mock.mock() as m: + m.get( + self.baseurl + "/1f951daf-4061-451a-9df1-69a8062664f2/content", + headers={"Content-Disposition": 'name="tableau_workbook"; filename="RESTAPISample.twbx"'}, + ) + file_path = self.server.workbooks.download("1f951daf-4061-451a-9df1-69a8062664f2", filepath=file_object) + self.assertTrue(isinstance(file_path, BytesIO)) + def test_download_sanitizes_name(self) -> None: filename = "Name,With,Commas.twbx" disposition = 'name="tableau_workbook"; filename="{}"'.format(filename) From d9f64e144cedf5dd31979e19bb33ad5084d770d9 Mon Sep 17 00:00:00 2001 From: Stu Tomlinson Date: Tue, 24 Jan 2023 03:21:59 +0000 Subject: [PATCH 16/18] Add updated_at to JobItem class (#1182) --- tableauserverclient/models/job_item.py | 10 +++++++++- test/test_job.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tableauserverclient/models/job_item.py b/tableauserverclient/models/job_item.py index 39562cd45..a7490e705 100644 --- a/tableauserverclient/models/job_item.py +++ b/tableauserverclient/models/job_item.py @@ -34,6 +34,7 @@ def __init__( workbook_id: Optional[str] = None, datasource_id: Optional[str] = None, flow_run: Optional[FlowRunItem] = None, + updated_at: Optional["datetime.datetime"] = None, ): self._id = id_ self._type = job_type @@ -47,6 +48,7 @@ def __init__( self._workbook_id = workbook_id self._datasource_id = datasource_id self._flow_run = flow_run + self._updated_at = updated_at @property def id(self) -> str: @@ -113,9 +115,13 @@ def flow_run(self): def flow_run(self, value): self._flow_run = value + @property + def updated_at(self) -> Optional["datetime.datetime"]: + return self._updated_at + def __repr__(self): return ( - "".format(**self.__dict__) ) @@ -144,6 +150,7 @@ def _parse_element(cls, element, ns): datasource = element.find(".//t:datasource[@id]", namespaces=ns) datasource_id = datasource.get("id") if datasource is not None else None flow_run = None + updated_at = parse_datetime(element.get("updatedAt", None)) for flow_job in element.findall(".//t:runFlowJobType", namespaces=ns): flow_run = FlowRunItem() flow_run._id = flow_job.get("flowRunId", None) @@ -163,6 +170,7 @@ def _parse_element(cls, element, ns): workbook_id, datasource_id, flow_run, + updated_at, ) diff --git a/test/test_job.py b/test/test_job.py index 19a93e808..83edadaef 100644 --- a/test/test_job.py +++ b/test/test_job.py @@ -53,8 +53,10 @@ def test_get_by_id(self) -> None: with requests_mock.mock() as m: m.get("{0}/{1}".format(self.baseurl, job_id), text=response_xml) job = self.server.jobs.get_by_id(job_id) + updated_at = datetime(2020, 5, 13, 20, 25, 18, tzinfo=utc) self.assertEqual(job_id, job.id) + self.assertEqual(updated_at, job.updated_at) self.assertListEqual(job.notes, ["Job detail notes"]) def test_get_before_signin(self) -> None: From 47eab0b7c2d3f72e49aaf792110ef090a6297829 Mon Sep 17 00:00:00 2001 From: Jeremy Harris Date: Mon, 13 Feb 2023 19:23:04 -0800 Subject: [PATCH 17/18] fix revision references where xml returned does not match docs (#1176) * fix revision references where xml returned does not match docs --- tableauserverclient/models/revision_item.py | 8 ++++---- test/assets/datasource_revision.xml | 10 +++++----- test/assets/workbook_revision.xml | 10 +++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tableauserverclient/models/revision_item.py b/tableauserverclient/models/revision_item.py index a49be88a7..600d73168 100644 --- a/tableauserverclient/models/revision_item.py +++ b/tableauserverclient/models/revision_item.py @@ -67,10 +67,10 @@ def from_response(cls, resp: bytes, ns, resource_item) -> List["RevisionItem"]: revision_item._resource_id = resource_item.id revision_item._resource_name = resource_item.name revision_item._revision_number = revision_xml.get("revisionNumber", None) - revision_item._current = string_to_bool(revision_xml.get("isCurrent", "")) - revision_item._deleted = string_to_bool(revision_xml.get("isDeleted", "")) - revision_item._created_at = parse_datetime(revision_xml.get("createdAt", None)) - for user in revision_xml.findall(".//t:user", namespaces=ns): + revision_item._current = string_to_bool(revision_xml.get("current", "")) + revision_item._deleted = string_to_bool(revision_xml.get("deleted", "")) + revision_item._created_at = parse_datetime(revision_xml.get("publishedAt", None)) + for user in revision_xml.findall(".//t:publisher", namespaces=ns): revision_item._user_id = user.get("id", None) revision_item._user_name = user.get("name", None) diff --git a/test/assets/datasource_revision.xml b/test/assets/datasource_revision.xml index 598c8ad45..8cadafc8f 100644 --- a/test/assets/datasource_revision.xml +++ b/test/assets/datasource_revision.xml @@ -2,13 +2,13 @@ - - + + - + - - + + \ No newline at end of file diff --git a/test/assets/workbook_revision.xml b/test/assets/workbook_revision.xml index 598c8ad45..8cadafc8f 100644 --- a/test/assets/workbook_revision.xml +++ b/test/assets/workbook_revision.xml @@ -2,13 +2,13 @@ - - + + - + - - + + \ No newline at end of file From 06e33fae632e5f5ba4ae829226d2de9d28771981 Mon Sep 17 00:00:00 2001 From: Jac Date: Mon, 13 Feb 2023 19:25:14 -0800 Subject: [PATCH 18/18] Do not create empty connections list (#1178) This should fix (1) from https://github.com/tableau/server-client-python/issues/1139#issuecomment-1379162364, by preventing us from sending an empty connections element if the list of connections is empty. --- tableauserverclient/server/request_factory.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 209626051..720eb4085 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -182,10 +182,10 @@ def _generate_xml(self, datasource_item, connection_credentials=None, connection if connection_credentials is not None and connections is not None: raise RuntimeError("You cannot set both `connections` and `connection_credentials`") - if connection_credentials is not None: + if connection_credentials is not None and connection_credentials != False: _add_credentials_element(datasource_element, connection_credentials) - if connections is not None and len(connections) > 0: + if connections is not None and connections != False and len(connections) > 0: connections_element = ET.SubElement(datasource_element, "connections") for connection in connections: _add_connections_element(connections_element, connection) @@ -337,7 +337,7 @@ def _generate_xml(self, flow_item: "FlowItem", connections: Optional[List["Conne project_element = ET.SubElement(flow_element, "project") project_element.attrib["id"] = flow_item.project_id - if connections is not None: + if connections is not None and connections != False: connections_element = ET.SubElement(flow_element, "connections") for connection in connections: _add_connections_element(connections_element, connection) @@ -904,10 +904,10 @@ def _generate_xml( if connection_credentials is not None and connections is not None: raise RuntimeError("You cannot set both `connections` and `connection_credentials`") - if connection_credentials is not None: + if connection_credentials is not None and connection_credentials != False: _add_credentials_element(workbook_element, connection_credentials) - if connections is not None and len(connections) > 0: + if connections is not None and connections != False and len(connections) > 0: connections_element = ET.SubElement(workbook_element, "connections") for connection in connections: _add_connections_element(connections_element, connection)