Skip to content

Commit

Permalink
Use standard HTTPMethod StrEnum
Browse files Browse the repository at this point in the history
  • Loading branch information
Shadow-Devil committed May 20, 2023
1 parent f08b2e6 commit d0d1d6e
Showing 1 changed file with 48 additions and 26 deletions.
74 changes: 48 additions & 26 deletions simple_salesforce/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Core classes and exceptions for Simple-Salesforce"""
from datetime import datetime
from http import HTTPMethod

# has to be defined prior to login import
DEFAULT_API_VERSION = '52.0'
Expand Down Expand Up @@ -283,7 +284,11 @@ def describe(self, **kwargs: Any) -> Optional[Any]:
* keyword arguments supported by requests.request (e.g. json, timeout)
"""
url = self.base_url + "sobjects"
result = self._call_salesforce('GET', url, name='describe', **kwargs)
result = self._call_salesforce(
HTTPMethod.GET,
url,
name='describe',
**kwargs)

json_result = self.parse_result_to_json(result)
if len(json_result) == 0:
Expand Down Expand Up @@ -341,7 +346,10 @@ def set_password(self, user: str, password: str) -> Optional[Any]:
url = f'{self.base_url}sobjects/User/{user}/password'
params = {'NewPassword': password}

result = self._call_salesforce('POST', url, data=json.dumps(params))
result = self._call_salesforce(
HTTPMethod.POST,
url,
data=json.dumps(params))

if result.status_code == 204:
return None
Expand All @@ -359,7 +367,7 @@ def restful(
self,
path: str,
params: Optional[Dict[str, Any]] = None,
method: str = 'GET',
method: Union[str, HTTPMethod] = HTTPMethod.GET,
**kwargs: Any) -> Optional[Any]:
"""Allows you to make a direct REST call if you know the path
Expand All @@ -386,7 +394,7 @@ def oauth2(
self,
path: str,
params: Optional[Dict[str, Any]] = None,
method: str = 'GET') -> Optional[Any]:
method: Union[str, HTTPMethod] = HTTPMethod.GET) -> Optional[Any]:
"""Allows you to make a request to OAuth endpoints if you know the path
Arguments:
Expand Down Expand Up @@ -419,7 +427,11 @@ def search(self, search: str) -> Any:

# `requests` will correctly encode the query string passed as `params`
params = {'q': search}
result = self._call_salesforce('GET', url, name='search', params=params)
result = self._call_salesforce(
HTTPMethod.GET,
url,
name='search',
params=params)

json_result = self.parse_result_to_json(result)
if len(json_result) == 0:
Expand All @@ -443,7 +455,7 @@ def limits(self, **kwargs: Any) -> Any:
limits.
"""
url = self.base_url + 'limits/'
result = self._call_salesforce('GET', url, **kwargs)
result = self._call_salesforce(HTTPMethod.GET, url, **kwargs)

if result.status_code != 200:
exception_handler(result)
Expand All @@ -466,7 +478,7 @@ def query(
url = self.base_url + ('queryAll/' if include_deleted else 'query/')
params = {'q': query}
# `requests` will correctly encode the query string passed as `params`
result = self._call_salesforce('GET', url, name='query',
result = self._call_salesforce(HTTPMethod.GET, url, name='query',
params=params, **kwargs)

return self.parse_result_to_json(result)
Expand Down Expand Up @@ -495,7 +507,11 @@ def query_more(
else:
endpoint = 'queryAll' if include_deleted else 'query'
url = f'{self.base_url}{endpoint}/{next_records_identifier}'
result = self._call_salesforce('GET', url, name='query_more', **kwargs)
result = self._call_salesforce(
HTTPMethod.GET,
url,
name='query_more',
**kwargs)

return self.parse_result_to_json(result)

Expand Down Expand Up @@ -560,7 +576,7 @@ def query_all(
def toolingexecute(
self,
action: str,
method: str = 'GET',
method: Union[str, HTTPMethod] = HTTPMethod.GET,
data: Optional[Dict[str, Any]] = None,
**kwargs: Any) -> Any:
"""Makes an HTTP request to an TOOLING REST endpoint
Expand Down Expand Up @@ -590,7 +606,7 @@ def toolingexecute(
def apexecute(
self,
action: str,
method: str = 'GET',
method: Union[str, HTTPMethod] = HTTPMethod.GET,
data: Optional[Dict[str, Any]] = None,
**kwargs: Any) -> Any:
"""Makes an HTTP request to an APEX REST endpoint
Expand Down Expand Up @@ -619,7 +635,7 @@ def apexecute(

def _call_salesforce(
self,
method: str,
method: Union[str, HTTPMethod],
url: str,
name: str = "",
**kwargs: Any) -> requests.Response:
Expand Down Expand Up @@ -804,7 +820,10 @@ def metadata(self, headers: Optional[Headers] = None) -> Any:
Arguments:
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce('GET', self.base_url, headers=headers)
result = self._call_salesforce(
HTTPMethod.GET,
self.base_url,
headers=headers)
return self.parse_result_to_json(result)

def describe(self, headers: Optional[Headers] = None) -> Any:
Expand All @@ -814,7 +833,7 @@ def describe(self, headers: Optional[Headers] = None) -> Any:
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce(
method='GET', url=urljoin(self.base_url, 'describe'),
method=HTTPMethod.GET, url=urljoin(self.base_url, 'describe'),
headers=headers
)
return self.parse_result_to_json(result)
Expand All @@ -833,7 +852,7 @@ def describe_layout(
"""
custom_url_part = f'describe/layouts/{record_id}'
result = self._call_salesforce(
method='GET',
method=HTTPMethod.GET,
url=urljoin(self.base_url, custom_url_part),
headers=headers
)
Expand All @@ -850,7 +869,7 @@ def get(
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce(
method='GET', url=urljoin(self.base_url, record_id),
method=HTTPMethod.GET, url=urljoin(self.base_url, record_id),
headers=headers
)
return self.parse_result_to_json(result)
Expand All @@ -872,7 +891,7 @@ def get_by_custom_id(
"""
custom_url = urljoin(self.base_url, f'{custom_id_field}/{custom_id}')
result = self._call_salesforce(
method='GET', url=custom_url, headers=headers
method=HTTPMethod.GET, url=custom_url, headers=headers
)
return self.parse_result_to_json(result)

Expand All @@ -888,7 +907,7 @@ def create(
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce(
method='POST', url=self.base_url,
method=HTTPMethod.POST, url=self.base_url,
data=json.dumps(data), headers=headers
)
return self.parse_result_to_json(result)
Expand All @@ -914,7 +933,7 @@ def upsert(
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce(
method='PATCH', url=urljoin(self.base_url, record_id),
method=HTTPMethod.PATCH, url=urljoin(self.base_url, record_id),
data=json.dumps(data), headers=headers
)
return self._raw_response(result, raw_response)
Expand All @@ -939,7 +958,7 @@ def update(
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce(
method='PATCH', url=urljoin(self.base_url, record_id),
method=HTTPMethod.PATCH, url=urljoin(self.base_url, record_id),
data=json.dumps(data), headers=headers
)
return self._raw_response(result, raw_response)
Expand All @@ -962,7 +981,7 @@ def delete(
* headers -- a dict with additional request headers.
"""
result = self._call_salesforce(
method='DELETE', url=urljoin(self.base_url, record_id),
method=HTTPMethod.DELETE, url=urljoin(self.base_url, record_id),
headers=headers
)
return self._raw_response(result, raw_response)
Expand All @@ -986,7 +1005,7 @@ def deleted(
self.base_url,
f'deleted/?start={date_to_iso8601(start)}&end={date_to_iso8601(end)}'
)
result = self._call_salesforce(method='GET', url=url, headers=headers)
result = self._call_salesforce(method=HTTPMethod.GET, url=url, headers=headers)
return self.parse_result_to_json(result)

def updated(
Expand All @@ -1008,12 +1027,15 @@ def updated(
self.base_url,
f'updated/?start={date_to_iso8601(start)}&end={date_to_iso8601(end)}'
)
result = self._call_salesforce(method='GET', url=url, headers=headers)
result = self._call_salesforce(
method=HTTPMethod.GET,
url=url,
headers=headers)
return self.parse_result_to_json(result)

def _call_salesforce(
self,
method: str,
method: HTTPMethod,
url: str,
**kwargs: Any) -> requests.Response:
"""Utility method for performing HTTP call to Salesforce.
Expand Down Expand Up @@ -1075,7 +1097,7 @@ def upload_base64(
data = {}
body = base64.b64encode(Path(file_path).read_bytes()).decode()
data[base64_field] = body
result = self._call_salesforce(method='POST', url=self.base_url,
result = self._call_salesforce(HTTPMethod.POST, url=self.base_url,
headers=headers, json=data, **kwargs)

return result
Expand All @@ -1092,7 +1114,7 @@ def update_base64(
data = {}
body = base64.b64encode(Path(file_path).read_bytes()).decode()
data[base64_field] = body
result = self._call_salesforce(method='PATCH',
result = self._call_salesforce(HTTPMethod.PATCH,
url=urljoin(self.base_url, record_id),
json=data,
headers=headers, **kwargs)
Expand All @@ -1114,7 +1136,7 @@ def get_base64(
Example: sobjects/Attachment/ABC123/Body
sobjects/ContentVersion/ABC123/VersionData
"""
result = self._call_salesforce(method='GET', url=urljoin(
result = self._call_salesforce(HTTPMethod.GET, url=urljoin(
self.base_url, f'{record_id}/{base64_field}'),
data=data,
headers=headers, **kwargs)
Expand Down

0 comments on commit d0d1d6e

Please sign in to comment.