Skip to content

Commit

Permalink
fix(user auth context): do not overwrite provided client options Auth…
Browse files Browse the repository at this point in the history
…orization header (#766)
  • Loading branch information
Garee committed Apr 17, 2024
1 parent bb24ce0 commit 4214c43
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 58 deletions.
39 changes: 11 additions & 28 deletions supabase/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,8 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self._auth_token = {
"Authorization": f"Bearer {supabase_key}",
}
options.headers.update(self._get_auth_headers())
self.options = options
options.headers.update(self._get_auth_headers())
self.rest_url = f"{supabase_url}/rest/v1"
self.realtime_url = f"{supabase_url}/realtime/v1".replace("http", "ws")
self.auth_url = f"{supabase_url}/auth/v1"
Expand Down Expand Up @@ -102,9 +99,7 @@ async def create(
supabase_key: str,
options: Union[ClientOptions, None] = None,
):
client = cls(supabase_url, supabase_key, options)
client._auth_token = await client._get_token_header()
return client
return cls(supabase_url, supabase_key, options)

def table(self, table_name: str) -> AsyncRequestBuilder:
"""Perform a table operation.
Expand Down Expand Up @@ -147,7 +142,6 @@ def rpc(
@property
def postgrest(self):
if self._postgrest is None:
self.options.headers.update(self._auth_token)
self._postgrest = self._init_postgrest_client(
rest_url=self.rest_url,
headers=self.options.headers,
Expand All @@ -160,21 +154,19 @@ def postgrest(self):
@property
def storage(self):
if self._storage is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._storage = self._init_storage_client(
storage_url=self.storage_url,
headers=headers,
headers=self.options.headers,
storage_client_timeout=self.options.storage_client_timeout,
)
return self._storage

@property
def functions(self):
if self._functions is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._functions = AsyncFunctionsClient(self.functions_url, headers)
self._functions = AsyncFunctionsClient(
self.functions_url, self.options.headers
)
return self._functions

# async def remove_subscription_helper(resolve):
Expand Down Expand Up @@ -248,26 +240,17 @@ def _init_postgrest_client(
)

def _create_auth_header(self, token: str):
return {
"Authorization": f"Bearer {token}",
}
return f"Bearer {token}"

def _get_auth_headers(self) -> Dict[str, str]:
"""Helper method to get auth headers."""
return {
"apiKey": self.supabase_key,
"Authorization": f"Bearer {self.supabase_key}",
"Authorization": self.options.headers.get(
"Authorization", self._create_auth_header(self.supabase_key)
),
}

async def _get_token_header(self):
try:
session = await self.auth.get_session()
access_token = session.access_token
except Exception as err:
access_token = self.supabase_key

return self._create_auth_header(access_token)

def _listen_to_auth_events(
self, event: AuthChangeEvent, session: Union[Session, None]
):
Expand All @@ -279,7 +262,7 @@ def _listen_to_auth_events(
self._functions = None
access_token = session.access_token if session else self.supabase_key

self._auth_token = self._create_auth_header(access_token)
self.options.headers["Authorization"] = self._create_auth_header(access_token)


async def create_client(
Expand Down
39 changes: 11 additions & 28 deletions supabase/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,8 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self._auth_token = {
"Authorization": f"Bearer {supabase_key}",
}
options.headers.update(self._get_auth_headers())
self.options = options
options.headers.update(self._get_auth_headers())
self.rest_url = f"{supabase_url}/rest/v1"
self.realtime_url = f"{supabase_url}/realtime/v1".replace("http", "ws")
self.auth_url = f"{supabase_url}/auth/v1"
Expand Down Expand Up @@ -102,9 +99,7 @@ def create(
supabase_key: str,
options: Union[ClientOptions, None] = None,
):
client = cls(supabase_url, supabase_key, options)
client._auth_token = client._get_token_header()
return client
return cls(supabase_url, supabase_key, options)

def table(self, table_name: str) -> SyncRequestBuilder:
"""Perform a table operation.
Expand Down Expand Up @@ -147,7 +142,6 @@ def rpc(
@property
def postgrest(self):
if self._postgrest is None:
self.options.headers.update(self._auth_token)
self._postgrest = self._init_postgrest_client(
rest_url=self.rest_url,
headers=self.options.headers,
Expand All @@ -160,21 +154,19 @@ def postgrest(self):
@property
def storage(self):
if self._storage is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._storage = self._init_storage_client(
storage_url=self.storage_url,
headers=headers,
headers=self.options.headers,
storage_client_timeout=self.options.storage_client_timeout,
)
return self._storage

@property
def functions(self):
if self._functions is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._functions = SyncFunctionsClient(self.functions_url, headers)
self._functions = SyncFunctionsClient(
self.functions_url, self.options.headers
)
return self._functions

# async def remove_subscription_helper(resolve):
Expand Down Expand Up @@ -248,26 +240,17 @@ def _init_postgrest_client(
)

def _create_auth_header(self, token: str):
return {
"Authorization": f"Bearer {token}",
}
return f"Bearer {token}"

def _get_auth_headers(self) -> Dict[str, str]:
"""Helper method to get auth headers."""
return {
"apiKey": self.supabase_key,
"Authorization": f"Bearer {self.supabase_key}",
"Authorization": self.options.headers.get(
"Authorization", self._create_auth_header(self.supabase_key)
),
}

def _get_token_header(self):
try:
session = self.auth.get_session()
access_token = session.access_token
except Exception as err:
access_token = self.supabase_key

return self._create_auth_header(access_token)

def _listen_to_auth_events(
self, event: AuthChangeEvent, session: Union[Session, None]
):
Expand All @@ -279,7 +262,7 @@ def _listen_to_auth_events(
self._functions = None
access_token = session.access_token if session else self.supabase_key

self._auth_token = self._create_auth_header(access_token)
self.options.headers["Authorization"] = self._create_auth_header(access_token)


def create_client(
Expand Down
78 changes: 76 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import os
from typing import Any
from unittest.mock import MagicMock

import pytest

from supabase import Client, create_client
from supabase.lib.client_options import ClientOptions


@pytest.mark.xfail(
reason="None of these values should be able to instantiate a client object"
Expand All @@ -12,6 +17,75 @@
@pytest.mark.parametrize("key", ["", None, "valeefgpoqwjgpj", 139, -1, {}, []])
def test_incorrect_values_dont_instantiate_client(url: Any, key: Any) -> None:
"""Ensure we can't instantiate client with invalid values."""
from supabase import Client, create_client

_: Client = create_client(url, key)


def test_uses_key_as_authorization_header_by_default() -> None:
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

client = create_client(url, key)

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == f"Bearer {key}"

assert client.postgrest.session.headers.get("apiKey") == key
assert client.postgrest.session.headers.get("Authorization") == f"Bearer {key}"

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == f"Bearer {key}"

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == f"Bearer {key}"


def test_supports_setting_a_global_authorization_header() -> None:
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

authorization = f"Bearer secretuserjwt"

options = ClientOptions(headers={"Authorization": authorization})

client = create_client(url, key, options)

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == authorization

assert client.postgrest.session.headers.get("apiKey") == key
assert client.postgrest.session.headers.get("Authorization") == authorization

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == authorization

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == authorization


def test_updates_the_authorization_header_on_auth_events() -> None:
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

client = create_client(url, key)

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == f"Bearer {key}"

mock_session = MagicMock(access_token="secretuserjwt")
client._listen_to_auth_events("SIGNED_IN", mock_session)

updated_authorization = f"Bearer {mock_session.access_token}"

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == updated_authorization

assert client.postgrest.session.headers.get("apiKey") == key
assert (
client.postgrest.session.headers.get("Authorization") == updated_authorization
)

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == updated_authorization

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == updated_authorization

0 comments on commit 4214c43

Please sign in to comment.