Skip to content

Commit

Permalink
Merge pull request #11 from rmyers/refactor-schema-extensions
Browse files Browse the repository at this point in the history
Refactoring schema extension to work with upstream changes
  • Loading branch information
rmyers committed May 8, 2023
2 parents fac04bb + 791e80c commit cf7d560
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 189 deletions.
2 changes: 2 additions & 0 deletions cannula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
)
from .errors import format_errors
from .utils import gql
from .schema import build_and_extend_schema

__all__ = [
"API",
"Context",
"Resolver",
"format_errors",
"gql",
"build_and_extend_schema",
]

__VERSION__ = "0.0.4"
57 changes: 23 additions & 34 deletions cannula/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,85 +74,74 @@ def my_query(source, info):
:param schema_directory: Directory name to search for schema files.
:param query_directory: Directory name to search for query docs.
"""

# Allow sub-resolvers to apply a base schema before applying custom schema.
base_schema: typing.Dict[str, DocumentNode] = {}
registry: typing.Dict[str, dict]
datasources: typing.Dict[str, typing.Any]
forms: typing.Dict[str, typing.Any]

def __init__(
self,
name: str,
schema: typing.Optional[typing.Union[str, DocumentNode]] = None,
schema_directory: str = 'schema',
query_directory: str = 'queries',
schema: typing.List[typing.Union[str, DocumentNode]] = [],
schema_directory: str = "schema",
query_directory: str = "queries",
):
self.registry = collections.defaultdict(dict)
self.datasources = {}
self.forms = {}
self._schema_directory = schema_directory
self._query_directory = query_directory
self.root_dir = get_root_path(name)
self._schema = schema

@property
def schema_directory(self):
if not hasattr(self, '_schema_dir'):
if not hasattr(self, "_schema_dir"):
if os.path.isabs(self._schema_directory):
setattr(self, '_schema_dir', self._schema_directory)
setattr(self, '_schema_dir', os.path.join(self.root_dir, self._schema_directory))
setattr(self, "_schema_dir", self._schema_directory)
setattr(
self, "_schema_dir", os.path.join(self.root_dir, self._schema_directory)
)
return self._schema_dir

def find_schema(self) -> typing.List[DocumentNode]:
schemas: typing.List[DocumentNode] = []
if os.path.isdir(self.schema_directory):
LOG.debug(f'Searching {self.schema_directory} for schema.')
LOG.debug(f"Searching {self.schema_directory} for schema.")
schemas = load_schema(self.schema_directory)

if self._schema is not None:
schemas.append(maybe_parse(self._schema))
for schema in self._schema:
schemas.append(maybe_parse(schema))

return schemas

@property
def query_directory(self) -> str:
if not hasattr(self, '_query_dir'):
if not hasattr(self, "_query_dir"):
if os.path.isabs(self._query_directory):
self._query_dir: str = self._query_directory
self._query_dir = os.path.join(self.root_dir, self._query_directory)
return self._query_dir

@functools.lru_cache(maxsize=128)
def load_query(self, query_name: str) -> DocumentNode:
path = os.path.join(self.query_directory, f'{query_name}.graphql')
path = os.path.join(self.query_directory, f"{query_name}.graphql")
assert os.path.isfile(path), f"No query found for {query_name}"

with open(path) as query:
return parse(query.read())

def resolver(self, type_name: str = 'Query') -> typing.Any:
def resolver(self, type_name: str = "Query") -> typing.Any:
def decorator(function):
self.registry[type_name][function.__name__] = function

return decorator

def datasource(self):
def decorator(klass):
self.datasources[klass.__name__] = klass
return decorator

def get_form_query(self, name: str, **kwargs) -> DocumentNode:
"""Get registered form query document"""
form = self.forms.get(name)
assert form is not None, f'Form: {name} is not registered!'

return form.get_query(**kwargs)

def get_form_mutation(self, name: str, **kwargs) -> DocumentNode:
"""Get registered form mutation document"""
form = self.forms.get(name)
assert form is not None, f'Form: {name} is not registered!'

return form.get_mutation(**kwargs)
return decorator


class API(Resolver):
Expand Down Expand Up @@ -188,7 +177,7 @@ def __init__(

@property
def schema(self) -> GraphQLSchema:
if not hasattr(self, '_full_schema'):
if not hasattr(self, "_full_schema"):
self._full_schema = self._build_schema()
return self._full_schema

Expand All @@ -200,7 +189,6 @@ def _all_schema(self) -> typing.Iterator[DocumentNode]:
self._merge_registry(resolver.registry)
self.base_schema.update(resolver.base_schema)
self.datasources.update(resolver.datasources)
self.forms.update(resolver.forms)
for document_node in resolver.find_schema():
yield document_node

Expand All @@ -212,7 +200,7 @@ def _build_schema(self) -> GraphQLSchema:

schema_validation_errors = validate_schema(schema)
if schema_validation_errors:
raise Exception(f'Invalid schema: {schema_validation_errors}')
raise Exception(f"Invalid schema: {schema_validation_errors}")

schema = fix_abstract_resolve_type(schema)

Expand All @@ -226,13 +214,14 @@ def _make_executable(self, schema: GraphQLSchema):
for field_name, resolver_fn in fields.items():
field_definition = object_type.fields.get(field_name)
if not field_definition:
raise Exception(f'Invalid field {type_name}.{field_name}')
raise Exception(f"Invalid field {type_name}.{field_name}")

field_definition.resolve = resolver_fn

def context(self):
def decorator(klass):
self._context = klass

return decorator

def get_context(self, request):
Expand All @@ -254,7 +243,7 @@ async def call(
self,
document: GraphQLSchema,
request: typing.Any = None,
variables: typing.Dict[str, typing.Any] = None
variables: typing.Dict[str, typing.Any] = None,
) -> ExecutionResult:
"""Preform a query against the schema.
Expand All @@ -281,7 +270,7 @@ def call_sync(
self,
document: GraphQLSchema,
request: typing.Any = None,
variables: typing.Dict[str, typing.Any] = None
variables: typing.Dict[str, typing.Any] = None,
) -> ExecutionResult:
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.call(document, request, variables))
37 changes: 13 additions & 24 deletions cannula/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
import typing
import itertools

from graphql import (
GraphQLSchema,
Expand All @@ -11,32 +12,21 @@
build_ast_schema,
extend_schema,
concat_ast,
is_type_extension_node,
is_type_system_extension_node,
is_type_definition_node,
)

LOG = logging.getLogger(__name__)
QUERY_TYPE = parse("type Query { _empty: String }")
MUTATION_TYPE = parse("type Mutation { _empty: String }")

object_definition_kind = "object_type_definition"
object_extension_kind = "object_type_extension"
interface_extension_kind = "interface_type_extension"
input_object_extension_kind = "input_object_type_extension"
union_extension_kind = "union_type_extension"
enum_extension_kind = "enum_type_extension"

extension_kinds = [
object_extension_kind,
interface_extension_kind,
input_object_extension_kind,
union_extension_kind,
enum_extension_kind,
]


def extract_extensions(ast: DocumentNode) -> DocumentNode:
extensions = [node for node in ast.definitions if node.kind in extension_kinds]

return DocumentNode(definitions=extensions)
type_extensions = filter(is_type_extension_node, ast.definitions)
system_extensions = filter(is_type_system_extension_node, ast.definitions)
extensions = itertools.chain(type_extensions, system_extensions)
return DocumentNode(definitions=list(extensions))


def assert_has_query_and_mutation(ast: DocumentNode) -> DocumentNode:
Expand All @@ -45,11 +35,8 @@ def assert_has_query_and_mutation(ast: DocumentNode) -> DocumentNode:
The schema is pretty much useless without them and rather than causing
an error we'll just add in an empty one so they can be extended.
"""
object_definitions = [
node.name.value
for node in ast.definitions
if node.kind == object_definition_kind
]
object_kinds = filter(is_type_definition_node, ast.definitions)
object_definitions = [node.name.value for node in object_kinds]
has_mutation_definition = "Mutation" in object_definitions
has_query_definition = "Query" in object_definitions

Expand Down Expand Up @@ -89,7 +76,9 @@ def build_and_extend_schema(

if extension_ast.definitions:
LOG.debug("Extending schema")
schema = extend_schema(schema, extension_ast)
schema = extend_schema(
schema, extension_ast, assume_valid=True, assume_valid_sdl=True
)

return schema

Expand Down
53 changes: 53 additions & 0 deletions examples/extends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import cannula
from graphql import GraphQLObjectType, GraphQLString, GraphQLSchema

schema = cannula.gql(
"""
type Brocoli {
taste: String
}
type Message {
text: String
number: Int
float: Float
isOn: Boolean
id: ID
brocoli: Brocoli
}
type Query {
mockity: [Message]
boo: String
}
"""
)

extensions = cannula.gql(
"""
extend type Brocoli {
color: String
}
extend type Query {
fancy: [Message]
}
"""
)

# schema = cannula.build_and_extend_schema([schema, extensions])

api = cannula.API(
__name__,
schema=[schema, extensions],
)

SAMPLE_QUERY = cannula.gql(
"""
query HelloWorld {
fancy {
text
}
}
"""
)


print(api.call_sync(SAMPLE_QUERY))

0 comments on commit cf7d560

Please sign in to comment.