Skip to content

Commit

Permalink
Create requests.MakeMultiPartRequest to allow files upload (#801)
Browse files Browse the repository at this point in the history
* extract performRequest

* create requests.MakeMultiPartRequest for file upload

* lint

* Apply suggestions from code review

Co-authored-by: htheodore-stripe <89876392+htheodore-stripe@users.noreply.github.com>

* lint

Co-authored-by: htheodore-stripe <89876392+htheodore-stripe@users.noreply.github.com>
  • Loading branch information
pepin-stripe and htheodore-stripe committed Dec 14, 2021
1 parent b424ed8 commit c1971b9
Show file tree
Hide file tree
Showing 4 changed files with 32,950 additions and 9 deletions.
79 changes: 70 additions & 9 deletions pkg/requests/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package requests

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -173,8 +176,33 @@ func (rb *Base) InitFlags() {
rb.Cmd.Flags().MarkHidden("api-base") // #nosec G104
}

// MakeMultiPartRequest will make a multipart/form-data request to the Stripe API with the specific variables given to it.
// Similar to making a multipart request using curl, add the local filepath to params arg with @ prefix.
// e.g. params.AppendData([]string{"photo=@/path/to/local/file.png"})
func (rb *Base) MakeMultiPartRequest(ctx context.Context, apiKey, path string, params *RequestParameters, errOnStatus bool) ([]byte, error) {
reqBody, contentType, err := rb.buildMultiPartRequest(params)
if err != nil {
return []byte{}, err
}

configure := func(req *http.Request) {
req.Header.Set("Content-Type", contentType)
}

return rb.performRequest(ctx, apiKey, path, params, reqBody.String(), errOnStatus, configure)
}

// MakeRequest will make a request to the Stripe API with the specific variables given to it
func (rb *Base) MakeRequest(ctx context.Context, apiKey, path string, params *RequestParameters, errOnStatus bool) ([]byte, error) {
data, err := rb.buildDataForRequest(params)
if err != nil {
return []byte{}, err
}

return rb.performRequest(ctx, apiKey, path, params, data, errOnStatus, nil)
}

func (rb *Base) performRequest(ctx context.Context, apiKey, path string, params *RequestParameters, data string, errOnStatus bool, additionalConfigure func(req *http.Request)) ([]byte, error) {
parsedBaseURL, err := url.Parse(rb.APIBaseURL)
if err != nil {
return []byte{}, err
Expand All @@ -186,29 +214,27 @@ func (rb *Base) MakeRequest(ctx context.Context, apiKey, path string, params *Re
Verbose: rb.showHeaders,
}

data, err := rb.buildDataForRequest(params)
if err != nil {
return []byte{}, err
}

configureReq := func(req *http.Request) {
configure := func(req *http.Request) {
rb.setIdempotencyHeader(req, params)
rb.setStripeAccountHeader(req, params)
rb.setVersionHeader(req, params)
if additionalConfigure != nil {
additionalConfigure(req)
}
}

resp, err := client.PerformRequest(ctx, rb.Method, path, data, configureReq)
resp, err := client.PerformRequest(ctx, rb.Method, path, data, configure)

if err != nil {
return []byte{}, err
}

defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)

if resp.StatusCode == 401 || (errOnStatus && resp.StatusCode >= 300) {
requestError := compileRequestError(body, resp.StatusCode)
return nil, requestError
return []byte{}, requestError
}

if !rb.SuppressOutput {
Expand Down Expand Up @@ -297,6 +323,41 @@ func (rb *Base) buildDataForRequest(params *RequestParameters) (string, error) {
return encode(keys, values), nil
}

func (rb *Base) buildMultiPartRequest(params *RequestParameters) (*bytes.Buffer, string, error) {
var body bytes.Buffer
mp := multipart.NewWriter(&body)
defer mp.Close()
for _, datum := range params.data {
splitDatum := strings.SplitN(datum, "=", 2)

if len(splitDatum) < 2 {
return nil, "", fmt.Errorf("Invalid data argument: %s", datum)
}

key := splitDatum[0]
val := splitDatum[1]

// Param values that are prefixed with @ will be parsed as a form file
if strings.HasPrefix(val, "@") {
val = val[1:]
file, err := os.Open(val)
if err != nil {
return nil, "", err
}
defer file.Close()
part, err := mp.CreateFormFile(key, val)
if err != nil {
return nil, "", err
}
io.Copy(part, file)
} else {
mp.WriteField(key, val)
}
}

return &body, mp.FormDataContentType(), nil
}

// encode creates a url encoded string with the request parameters
func encode(keys []string, values []string) string {
var buf strings.Builder
Expand Down
37 changes: 37 additions & 0 deletions pkg/requests/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -155,6 +156,42 @@ func TestMakeRequest_ErrOnAPIKeyExpired(t *testing.T) {
require.Contains(t, err.Error(), "Request failed, status=401, body=")
}

func TestMakeMultiPartRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("FILES!"))

reqBody, err := ioutil.ReadAll(r.Body)
require.NoError(t, err)

require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/foo/bar", r.URL.Path)
require.Equal(t, "Bearer sk_test_1234", r.Header.Get("Authorization"))
require.NotEmpty(t, r.UserAgent())
require.NotEmpty(t, r.Header.Get("X-Stripe-Client-User-Agent"))
require.Contains(t, r.Header.Get("Content-Type"), "multipart/form-data")
require.Contains(t, string(reqBody), "purpose")
require.Contains(t, string(reqBody), "app_upload")
}))
defer ts.Close()

rb := Base{APIBaseURL: ts.URL}
rb.Method = http.MethodPost

tempFile, err := os.CreateTemp("", "upload.zip")
if err != nil {
t.Error("Error creating temp file")
}
defer os.Remove(tempFile.Name())

params := &RequestParameters{
data: []string{"purpose=app_upload", fmt.Sprintf("file=@%v", tempFile.Name())},
}

_, err = rb.MakeMultiPartRequest(context.Background(), "sk_test_1234", "/foo/bar", params, true)
require.NoError(t, err)
}

func TestGetUserConfirmationRequired(t *testing.T) {
reader := bufio.NewReader(strings.NewReader("yes\n"))

Expand Down
3 changes: 3 additions & 0 deletions pkg/stripe/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
// DefaultAPIBaseURL is the default base URL for API requests
const DefaultAPIBaseURL = "https://api.stripe.com"

// DefaultFilesAPIBaseURL is the default base URL for Files API requsts
const DefaultFilesAPIBaseURL = "https://files.stripe.com"

// DefaultDashboardBaseURL is the default base URL for dashboard requests
const DefaultDashboardBaseURL = "https://dashboard.stripe.com"

Expand Down

0 comments on commit c1971b9

Please sign in to comment.