Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic query complexity #612

Merged
merged 8 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/Annotations/Cost.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<?php

declare(strict_types=1);

namespace TheCodingMachine\GraphQLite\Annotations;

use Attribute;

#[Attribute(Attribute::TARGET_PROPERTY | Attribute::TARGET_METHOD)]
class Cost implements MiddlewareAnnotationInterface
oprypkhantc marked this conversation as resolved.
Show resolved Hide resolved
{
/**
* @param int $complexity Complexity for that field
* @param string[] $multipliers Names of fields by value of which complexity will be multiplied
* @param ?int $defaultMultiplier Default multiplier value if all multipliers are missing/null
*/
public function __construct(
public readonly int $complexity = 1,
public readonly array $multipliers = [],
public readonly int|null $defaultMultiplier = null,
) {
}
}
12 changes: 11 additions & 1 deletion src/Annotations/MiddlewareAnnotations.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ public function __construct(private array $annotations)
/**
* Return annotations of the $className type
*
* @return array<int, MiddlewareAnnotationInterface>
* @param class-string<TAnnotation> $className
*
* @return array<int, TAnnotation>
*
* @template TAnnotation of MiddlewareAnnotationInterface
*/
public function getAnnotationsByType(string $className): array
{
Expand All @@ -32,6 +36,12 @@ public function getAnnotationsByType(string $className): array

/**
* Returns at most 1 annotation of the $className type.
*
* @param class-string<TAnnotation> $className
*
* @return TAnnotation|null
*
* @template TAnnotation of MiddlewareAnnotationInterface
*/
public function getAnnotationByType(string $className): MiddlewareAnnotationInterface|null
{
Expand Down
34 changes: 34 additions & 0 deletions src/Http/Psr15GraphQLMiddlewareBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
use GraphQL\Error\DebugFlag;
use GraphQL\Server\ServerConfig;
use GraphQL\Type\Schema;
use GraphQL\Validator\DocumentValidator;
use GraphQL\Validator\Rules\QueryComplexity;
use GraphQL\Validator\Rules\ValidationRule;
use Laminas\Diactoros\ResponseFactory;
use Laminas\Diactoros\StreamFactory;
use Psr\Http\Message\ResponseFactoryInterface;
Expand All @@ -21,6 +24,7 @@
use TheCodingMachine\GraphQLite\Server\PersistedQuery\NotSupportedPersistedQueryLoader;

use function class_exists;
use function is_callable;

/**
* A factory generating a PSR-15 middleware tailored for GraphQLite.
Expand All @@ -38,6 +42,9 @@ class Psr15GraphQLMiddlewareBuilder

private HttpCodeDeciderInterface $httpCodeDecider;

/** @var ValidationRule[] */
private array $addedValidationRules = [];

public function __construct(Schema $schema)
{
$this->config = new ServerConfig();
Expand Down Expand Up @@ -97,6 +104,18 @@ public function useAutomaticPersistedQueries(CacheInterface $cache, DateInterval
return $this;
}

public function limitQueryComplexity(int $complexity): self
{
return $this->addValidationRule(new QueryComplexity($complexity));
}

public function addValidationRule(ValidationRule $rule): self
{
$this->addedValidationRules[] = $rule;

return $this;
}

public function createMiddleware(): MiddlewareInterface
{
if ($this->responseFactory === null && ! class_exists(ResponseFactory::class)) {
Expand All @@ -109,6 +128,21 @@ public function createMiddleware(): MiddlewareInterface
}
$this->streamFactory = $this->streamFactory ?: new StreamFactory();

// If getValidationRules() is null in the config, DocumentValidator will default to DocumentValidator::allRules().
// So if we only added given rule, all of the default rules would not be validated, so we must also provide them.
$originalValidationRules = $this->config->getValidationRules() ?? DocumentValidator::allRules();

$this->config->setValidationRules(function (...$args) use ($originalValidationRules) {
if (is_callable($originalValidationRules)) {
$originalValidationRules = $originalValidationRules(...$args);
}

return [
...$originalValidationRules,
...$this->addedValidationRules,
];
});

return new WebonyxGraphqlMiddleware($this->config, $this->responseFactory, $this->streamFactory, $this->httpCodeDecider, $this->url);
}
}
4 changes: 0 additions & 4 deletions src/Middlewares/AuthorizationFieldMiddleware.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,15 @@ public function process(QueryFieldDescriptor $queryFieldDescriptor, FieldHandler
$annotations = $queryFieldDescriptor->getMiddlewareAnnotations();

$loggedAnnotation = $annotations->getAnnotationByType(Logged::class);
assert($loggedAnnotation === null || $loggedAnnotation instanceof Logged);
$rightAnnotation = $annotations->getAnnotationByType(Right::class);
assert($rightAnnotation === null || $rightAnnotation instanceof Right);

// Avoid wrapping resolver callback when no annotations are specified.
if (! $loggedAnnotation && ! $rightAnnotation) {
return $fieldHandler->handle($queryFieldDescriptor);
}

$failWith = $annotations->getAnnotationByType(FailWith::class);
assert($failWith === null || $failWith instanceof FailWith);
$hideIfUnauthorized = $annotations->getAnnotationByType(HideIfUnauthorized::class);
assert($hideIfUnauthorized instanceof HideIfUnauthorized || $hideIfUnauthorized === null);

if ($failWith !== null && $hideIfUnauthorized !== null) {
throw IncompatibleAnnotationsException::cannotUseFailWithAndHide();
Expand Down
3 changes: 0 additions & 3 deletions src/Middlewares/AuthorizationInputFieldMiddleware.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,14 @@ public function process(InputFieldDescriptor $inputFieldDescriptor, InputFieldHa
$annotations = $inputFieldDescriptor->getMiddlewareAnnotations();

$loggedAnnotation = $annotations->getAnnotationByType(Logged::class);
assert($loggedAnnotation === null || $loggedAnnotation instanceof Logged);
$rightAnnotation = $annotations->getAnnotationByType(Right::class);
assert($rightAnnotation === null || $rightAnnotation instanceof Right);

// Avoid wrapping resolver callback when no annotations are specified.
if (! $loggedAnnotation && ! $rightAnnotation) {
return $inputFieldHandler->handle($inputFieldDescriptor);
}

$hideIfUnauthorized = $annotations->getAnnotationByType(HideIfUnauthorized::class);
assert($hideIfUnauthorized instanceof HideIfUnauthorized || $hideIfUnauthorized === null);

if ($hideIfUnauthorized !== null && ! $this->isAuthorized($loggedAnnotation, $rightAnnotation)) {
return null;
Expand Down
73 changes: 73 additions & 0 deletions src/Middlewares/CostFieldMiddleware.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<?php

declare(strict_types=1);

namespace TheCodingMachine\GraphQLite\Middlewares;

use GraphQL\Type\Definition\FieldDefinition;
use TheCodingMachine\GraphQLite\Annotations\Cost;
use TheCodingMachine\GraphQLite\QueryFieldDescriptor;

use function implode;
use function is_int;

/**
* Reference implementation: https://github.com/ChilliCream/graphql-platform/blob/388f5c988bbb806e46e2315f1844ea5bb63096f2/src/HotChocolate/Core/src/Execution/Options/ComplexityAnalyzerSettings.cs#L58
*/
class CostFieldMiddleware implements FieldMiddlewareInterface
{
public function process(QueryFieldDescriptor $queryFieldDescriptor, FieldHandlerInterface $fieldHandler): FieldDefinition|null
{
$costAttribute = $queryFieldDescriptor->getMiddlewareAnnotations()->getAnnotationByType(Cost::class);

if (! $costAttribute) {
return $fieldHandler->handle($queryFieldDescriptor);
}

$field = $fieldHandler->handle(
$queryFieldDescriptor->withAddedCommentLines($this->buildQueryComment($costAttribute)),
);

if (! $field) {
return $field;
}

$field->complexityFn = static function (int $childrenComplexity, array $fieldArguments) use ($costAttribute): int {
if (! $costAttribute->multipliers) {
return $costAttribute->complexity + $childrenComplexity;
}

$cost = $costAttribute->complexity + $childrenComplexity;
$needsDefaultMultiplier = true;

foreach ($costAttribute->multipliers as $multiplier) {
$value = $fieldArguments[$multiplier] ?? null;

if (! is_int($value)) {
continue;
}

$cost *= $value;
$needsDefaultMultiplier = false;
}

if ($needsDefaultMultiplier && $costAttribute->defaultMultiplier !== null) {
$cost *= $costAttribute->defaultMultiplier;
}

return $cost;
};

return $field;
}

private function buildQueryComment(Cost $costAttribute): string
{
return 'Cost: ' .
oprypkhantc marked this conversation as resolved.
Show resolved Hide resolved
implode(', ', [
'complexity = ' . $costAttribute->complexity,
'multipliers = [' . implode(', ', $costAttribute->multipliers) . ']',
'defaultMultiplier = ' . ($costAttribute->defaultMultiplier ?? 'null'),
]);
}
}
1 change: 0 additions & 1 deletion src/Middlewares/SecurityFieldMiddleware.php
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public function process(QueryFieldDescriptor $queryFieldDescriptor, FieldHandler
}

$failWith = $annotations->getAnnotationByType(FailWith::class);
assert($failWith instanceof FailWith || $failWith === null);

// If the failWith value is null and the return type is non nullable, we must set it to nullable.
$makeReturnTypeNullable = false;
Expand Down
9 changes: 9 additions & 0 deletions src/QueryFieldDescriptor.php
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ public function withComment(string|null $comment): self
return $this->with(comment: $comment);
}

public function withAddedCommentLines(string $comment): self
{
if (! $this->comment) {
return $this->withComment($comment);
}

return $this->withComment($this->comment . "\n" . $comment);
}

public function getDeprecationReason(): string|null
{
return $this->deprecationReason;
Expand Down
12 changes: 6 additions & 6 deletions src/SchemaFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
use TheCodingMachine\GraphQLite\Mappers\TypeMapperInterface;
use TheCodingMachine\GraphQLite\Middlewares\AuthorizationFieldMiddleware;
use TheCodingMachine\GraphQLite\Middlewares\AuthorizationInputFieldMiddleware;
use TheCodingMachine\GraphQLite\Middlewares\CostFieldMiddleware;
use TheCodingMachine\GraphQLite\Middlewares\FieldMiddlewareInterface;
use TheCodingMachine\GraphQLite\Middlewares\FieldMiddlewarePipe;
use TheCodingMachine\GraphQLite\Middlewares\InputFieldMiddlewareInterface;
Expand Down Expand Up @@ -211,9 +212,7 @@ public function addParameterMiddleware(ParameterMiddlewareInterface $parameterMi
return $this;
}

/**
* @deprecated Use PHP8 Attributes instead
*/
/** @deprecated Use PHP8 Attributes instead */
public function setDoctrineAnnotationReader(Reader $annotationReader): self
{
$this->doctrineAnnotationReader = $annotationReader;
Expand Down Expand Up @@ -349,7 +348,7 @@ public function createSchema(): Schema

$namespaceFactory = new NamespaceFactory($namespacedCache, $this->classNameMapper, $this->globTTL);
$nsList = array_map(
static fn(string $namespace) => $namespaceFactory->createNamespace($namespace),
static fn (string $namespace) => $namespaceFactory->createNamespace($namespace),
$this->typeNamespaces,
);

Expand All @@ -363,6 +362,7 @@ public function createSchema(): Schema
// TODO: add a logger to the SchemaFactory and make use of it everywhere (and most particularly in SecurityFieldMiddleware)
$fieldMiddlewarePipe->pipe(new SecurityFieldMiddleware($expressionLanguage, $authenticationService, $authorizationService));
$fieldMiddlewarePipe->pipe(new AuthorizationFieldMiddleware($authenticationService, $authorizationService));
$fieldMiddlewarePipe->pipe(new CostFieldMiddleware());

$inputFieldMiddlewarePipe = new InputFieldMiddlewarePipe();
foreach ($this->inputFieldMiddlewares as $inputFieldMiddleware) {
Expand Down Expand Up @@ -390,7 +390,7 @@ public function createSchema(): Schema
$rootTypeMapper = new MyCLabsEnumTypeMapper($rootTypeMapper, $annotationReader, $symfonyCache, $nsList);
}

if (!empty($this->rootTypeMapperFactories)) {
if (! empty($this->rootTypeMapperFactories)) {
$rootSchemaFactoryContext = new RootTypeMapperFactoryContext(
$annotationReader,
$typeResolver,
Expand Down Expand Up @@ -458,7 +458,7 @@ public function createSchema(): Schema
));
}

if (!empty($this->typeMapperFactories) || !empty($this->queryProviderFactories)) {
if (! empty($this->typeMapperFactories) || ! empty($this->queryProviderFactories)) {
$context = new FactoryContext(
$annotationReader,
$typeResolver,
Expand Down
16 changes: 16 additions & 0 deletions tests/Fixtures/Integration/Controllers/ArticleController.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,27 @@

namespace TheCodingMachine\GraphQLite\Fixtures\Integration\Controllers;

use TheCodingMachine\GraphQLite\Annotations\Cost;
use TheCodingMachine\GraphQLite\Annotations\Mutation;
use TheCodingMachine\GraphQLite\Annotations\Query;
use TheCodingMachine\GraphQLite\Fixtures\Integration\Models\Article;
use TheCodingMachine\GraphQLite\Fixtures\Integration\Models\Contact;
use TheCodingMachine\GraphQLite\Fixtures\Integration\Models\User;

class ArticleController
{
/**
* @return Article[]
*/
#[Query]
#[Cost(complexity: 5, multipliers: ['take'], defaultMultiplier: 500)]
public function articles(?int $take = 10): array
{
return [
new Article('Title'),
];
}


/**
* @Mutation()
Expand Down
3 changes: 3 additions & 0 deletions tests/Fixtures/Integration/Models/Post.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace TheCodingMachine\GraphQLite\Fixtures\Integration\Models;

use DateTimeInterface;
use TheCodingMachine\GraphQLite\Annotations\Cost;
use TheCodingMachine\GraphQLite\Annotations\Field;
use TheCodingMachine\GraphQLite\Annotations\Input;
use TheCodingMachine\GraphQLite\Annotations\Type;
Expand Down Expand Up @@ -38,6 +39,7 @@ class Post
* @Field(name="comment")
* @var string|null
*/
#[Cost(complexity: 5)]
private $description = 'foo';

/**
Expand All @@ -50,6 +52,7 @@ class Post
* @Field()
* @var Contact|null
*/
#[Cost(complexity: 3)]
public $author = null;

/**
Expand Down
Loading
Loading