Skip to content

Commit

Permalink
feat: support file upload in router
Browse files Browse the repository at this point in the history
  • Loading branch information
pedraumcosta committed May 8, 2024
1 parent 22c2f9a commit d3d6ec7
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1649,9 +1649,16 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) {
return variables, false
}

func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err error) {
func (s *Source) Load(
ctx context.Context, input []byte, files []httpclient.File, writer io.Writer,
) (err error) {
input = s.compactAndUnNullVariables(input)
return httpclient.Do(s.httpClient, ctx, input, writer)

if files == nil {
return httpclient.Do(s.httpClient, ctx, input, writer)
}

return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, writer)
}

type GraphQLSubscriptionClient interface {
Expand Down
26 changes: 26 additions & 0 deletions v2/pkg/engine/datasource/httpclient/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package httpclient

type File interface {
Path() string
Name() string
}

type internalFile struct {
path string
name string
}

func NewFile(path string, name string) File {
return &internalFile{
path: path,
name: name,
}
}

func (f *internalFile) Path() string {
return f.path
}

func (f *internalFile) Name() string {
return f.name
}
191 changes: 191 additions & 0 deletions v2/pkg/engine/datasource/httpclient/nethttpclient.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package httpclient

import (
"bufio"
"bytes"
"compress/flate"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -219,3 +224,189 @@ func respBodyReader(res *http.Response) (io.Reader, error) {
return res.Body, nil
}
}

func DoMultipartForm(
client *http.Client, ctx context.Context, requestInput []byte, files []File, out io.Writer,
) (err error) {
if files == nil || len(files) == 0 {
return errors.New("no files provided")
}

url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput)

formValues := map[string]io.Reader{
"operations": bytes.NewReader(body),
}

var fileMap string
for i, file := range files {
if len(fileMap) == 0 {
if len(files) == 1 {
fileMap = fmt.Sprintf(`"%d" : ["variables.file"]`, i)
} else {
fileMap = fmt.Sprintf(`"%d" : ["variables.file%d"]`, i, i+1)
}
} else {
fileMap = fmt.Sprintf(`%s, "%d" : ["variables.file%d"]`, fileMap, i, i+1)
}
key := fmt.Sprintf("%d", i)
temporaryFile, err := os.Open(file.Path())
if err != nil {
return err
}
formValues[key] = bufio.NewReader(temporaryFile)
}
formValues["map"] = strings.NewReader("{ " + fileMap + " }")

multipartBody, contentType, err := multipartBytes(formValues, files)
if err != nil {
return err
}

request, err := http.NewRequestWithContext(ctx, string(method), string(url), &multipartBody)
if err != nil {
return err
}

if headers != nil {
err = jsonparser.ObjectEach(headers, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error {
_, err = jsonparser.ArrayEach(value, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
if err != nil {
return
}
if len(value) == 0 {
return
}
request.Header.Add(string(key), string(value))
})
return err
})
if err != nil {
return err
}
}

if queryParams != nil {
query := request.URL.Query()
_, err = jsonparser.ArrayEach(queryParams, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
var (
parameterName, parameterValue []byte
)
jsonparser.EachKey(value, func(i int, bytes []byte, valueType jsonparser.ValueType, err error) {
switch i {
case 0:
parameterName = bytes
case 1:
parameterValue = bytes
}
}, queryParamsKeys...)
if len(parameterName) != 0 && len(parameterValue) != 0 {
if bytes.Equal(parameterValue[:1], literal.LBRACK) {
_, _ = jsonparser.ArrayEach(parameterValue, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
query.Add(string(parameterName), string(value))
})
} else {
query.Add(string(parameterName), string(parameterValue))
}
}
})
if err != nil {
return err
}
request.URL.RawQuery = query.Encode()
}

request.Header.Add(AcceptHeader, ContentTypeJSON)
request.Header.Add(ContentTypeHeader, contentType)
request.Header.Set(AcceptEncodingHeader, EncodingGzip)
request.Header.Add(AcceptEncodingHeader, EncodingDeflate)

response, err := client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
for _, file := range files {
err = os.Remove(file.Path())
if err != nil {
return err
}
}

respReader, err := respBodyReader(response)
if err != nil {
return err
}

if !enableTrace {
_, err = io.Copy(out, respReader)
return
}

buf := &bytes.Buffer{}
_, err = io.Copy(buf, respReader)
if err != nil {
return err
}
responseTrace := TraceHTTP{
Request: TraceHTTPRequest{
Method: request.Method,
URL: request.URL.String(),
Headers: redactHeaders(request.Header),
},
Response: TraceHTTPResponse{
StatusCode: response.StatusCode,
Status: response.Status,
Headers: redactHeaders(response.Header),
BodySize: buf.Len(),
},
}
trace, err := json.Marshal(responseTrace)
if err != nil {
return err
}
responseWithTraceExtension, err := jsonparser.Set(buf.Bytes(), trace, "extensions", "trace")
if err != nil {
return err
}
_, err = out.Write(responseWithTraceExtension)
return err
}

func multipartBytes(values map[string]io.Reader, files []File) (bytes.Buffer, string, error) {
var err error
var b bytes.Buffer
var fw io.Writer
w := multipart.NewWriter(&b)

// First create the fields to control the file upload
valuesInOrder := []string{"operations", "map"}
for _, key := range valuesInOrder {
r := values[key]
if fw, err = w.CreateFormField(key); err != nil {
return b, "", err
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}
}

// Now create one form for each file
for i, file := range files {
key := fmt.Sprintf("%d", i)
r := values[key]
if fw, err = w.CreateFormFile(key, file.Name()); err != nil {
return b, "", err
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}
}

err = w.Close()
if err != nil {
return b, "", err
}

return b, w.FormDataContentType(), nil
}
3 changes: 2 additions & 1 deletion v2/pkg/engine/datasource/introspection_datasource/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package introspection_datasource
import (
"context"
"encoding/json"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"

"github.com/wundergraph/graphql-go-tools/v2/pkg/introspection"
Expand All @@ -16,7 +17,7 @@ type Source struct {
introspectionData *introspection.Data
}

func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) {
func (s *Source) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) {
var req introspectionInput
if err := json.Unmarshal(input, &req); err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"
"regexp"
"strings"
Expand Down Expand Up @@ -290,7 +291,7 @@ type PublishDataSource struct {
pubSub PubSub
}

func (s *PublishDataSource) Load(ctx context.Context, input []byte, w io.Writer) error {
func (s *PublishDataSource) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error {
subject, err := jsonparser.GetString(input, "subject")
if err != nil {
return fmt.Errorf("error getting subject from input: %w", err)
Expand All @@ -312,7 +313,7 @@ type RequestDataSource struct {
pubSub PubSub
}

func (s *RequestDataSource) Load(ctx context.Context, input []byte, w io.Writer) error {
func (s *RequestDataSource) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error {
subject, err := jsonparser.GetString(input, "subject")
if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package staticdatasource

import (
"context"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"

"github.com/jensneuse/abstractlogger"
Expand Down Expand Up @@ -65,7 +66,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration {

type Source struct{}

func (Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) {
func (Source) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) {
_, err = w.Write(input)
return
}
4 changes: 4 additions & 0 deletions v2/pkg/engine/resolve/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (
"net/http"
"time"

"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"go.uber.org/atomic"
)

type Context struct {
ctx context.Context
Variables []byte
Files []httpclient.File
Request Request
RenameTypeNames []RenameTypeName
TracingOptions TraceOptions
Expand Down Expand Up @@ -141,6 +143,7 @@ func (c *Context) clone(ctx context.Context) *Context {
cpy := *c
cpy.ctx = ctx
cpy.Variables = append([]byte(nil), c.Variables...)
cpy.Files = append([]httpclient.File(nil), c.Files...)
cpy.Request.Header = c.Request.Header.Clone()
cpy.RenameTypeNames = append([]RenameTypeName(nil), c.RenameTypeNames...)
return &cpy
Expand All @@ -149,6 +152,7 @@ func (c *Context) clone(ctx context.Context) *Context {
func (c *Context) Free() {
c.ctx = nil
c.Variables = nil
c.Files = nil
c.Request.Header = nil
c.RenameTypeNames = nil
c.TracingOptions.DisableAll()
Expand Down
3 changes: 2 additions & 1 deletion v2/pkg/engine/resolve/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"io"

"github.com/cespare/xxhash/v2"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
)

type DataSource interface {
Load(ctx context.Context, input []byte, w io.Writer) (err error)
Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error)
}

type SubscriptionDataSource interface {
Expand Down
4 changes: 2 additions & 2 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1363,11 +1363,11 @@ func (l *Loader) executeSourceLoad(ctx context.Context, source DataSource, input
if res.loaderHookContext != nil {
res.err = source.Load(res.loaderHookContext, input, res.out)
} else {
res.err = source.Load(ctx, input, res.out)
res.err = source.Load(ctx, input, l.ctx.Files, res.out)
}

} else {
res.err = source.Load(ctx, input, res.out)
res.err = source.Load(ctx, input, l.ctx.Files, res.out)
}

res.statusCode = responseContext.StatusCode
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/variablesvalidation/variablesvalidation.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (v *variablesVisitor) traverseOperationType(jsonFieldRef int, operationType
v.renderVariableRequiredError(v.currentVariableName, operationTypeRef)
return
}
if v.variables.Nodes[jsonFieldRef].Kind == astjson.NodeKindNull {
if v.variables.Nodes[jsonFieldRef].Kind == astjson.NodeKindNull && varTypeName.String() != "Upload" {
v.renderVariableInvalidNullError(v.currentVariableName, operationTypeRef)
return
}
Expand Down

0 comments on commit d3d6ec7

Please sign in to comment.