From fcfac96094ef3606bba3257dd1b64df57b4192e8 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 30 Apr 2024 12:24:09 -0400 Subject: [PATCH 1/7] add ability to specify custom headers and metadata for control and data plane requests, update tests, add verbose to just test, update NewClientParams to take in a custom rest client to allow for mocking in unit tests --- internal/mocks/mock_transport.go | 32 +++++++++++++++++++ justfile | 2 +- pinecone/client.go | 39 ++++++++++++++++------- pinecone/client_test.go | 48 +++++++++++++++++++++++++++-- pinecone/index_connection.go | 14 +++++++-- pinecone/index_connection_test.go | 51 ++++++++++++++++++++++++++++--- pinecone/models.go | 2 +- 7 files changed, 166 insertions(+), 22 deletions(-) create mode 100644 internal/mocks/mock_transport.go diff --git a/internal/mocks/mock_transport.go b/internal/mocks/mock_transport.go new file mode 100644 index 0000000..70931ca --- /dev/null +++ b/internal/mocks/mock_transport.go @@ -0,0 +1,32 @@ +package mocks + +import ( + "bytes" + "io" + "net/http" +) + +type MockTransport struct { + Req *http.Request + Resp *http.Response + Err error +} + +func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + m.Req = req + return m.Resp, m.Err +} + +func CreateMockClient(jsonBody string) *http.Client { + return &http.Client { + Transport: &MockTransport{ + Resp: &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte(jsonBody))), + Header: make(http.Header), + }, + }, + } +} + +var jsonYes = `{"message": "success"}` \ No newline at end of file diff --git a/justfile b/justfile index 2196223..64d19e0 100644 --- a/justfile +++ b/justfile @@ -3,7 +3,7 @@ test: set -o allexport source .env set +o allexport - go test -count=1 ./pinecone + go test -count=1 -v ./pinecone bootstrap: go install google.golang.org/protobuf/cmd/protoc-gen-go@latest go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest diff --git a/pinecone/client.go b/pinecone/client.go index d701480..ab1c033 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -4,51 +4,68 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" "github.com/pinecone-io/go-pinecone/internal/provider" "github.com/pinecone-io/go-pinecone/internal/useragent" - "io" - "net/http" ) type Client struct { apiKey string restClient *control.Client sourceTag string + headers map[string]string } type NewClientParams struct { ApiKey string SourceTag string // optional + Headers map[string]string // optional + RestClient *http.Client // optional } func NewClient(in NewClientParams) (*Client, error) { + clientOptions := []control.ClientOption{} apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) if err != nil { return nil, err } - + clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - client, err := control.NewClient("https://api.pinecone.io", - control.WithRequestEditorFn(apiKeyProvider.Intercept), - control.WithRequestEditorFn(userAgentProvider.Intercept), - ) + for key, value := range in.Headers { + headerProvider := provider.NewHeaderProvider(key, value) + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + client, err := control.NewClient("https://api.pinecone.io", clientOptions...) if err != nil { return nil, err } - c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag} + c := Client{apiKey: in.ApiKey, restClient: client, sourceTag: in.SourceTag, headers: in.Headers} return &c, nil } func (c *Client) Index(host string) (*IndexConnection, error) { - return c.IndexWithNamespace(host, "") + return c.IndexWithAdditionalMetadata(host, "", nil) } func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnection, error) { - idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag) + return c.IndexWithAdditionalMetadata(host, namespace, nil) +} + +func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { + idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag, additionalMetadata) if err != nil { return nil, err } @@ -65,7 +82,7 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) } - + fmt.Printf("res.Body: %+v", res.Body) var indexList control.IndexList err = json.NewDecoder(res.Body).Decode(&indexList) if err != nil { diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 446d2d3..e591ae0 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -4,9 +4,11 @@ import ( "context" "fmt" "os" + "reflect" "testing" "github.com/google/uuid" + "github.com/pinecone-io/go-pinecone/internal/mocks" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -66,8 +68,11 @@ func (ts *ClientTests) TestNewClientParamsSet() { if client.sourceTag != "" { ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) } + if client.headers != nil { + ts.FailNow(fmt.Sprintf("Expected client to have nil headers, but got '%v'", client.headers)) + } if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected 2 request editors on client") + ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) } } @@ -85,7 +90,46 @@ func (ts *ClientTests) TestNewClientParamsSetSourceTag() { ts.FailNow(fmt.Sprintf("Expected client to have sourceTag '%s', but got '%s'", sourceTag, client.sourceTag)) } if len(client.restClient.RequestEditors) != 2 { - ts.FailNow("Expected 2 request editors on client") + ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) + } +} + +func (ts *ClientTests) TestNewClientParamsSetHeaders() { + apiKey := "test-api-key" + headers := map[string]string{"test-header": "test-value"} + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers}) + if err != nil { + ts.FailNow(err.Error()) + } + if client.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected client to have apiKey '%s', but got '%s'", apiKey, client.apiKey)) + } + if !reflect.DeepEqual(client.headers, headers) { + ts.FailNow(fmt.Sprintf("Expected client to have headers '%+v', but got '%+v'", headers, client.headers)) + } + if len(client.restClient.RequestEditors) != 3 { + ts.FailNow(fmt.Sprintf("Expected client to have '%v' request editors, but got '%v'", 3, len(client.restClient.RequestEditors))) + } +} + +func (ts *ClientTests) TestHeadersAppliedToRequests() { + apiKey := "test-api-key" + headers := map[string]string{"test-header": "123456"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + testHeaderValue := mockTransport.Req.Header.Get("test-header") + if testHeaderValue != "123456" { + ts.FailNow(fmt.Sprintf("Expected request to have header value '123456', but got '%s'", testHeaderValue)) } } diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index 9a2131e..a52d932 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -16,11 +16,12 @@ import ( type IndexConnection struct { Namespace string apiKey string + additionalMetadata map[string]string dataClient *data.VectorServiceClient grpcConn *grpc.ClientConn } -func newIndexConnection(apiKey string, host string, namespace string, sourceTag string) (*IndexConnection, error) { +func newIndexConnection(apiKey string, host string, namespace string, sourceTag string, additionalMetadata map[string]string) (*IndexConnection, error) { config := &tls.Config{} target := fmt.Sprintf("%s:443", host) conn, err := grpc.Dial( @@ -38,7 +39,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag dataClient := data.NewVectorServiceClient(conn) - idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn} + idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: additionalMetadata} return &idx, nil } @@ -360,5 +361,12 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { } func (idx *IndexConnection) akCtx(ctx context.Context) context.Context { - return metadata.AppendToOutgoingContext(ctx, "api-key", idx.apiKey) + newMetadata := []string{} + newMetadata = append(newMetadata, "api-key", idx.apiKey) + + for key, value := range idx.additionalMetadata{ + newMetadata = append(newMetadata, key, value) + } + + return metadata.AppendToOutgoingContext(ctx, newMetadata...) } diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index ffdca5f..d0d6780 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "reflect" "testing" "github.com/google/uuid" @@ -18,6 +19,7 @@ type IndexConnectionTests struct { apiKey string idxConn *IndexConnection sourceTag string + metadata map[string]string idxConnSourceTag *IndexConnection vectorIds []string } @@ -36,6 +38,10 @@ func TestIndexConnection(t *testing.T) { assert.NotEmptyf(t, podIndexName, "TEST_POD_INDEX_NAME env variable not set") podIdx, err := client.DescribeIndex(context.Background(), podIndexName) + if err != nil { + t.FailNow() + } + podTestSuite := new(IndexConnectionTests) podTestSuite.host = podIdx.Host podTestSuite.dimension = podIdx.Dimension @@ -45,6 +51,9 @@ func TestIndexConnection(t *testing.T) { assert.NotEmptyf(t, serverlessIndexName, "TEST_SERVERLESS_INDEX_NAME env variable not set") serverlessIdx, err := client.DescribeIndex(context.Background(), serverlessIndexName) + if err != nil { + t.FailNow() + } serverlessTestSuite := new(IndexConnectionTests) serverlessTestSuite.host = serverlessIdx.Host @@ -62,12 +71,12 @@ func (ts *IndexConnectionTests) SetupSuite() { namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), "") + idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), "", nil) assert.NoError(ts.T(), err) ts.idxConn = idxConn ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag) + idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag, nil) assert.NoError(ts.T(), err) ts.idxConnSourceTag = idxConnSourceTag @@ -79,13 +88,16 @@ func (ts *IndexConnectionTests) TearDownSuite() { err := ts.idxConn.Close() assert.NoError(ts.T(), err) + + err = ts.idxConnSourceTag.Close() + assert.NoError(ts.T(), err) } func (ts *IndexConnectionTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag, nil) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { @@ -94,19 +106,47 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { if idxConn.Namespace != "" { ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) } + if idxConn.additionalMetadata != nil { + ts.FailNow(fmt.Sprintf("Expected idxConn to have nil additionalMetadata, but got '%+v'", idxConn.additionalMetadata)) + } if idxConn.dataClient == nil { ts.FailNow("Expected idxConn to have non-nil dataClient") } if idxConn.grpcConn == nil { ts.FailNow("Expected idxConn to have non-nil grpcConn") } + if idxConn.additionalMetadata != nil { + ts.FailNow("Expected idxConn to have nil additionalMetadata") + } } func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag) + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag, nil) + assert.NoError(ts.T(), err) + + if idxConn.apiKey != apiKey { + ts.FailNow(fmt.Sprintf("Expected idxConn to have apiKey '%s', but got '%s'", apiKey, idxConn.apiKey)) + } + if idxConn.Namespace != namespace { + ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) + } + if idxConn.dataClient == nil { + ts.FailNow("Expected idxConn to have non-nil dataClient") + } + if idxConn.grpcConn == nil { + ts.FailNow("Expected idxConn to have non-nil grpcConn") + } +} + +func (ts *IndexConnectionTests) TestNewIndexConnectionAdditionalMetadata() { + apiKey := "test-api-key" + namespace := "test-namespace" + sourceTag := "test-source-tag" + additionalMetadata := map[string]string{"test-header": "test-value"} + idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag, additionalMetadata) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { @@ -115,6 +155,9 @@ func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { if idxConn.Namespace != namespace { ts.FailNow(fmt.Sprintf("Expected idxConn to have namespace '%s', but got '%s'", namespace, idxConn.Namespace)) } + if !reflect.DeepEqual(idxConn.additionalMetadata, additionalMetadata) { + ts.FailNow(fmt.Sprintf("Expected idxConn to have additionalMetadata '%+v', but got '%+v'", additionalMetadata, idxConn.additionalMetadata)) + } if idxConn.dataClient == nil { ts.FailNow("Expected idxConn to have non-nil dataClient") } diff --git a/pinecone/models.go b/pinecone/models.go index 4d00443..3ce895b 100644 --- a/pinecone/models.go +++ b/pinecone/models.go @@ -112,4 +112,4 @@ type Usage struct { } type Filter = structpb.Struct -type Metadata = structpb.Struct +type Metadata = structpb.Struct \ No newline at end of file From b3117707d0855c9cea8e20858e1045b77abf469f Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 30 Apr 2024 12:37:27 -0400 Subject: [PATCH 2/7] remove unneeded variable declaration --- internal/mocks/mock_transport.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/mocks/mock_transport.go b/internal/mocks/mock_transport.go index 70931ca..25ddb3f 100644 --- a/internal/mocks/mock_transport.go +++ b/internal/mocks/mock_transport.go @@ -27,6 +27,4 @@ func CreateMockClient(jsonBody string) *http.Client { }, }, } -} - -var jsonYes = `{"message": "success"}` \ No newline at end of file +} \ No newline at end of file From 233a473e9058ec7c9d0db46c4171f3e9ef537de9 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 30 Apr 2024 16:34:08 -0400 Subject: [PATCH 3/7] remove print statement --- pinecone/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pinecone/client.go b/pinecone/client.go index ab1c033..7633f29 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -82,7 +82,7 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) } - fmt.Printf("res.Body: %+v", res.Body) + var indexList control.IndexList err = json.NewDecoder(res.Body).Decode(&indexList) if err != nil { From 395be590fa48594dab42a2fb2d1c97c81973a154 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 30 Apr 2024 16:50:38 -0400 Subject: [PATCH 4/7] tweak testing strings --- pinecone/client_test.go | 2 +- pinecone/index_connection_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pinecone/client_test.go b/pinecone/client_test.go index e591ae0..59bb678 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -69,7 +69,7 @@ func (ts *ClientTests) TestNewClientParamsSet() { ts.FailNow(fmt.Sprintf("Expected client to have empty sourceTag, but got '%s'", client.sourceTag)) } if client.headers != nil { - ts.FailNow(fmt.Sprintf("Expected client to have nil headers, but got '%v'", client.headers)) + ts.FailNow(fmt.Sprintf("Expected client headers to be nil, but got '%v'", client.headers)) } if len(client.restClient.RequestEditors) != 2 { ts.FailNow("Expected client to have '%v' request editors, but got '%v'", 2, len(client.restClient.RequestEditors)) diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index d0d6780..35d0bca 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -107,7 +107,7 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { ts.FailNow(fmt.Sprintf("Expected idxConn to have empty namespace, but got '%s'", idxConn.Namespace)) } if idxConn.additionalMetadata != nil { - ts.FailNow(fmt.Sprintf("Expected idxConn to have nil additionalMetadata, but got '%+v'", idxConn.additionalMetadata)) + ts.FailNow(fmt.Sprintf("Expected idxConn additionalMetadata to be nil, but got '%+v'", idxConn.additionalMetadata)) } if idxConn.dataClient == nil { ts.FailNow("Expected idxConn to have non-nil dataClient") From e5753fb9b1f81f766eb9b72fbe7ddbeba41e1a78 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Tue, 30 Apr 2024 17:46:47 -0400 Subject: [PATCH 5/7] add newIndexParameters struct and swap to using that for newIndexConnection --- pinecone/client.go | 4 ++-- pinecone/index_connection.go | 16 ++++++++++++---- pinecone/index_connection_test.go | 10 +++++----- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 7633f29..ac2891b 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -65,7 +65,7 @@ func (c *Client) IndexWithNamespace(host string, namespace string) (*IndexConnec } func (c *Client) IndexWithAdditionalMetadata(host string, namespace string, additionalMetadata map[string]string) (*IndexConnection, error) { - idx, err := newIndexConnection(c.apiKey, host, namespace, c.sourceTag, additionalMetadata) + idx, err := newIndexConnection(newIndexParameters{apiKey: c.apiKey, host: host, namespace: namespace, sourceTag: c.sourceTag, additionalMetadata: additionalMetadata}) if err != nil { return nil, err } @@ -82,7 +82,7 @@ func (c *Client) ListIndexes(ctx context.Context) ([]*Index, error) { if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) } - + var indexList control.IndexList err = json.NewDecoder(res.Body).Decode(&indexList) if err != nil { diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index a52d932..a2daf3a 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -21,15 +21,23 @@ type IndexConnection struct { grpcConn *grpc.ClientConn } -func newIndexConnection(apiKey string, host string, namespace string, sourceTag string, additionalMetadata map[string]string) (*IndexConnection, error) { +type newIndexParameters struct { + apiKey string + host string + namespace string + sourceTag string + additionalMetadata map[string]string +} + +func newIndexConnection(in newIndexParameters) (*IndexConnection, error) { config := &tls.Config{} - target := fmt.Sprintf("%s:443", host) + target := fmt.Sprintf("%s:443", in.host) conn, err := grpc.Dial( target, grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithAuthority(target), grpc.WithBlock(), - grpc.WithUserAgent(useragent.BuildUserAgentGRPC(sourceTag)), + grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)), ) if err != nil { @@ -39,7 +47,7 @@ func newIndexConnection(apiKey string, host string, namespace string, sourceTag dataClient := data.NewVectorServiceClient(conn) - idx := IndexConnection{Namespace: namespace, apiKey: apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: additionalMetadata} + idx := IndexConnection{Namespace: in.namespace, apiKey: in.apiKey, dataClient: &dataClient, grpcConn: conn, additionalMetadata: in.additionalMetadata} return &idx, nil } diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 35d0bca..fd709d6 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -71,12 +71,12 @@ func (ts *IndexConnectionTests) SetupSuite() { namespace, err := uuid.NewV7() assert.NoError(ts.T(), err) - idxConn, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), "", nil) + idxConn, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ""}) assert.NoError(ts.T(), err) ts.idxConn = idxConn ts.sourceTag = "test_source_tag" - idxConnSourceTag, err := newIndexConnection(ts.apiKey, ts.host, namespace.String(), ts.sourceTag, nil) + idxConnSourceTag, err := newIndexConnection(newIndexParameters{apiKey: ts.apiKey, host: ts.host, namespace: namespace.String(), sourceTag: ts.sourceTag}) assert.NoError(ts.T(), err) ts.idxConnSourceTag = idxConnSourceTag @@ -97,7 +97,7 @@ func (ts *IndexConnectionTests) TestNewIndexConnection() { apiKey := "test-api-key" namespace := "" sourceTag := "" - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag, nil) + idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { @@ -124,7 +124,7 @@ func (ts *IndexConnectionTests) TestNewIndexConnectionNamespace() { apiKey := "test-api-key" namespace := "test-namespace" sourceTag := "test-source-tag" - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag, nil) + idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag}) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { @@ -146,7 +146,7 @@ func (ts *IndexConnectionTests) TestNewIndexConnectionAdditionalMetadata() { namespace := "test-namespace" sourceTag := "test-source-tag" additionalMetadata := map[string]string{"test-header": "test-value"} - idxConn, err := newIndexConnection(apiKey, ts.host, namespace, sourceTag, additionalMetadata) + idxConn, err := newIndexConnection(newIndexParameters{apiKey: apiKey, host: ts.host, namespace: namespace, sourceTag: sourceTag, additionalMetadata: additionalMetadata}) assert.NoError(ts.T(), err) if idxConn.apiKey != apiKey { From 59a69184e47295f330e708cae1cc8f29c38c9c4a Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Fri, 3 May 2024 17:27:47 -0400 Subject: [PATCH 6/7] add new buildClientOptions function, do not pass the Api-Key header if both Api-Key and Authorization header have been provided, add unit tests --- pinecone/client.go | 62 ++++++++++++++++++++++++++++------------- pinecone/client_test.go | 36 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index ac2891b..7aa4cad 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/deepmap/oapi-codegen/v2/pkg/securityprovider" "github.com/pinecone-io/go-pinecone/internal/gen/control" @@ -21,31 +22,17 @@ type Client struct { } type NewClientParams struct { - ApiKey string - SourceTag string // optional - Headers map[string]string // optional - RestClient *http.Client // optional + ApiKey string // optional unless no Authorization header provided + SourceTag string // optional + Headers map[string]string // optional + RestClient *http.Client // optional } func NewClient(in NewClientParams) (*Client, error) { - clientOptions := []control.ClientOption{} - apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + clientOptions, err := buildClientOptions(in) if err != nil { return nil, err } - clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) - - userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) - clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) - - for key, value := range in.Headers { - headerProvider := provider.NewHeaderProvider(key, value) - clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) - } - - if in.RestClient != nil { - clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) - } client, err := control.NewClient("https://api.pinecone.io", clientOptions...) if err != nil { @@ -425,3 +412,40 @@ func minOne(x int32) int32 { } return x } + +func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { + clientOptions := []control.ClientOption{} + hasAuthorizationHeader := false + hasApiKey := in.ApiKey != "" + + userAgentProvider := provider.NewHeaderProvider("User-Agent", useragent.BuildUserAgent(in.SourceTag)) + clientOptions = append(clientOptions, control.WithRequestEditorFn(userAgentProvider.Intercept)) + + for key, value := range in.Headers { + headerProvider := provider.NewHeaderProvider(key, value) + + if strings.Contains(key, "Authorization") { + hasAuthorizationHeader = true + } + + clientOptions = append(clientOptions, control.WithRequestEditorFn(headerProvider.Intercept)) + } + + if !hasAuthorizationHeader { + apiKeyProvider, err := securityprovider.NewSecurityProviderApiKey("header", "Api-Key", in.ApiKey) + if err != nil { + return nil, err + } + clientOptions = append(clientOptions, control.WithRequestEditorFn(apiKeyProvider.Intercept)) + } + + if !hasAuthorizationHeader && !hasApiKey { + return nil, fmt.Errorf("no API key provided, please pass an API key for authorization") + } + + if in.RestClient != nil { + clientOptions = append(clientOptions, control.WithHTTPClient(in.RestClient)) + } + + return clientOptions, nil +} diff --git a/pinecone/client_test.go b/pinecone/client_test.go index 59bb678..ecb8bbb 100644 --- a/pinecone/client_test.go +++ b/pinecone/client_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "reflect" + "strings" "testing" "github.com/google/uuid" @@ -112,6 +113,16 @@ func (ts *ClientTests) TestNewClientParamsSetHeaders() { } } +func (ts *ClientTests) TestNewClientParamsNoApiKeyNoAuthorizationHeader() { + client, err := NewClient(NewClientParams{}) + require.NotNil(ts.T(), err, "Expected error when creating client without an API key or Authorization header") + if !strings.Contains(err.Error(), "no API key provided, please pass an API key for authorization") { + ts.FailNow(fmt.Sprintf("Expected error to contain 'no API key provided, please pass an API key for authorization', but got '%s'", err.Error())) + } + + require.Nil(ts.T(), client, "Expected client to be nil when creating client without an API key or Authorization header") +} + func (ts *ClientTests) TestHeadersAppliedToRequests() { apiKey := "test-api-key" headers := map[string]string{"test-header": "123456"} @@ -133,6 +144,31 @@ func (ts *ClientTests) TestHeadersAppliedToRequests() { } } +func (ts *ClientTests) TestAuthorizationHeaderOverridesApiKey() { + apiKey := "test-api-key" + headers := map[string]string{"Authorization": "bearer fooo"} + + httpClient := mocks.CreateMockClient(`{"indexes": []}`) + client, err := NewClient(NewClientParams{ApiKey: apiKey, Headers: headers, RestClient: httpClient}) + if err != nil { + ts.FailNow(err.Error()) + } + mockTransport := httpClient.Transport.(*mocks.MockTransport) + + _, err = client.ListIndexes(context.Background()) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), mockTransport.Req, "Expected request to be made") + + apiKeyHeaderValue := mockTransport.Req.Header.Get("Api-Key") + authHeaderValue := mockTransport.Req.Header.Get("Authorization") + if authHeaderValue != "bearer fooo" { + ts.FailNow(fmt.Sprintf("Expected request to have header value 'bearer fooo', but got '%s'", authHeaderValue)) + } + if apiKeyHeaderValue != "" { + ts.FailNow(fmt.Sprintf("Expected request to not have Api-Key header, but got '%s'", apiKeyHeaderValue)) + } +} + func (ts *ClientTests) TestListIndexes() { indexes, err := ts.client.ListIndexes(context.Background()) require.NoError(ts.T(), err) From 89fc762240cd0f5927b9ab9a86886cbdff153930 Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Mon, 6 May 2024 16:05:33 -0400 Subject: [PATCH 7/7] review feedback, relax check on authorization string, update required comment --- pinecone/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pinecone/client.go b/pinecone/client.go index 7aa4cad..97e8e91 100644 --- a/pinecone/client.go +++ b/pinecone/client.go @@ -22,7 +22,7 @@ type Client struct { } type NewClientParams struct { - ApiKey string // optional unless no Authorization header provided + ApiKey string // required unless Authorization header provided SourceTag string // optional Headers map[string]string // optional RestClient *http.Client // optional @@ -424,7 +424,7 @@ func buildClientOptions(in NewClientParams) ([]control.ClientOption, error) { for key, value := range in.Headers { headerProvider := provider.NewHeaderProvider(key, value) - if strings.Contains(key, "Authorization") { + if strings.Contains(strings.ToLower(key), "authorization") { hasAuthorizationHeader = true }