diff --git a/shotgun_api3/lib/mockgun/mockgun.py b/shotgun_api3/lib/mockgun/mockgun.py index e94d92b4..acabd8f8 100644 --- a/shotgun_api3/lib/mockgun/mockgun.py +++ b/shotgun_api3/lib/mockgun/mockgun.py @@ -193,6 +193,8 @@ def __init__(self, # they way they would expect to in the real API. self.config = _Config(self) + self.config.set_server_params(base_url) + # load in the shotgun schema to associate with this Shotgun (schema_path, schema_entity_path) = self.get_schema_paths() diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index 423e728e..142e445f 100755 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -429,6 +429,30 @@ def __init__(self, sg): self.no_ssl_validation = False self.localized = False + def set_server_params(self, base_url): + """ + Set the different server related fields based on the passed in URL. + + This will impact the following attributes: + + - scheme: http or https + - api_path: usually /api3/json + - server: usually something.shotgunstudio.com + + :param str base_url: The server URL. + + :raises ValueError: Raised if protocol is not http or https. + """ + self.scheme, self.server, api_base, _, _ = \ + urllib.parse.urlsplit(base_url) + if self.scheme not in ("http", "https"): + raise ValueError( + "base_url must use http or https got '%s'" % base_url + ) + self.api_path = urllib.parse.urljoin(urllib.parse.urljoin( + api_base or "/", self.api_ver + "/"), "json" + ) + @property def records_per_page(self): """ @@ -646,17 +670,14 @@ def __init__(self, self.__ca_certs = os.environ.get("SHOTGUN_API_CACERTS") self.base_url = (base_url or "").lower() - self.config.scheme, self.config.server, api_base, _, _ = \ - urllib.parse.urlsplit(self.base_url) - if self.config.scheme not in ("http", "https"): - raise ValueError("base_url must use http or https got '%s'" % - self.base_url) - self.config.api_path = urllib.parse.urljoin(urllib.parse.urljoin( - api_base or "/", self.config.api_ver + "/"), "json") + self.config.set_server_params(self.base_url) # if the service contains user information strip it out # copied from the xmlrpclib which turned the user:password into # and auth header + # Do NOT urlsplit(self.base_url) here, as it contains the lower case version + # of the base_url argument. Doing so would base64-encode the lowercase + # version of the credentials. auth, self.config.server = urllib.parse.splituser(urllib.parse.urlsplit(base_url).netloc) if auth: auth = base64.encodestring(six.ensure_binary(urllib.parse.unquote(auth))).decode("utf-8") @@ -2054,7 +2075,7 @@ def schema_field_update(self, entity_type, field_name, properties, project_entit :param properties: Dictionary with key/value pairs where the key is the property to be updated and the value is the new value. :param dict project_entity: Optional Project entity specifying which project to modify the - ``visible`` property for. If the ``visible`` is present in ``properties`` and + ``visible`` property for. If ``visible`` is present in ``properties`` and ``project_entity`` is not set, an exception will be raised. Example: ``{'type': 'Project', 'id': 3}`` :returns: ``True`` if the field was updated. diff --git a/tests/test_mockgun.py b/tests/test_mockgun.py index 6efa5f27..ce851135 100644 --- a/tests/test_mockgun.py +++ b/tests/test_mockgun.py @@ -433,5 +433,29 @@ def test_invalid_operator(self): ) +class TestConfig(unittest.TestCase): + """ + Tests the shotgun._Config class + """ + + def test_set_server_params_with_regular_url(self): + """ + Make sure it works with a normal URL. + """ + mockgun = Mockgun("https://server.shotgunstudio.com/") + self.assertEqual(mockgun.config.scheme, "https") + self.assertEqual(mockgun.config.server, "server.shotgunstudio.com") + self.assertEqual(mockgun.config.api_path, "/api3/json") + + def test_set_server_params_with_url_with_path(self): + """ + Make sure it works with a URL with a path + """ + mockgun = Mockgun("https://local/something/") + self.assertEqual(mockgun.config.scheme, "https") + self.assertEqual(mockgun.config.server, "local") + self.assertEqual(mockgun.config.api_path, "/something/api3/json") + + if __name__ == '__main__': unittest.main()