diff --git a/replicate/client.py b/replicate/client.py index 4196d6aa..95c6e0a3 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -5,6 +5,7 @@ import requests from requests.adapters import HTTPAdapter, Retry +from requests.cookies import RequestsCookieJar from replicate.__about__ import __version__ from replicate.exceptions import ModelError, ReplicateError @@ -25,7 +26,7 @@ def __init__(self, api_token: Optional[str] = None) -> None: self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) # TODO: make thread safe - self.read_session = requests.Session() + self.read_session = _create_session() read_retries = Retry( total=5, backoff_factor=2, @@ -50,7 +51,7 @@ def __init__(self, api_token: Optional[str] = None) -> None: self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries)) self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries)) - self.write_session = requests.Session() + self.write_session = _create_session() write_retries = Retry( total=5, backoff_factor=2, @@ -138,3 +139,21 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: if prediction.status == "failed": raise ModelError(prediction.error) return prediction.output + + +class _NonpersistentCookieJar(RequestsCookieJar): + """ + A cookie jar that doesn't persist cookies between requests. + """ + + def set(self, name, value, **kwargs) -> None: + return + + def set_cookie(self, cookie, *args, **kwargs) -> None: + return + + +def _create_session() -> requests.Session: + s = requests.Session() + s.cookies = _NonpersistentCookieJar() + return s