Skip to content

Commit

Permalink
Add a way to filter fields (#3274)
Browse files Browse the repository at this point in the history
* Add a way to filter fields

* Update test

* Add release notes

* Docs

* Update pyright tests
  • Loading branch information
patrick91 committed Jan 22, 2024
1 parent c26bb05 commit 85fb58c
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 50 deletions.
33 changes: 33 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Release type: minor

This release adds a new method `get_fields` on the `Schema` class.
You can use `get_fields` to hide certain field based on some conditions,
for example:

```python
@strawberry.type
class User:
name: str
email: str = strawberry.field(metadata={"tags": ["internal"]})


@strawberry.type
class Query:
user: User


def public_field_filter(field: StrawberryField) -> bool:
return "internal" not in field.metadata.get("tags", [])


class PublicSchema(strawberry.Schema):
def get_fields(
self, type_definition: StrawberryObjectDefinition
) -> List[StrawberryField]:
return list(filter(public_field_filter, type_definition.fields))


schema = PublicSchema(query=Query)
```

The schema here would only have the `name` field on the `User` type.
40 changes: 40 additions & 0 deletions docs/types/schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,43 @@ class StrawberryLogger:

cls.logger.error(error, exc_info=error.original_error, **logger_kwargs)
```

## Filtering/customising fields

You can customise the fields that are exposed on a schema by subclassing the
`Schema` class and overriding the `get_fields` method, for example you can use
this to create different GraphQL APIs, such as a public and an internal API.
Here's an example of this:

```python
@strawberry.type
class User:
name: str
email: str = strawberry.field(metadata={"tags": ["internal"]})


@strawberry.type
class Query:
user: User


def public_field_filter(field: StrawberryField) -> bool:
return "internal" not in field.metadata.get("tags", [])


class PublicSchema(strawberry.Schema):
def get_fields(
self, type_definition: StrawberryObjectDefinition
) -> List[StrawberryField]:
return list(filter(public_field_filter, type_definition.fields))


schema = PublicSchema(query=Query)
```

<Note>

The `get_fields` method is only called once when creating the schema, this is
not intended to be used to dynamically customise the schema.

</Note>
9 changes: 8 additions & 1 deletion strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def __init__(
# TODO: check that the overrides are valid
scalar_registry.update(cast(SCALAR_OVERRIDES_DICT_TYPE, scalar_overrides))

self.schema_converter = GraphQLCoreConverter(self.config, scalar_registry)
self.schema_converter = GraphQLCoreConverter(
self.config, scalar_registry, self.get_fields
)
self.directives = directives
self.schema_directives = list(schema_directives)

Expand Down Expand Up @@ -231,6 +233,11 @@ def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirecti
None,
)

def get_fields(
self, type_definition: StrawberryObjectDefinition
) -> List[StrawberryField]:
return type_definition.fields

async def execute(
self,
query: Optional[str],
Expand Down
9 changes: 8 additions & 1 deletion strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _get_thunk_mapping(
type_definition: StrawberryObjectDefinition,
name_converter: Callable[[StrawberryField], str],
field_converter: FieldConverterProtocol[FieldType],
get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
) -> Dict[str, FieldType]:
"""Create a GraphQL core `ThunkMapping` mapping of field names to field types.
Expand All @@ -123,7 +124,9 @@ def _get_thunk_mapping(
"""
thunk_mapping: Dict[str, FieldType] = {}

for field in type_definition.fields:
fields = get_fields(type_definition)

for field in fields:
field_type = field.type

if field_type is UNRESOLVED:
Expand Down Expand Up @@ -178,10 +181,12 @@ def __init__(
self,
config: StrawberryConfig,
scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]],
get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
):
self.type_map: Dict[str, ConcreteType] = {}
self.config = config
self.scalar_registry = scalar_registry
self.get_fields = get_fields

def from_argument(self, argument: StrawberryArgument) -> GraphQLArgument:
argument_type = cast(
Expand Down Expand Up @@ -374,6 +379,7 @@ def get_graphql_fields(
type_definition=type_definition,
name_converter=self.config.name_converter.from_field,
field_converter=self.from_field,
get_fields=self.get_fields,
)

def get_graphql_input_fields(
Expand All @@ -383,6 +389,7 @@ def get_graphql_input_fields(
type_definition=type_definition,
name_converter=self.config.name_converter.from_field,
field_converter=self.from_input_field,
get_fields=self.get_fields,
)

def from_input_object(self, object_type: type) -> GraphQLInputObjectType:
Expand Down
18 changes: 9 additions & 9 deletions tests/pyright/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def test_federation_type():
results = run_pyright(CODE)

assert results == [
Result(type="error", message='No parameter named "n"', line=16, column=6),
Result(
type="error",
message='Argument missing for parameter "name"',
line=16,
column=1,
),
Result(type="error", message='No parameter named "n"', line=16, column=6),
Result(
type="information",
message='Type of "User" is "type[User]"',
Expand Down Expand Up @@ -75,15 +75,15 @@ def test_federation_interface():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=12,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=12,
column=1,
column=6,
),
Result(
type="information",
Expand Down Expand Up @@ -122,15 +122,15 @@ def test_federation_input():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=10,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=10,
column=1,
column=6,
),
Result(
type="information",
Expand Down
14 changes: 7 additions & 7 deletions tests/pyright/test_federation_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def test_pyright():
results = run_pyright(CODE)

assert results == [
Result(
type="error",
message='Argument missing for parameter "name"',
line=24,
column=1,
),
Result(
type="error",
message='No parameter named "n"',
Expand All @@ -51,7 +57,7 @@ def test_pyright():
Result(
type="error",
message='Argument missing for parameter "name"',
line=24,
line=27,
column=1,
),
Result(
Expand All @@ -60,12 +66,6 @@ def test_pyright():
line=27,
column=11,
),
Result(
type="error",
message='Argument missing for parameter "name"',
line=27,
column=1,
),
Result(
type="information",
message='Type of "User" is "type[User]"',
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_federation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=11,
column=11,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=11,
column=1,
column=11,
),
]
8 changes: 4 additions & 4 deletions tests/pyright/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=11,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=11,
column=1,
column=6,
),
Result(
type="information",
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_fields_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=11,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=11,
column=1,
column=6,
),
Result(
type="information",
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_fields_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=15,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=15,
column=1,
column=6,
),
Result(
type="information",
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_fields_resolver_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=15,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=15,
column=1,
column=6,
),
Result(
type="information",
Expand Down
14 changes: 7 additions & 7 deletions tests/pyright/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ def test_pyright():
results = run_pyright(CODE)

assert results == [
Result(
type="error",
message='No parameter named "n"',
line=16,
column=11,
),
Result(
type="error",
message='Argument missing for parameter "name"',
Expand All @@ -44,7 +38,7 @@ def test_pyright():
Result(
type="error",
message='No parameter named "n"',
line=19,
line=16,
column=11,
),
Result(
Expand All @@ -53,4 +47,10 @@ def test_pyright():
line=19,
column=1,
),
Result(
type="error",
message='No parameter named "n"',
line=19,
column=11,
),
]
8 changes: 4 additions & 4 deletions tests/pyright/test_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Arguments missing for parameters "name", "age"',
line=12,
column=6,
column=1,
),
Result(
type="error",
message='Arguments missing for parameters "name", "age"',
message='No parameter named "n"',
line=12,
column=1,
column=6,
),
Result(
type="information",
Expand Down

0 comments on commit 85fb58c

Please sign in to comment.