Skip to content

Commit

Permalink
change expose and rename Client.clientTransport
Browse files Browse the repository at this point in the history
change type of Client.Transport to pointer
change add exposed Parent to transport
change use Parent in RoundTripper
  • Loading branch information
Fritte795 committed Nov 4, 2023
1 parent 8e8f9cc commit ec993ac
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
38 changes: 20 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ var (
)

type Client struct {
ClientError error
session http.Client
clientTransport transport
ClientError error
session http.Client
Transport *transport
}

// NewClient constructs a new client given a URL to a Postgrest instance.
Expand All @@ -31,34 +31,35 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client {
t := transport{
header: http.Header{},
baseURL: *baseURL,
Parent: http.DefaultTransport,
}

c := Client{
session: http.Client{Transport: t},
clientTransport: t,
session: http.Client{Transport: &t},
Transport: &t,
}

if schema == "" {
schema = "public"
}

// Set required headers
c.clientTransport.header.Set("Accept", "application/json")
c.clientTransport.header.Set("Content-Type", "application/json")
c.clientTransport.header.Set("Accept-Profile", schema)
c.clientTransport.header.Set("Content-Profile", schema)
c.clientTransport.header.Set("X-Client-Info", "postgrest-go/"+version)
c.Transport.header.Set("Accept", "application/json")
c.Transport.header.Set("Content-Type", "application/json")
c.Transport.header.Set("Accept-Profile", schema)
c.Transport.header.Set("Content-Profile", schema)
c.Transport.header.Set("X-Client-Info", "postgrest-go/"+version)

// Set optional headers if they exist
for key, value := range headers {
c.clientTransport.header.Set(key, value)
c.Transport.header.Set(key, value)
}

return &c
}

func (c *Client) Ping() bool {
req, err := http.NewRequest("GET", path.Join(c.clientTransport.baseURL.Path, ""), nil)
req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil)
if err != nil {
c.ClientError = err

Expand All @@ -83,15 +84,15 @@ func (c *Client) Ping() bool {

// TokenAuth sets authorization headers for subsequent requests.
func (c *Client) TokenAuth(token string) *Client {
c.clientTransport.header.Set("Authorization", "Bearer "+token)
c.clientTransport.header.Set("apikey", token)
c.Transport.header.Set("Authorization", "Bearer "+token)
c.Transport.header.Set("apikey", token)
return c
}

// ChangeSchema modifies the schema for subsequent requests.
func (c *Client) ChangeSchema(schema string) *Client {
c.clientTransport.header.Set("Accept-Profile", schema)
c.clientTransport.header.Set("Content-Profile", schema)
c.Transport.header.Set("Accept-Profile", schema)
c.Transport.header.Set("Content-Profile", schema)
return c
}

Expand All @@ -115,7 +116,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
}

readerBody := bytes.NewBuffer(byteBody)
url := path.Join(c.clientTransport.baseURL.Path, "rpc", name)
url := path.Join(c.Transport.baseURL.Path, "rpc", name)
req, err := http.NewRequest("POST", url, readerBody)
if err != nil {
c.ClientError = err
Expand Down Expand Up @@ -152,6 +153,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
type transport struct {
header http.Header
baseURL url.URL
Parent http.RoundTripper
}

func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
Expand All @@ -162,5 +164,5 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
}

req.URL = t.baseURL.ResolveReference(req.URL)
return http.DefaultTransport.RoundTrip(req)
return t.Parent.RoundTrip(req)
}
2 changes: 1 addition & 1 deletion execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func executeHelper(client *Client, method string, body []byte, urlFragments []st
}

readerBody := bytes.NewBuffer(body)
baseUrl := path.Join(append([]string{client.clientTransport.baseURL.Path}, urlFragments...)...)
baseUrl := path.Join(append([]string{client.Transport.baseURL.Path}, urlFragments...)...)
req, err := http.NewRequest(method, baseUrl, readerBody)
if err != nil {
return nil, 0, fmt.Errorf("error creating request: %s", err.Error())
Expand Down

0 comments on commit ec993ac

Please sign in to comment.