diff --git a/searchtweets/credentials.py b/searchtweets/credentials.py index f049d8e..902d96c 100644 --- a/searchtweets/credentials.py +++ b/searchtweets/credentials.py @@ -11,8 +11,12 @@ import os import logging import yaml +import requests +import base64 from .utils import merge_dicts +OAUTH_ENDPOINT = 'https://api.twitter.com/oauth2/token' + __all__ = ["load_credentials"] logger = logging.getLogger(__name__) @@ -76,8 +80,16 @@ def _parse_credentials(search_creds, account_type): try: if account_type == "premium": - search_args = {"bearer_token": search_creds["bearer_token"], - "endpoint": search_creds["endpoint"]} + if "bearer_token" not in search_creds: + if "consumer_key" in search_creds \ + and "consumer_secret" in search_creds: + search_creds["bearer_token"] = _generate_bearer_token( + search_creds["consumer_key"], + search_creds["consumer_secret"]) + + search_args = { + "bearer_token": search_creds["bearer_token"], + "endpoint": search_creds["endpoint"]} if account_type == "enterprise": search_args = {"username": search_creds["username"], "password": search_creds["password"], @@ -183,3 +195,25 @@ def load_credentials(filename=None, account_type=None, else merge_dicts(env_vars, yaml_vars)) parsed_vars = _parse_credentials(merged_vars, account_type=account_type) return parsed_vars + + +def _generate_bearer_token(consumer_key, consumer_secret): + """ + Return the bearer token for a given pair of consumer key and secret values. + """ + auth = base64.b64encode("{0}:{1}".format( + consumer_key, + consumer_secret).encode()).decode() + + headers = { + 'Authorization': 'Basic {0}'.format(auth), + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8'} + data = 'grant_type=client_credentials' + resp = requests.post(OAUTH_ENDPOINT, + data=data, + headers=headers) + if resp.status_code >= 400: + logger.error(resp.text) + resp.raise_for_status() + + return resp.json()['access_token']