From bdabe1c1522f7e5dbe6259ac2733fde2b2170d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Sun, 14 Apr 2024 10:06:28 +0200 Subject: [PATCH] Pass parent context to requests --- trino/trino.go | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/trino/trino.go b/trino/trino.go index 2383383..1ae20f1 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -448,8 +448,8 @@ func (c *Conn) Close() error { return nil } -func (c *Conn) newRequest(method, url string, body io.Reader, hs http.Header) (*http.Request, error) { - req, err := http.NewRequest(method, url, body) +func (c *Conn) newRequest(ctx context.Context, method, url string, body io.Reader, hs http.Header) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, fmt.Errorf("trino: %w", err) } @@ -485,14 +485,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response case <-ctx.Done(): return nil, ctx.Err() case <-timer.C: - timeout := DefaultQueryTimeout - if deadline, ok := ctx.Deadline(); ok { - timeout = time.Until(deadline) - } - client := c.httpClient - client.Timeout = timeout - req.Cancel = ctx.Done() - resp, err := client.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return nil, &ErrQueryFailed{Reason: err} } @@ -845,13 +838,19 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt } } - req, err := st.conn.newRequest("POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs) + var cancel context.CancelFunc = func() {} + if _, ok := ctx.Deadline(); !ok { + ctx, cancel = context.WithTimeout(ctx, DefaultQueryTimeout) + } + req, err := st.conn.newRequest(ctx, "POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs) if err != nil { + cancel() return nil, err } resp, err := st.conn.roundTrip(ctx, req) if err != nil { + cancel() return nil, err } @@ -861,6 +860,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt d.UseNumber() err = d.Decode(&sr) if err != nil { + cancel() return nil, fmt.Errorf("trino: %w", err) } @@ -879,8 +879,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt } hs := make(http.Header) hs.Add(trinoUserHeader, st.user) - req, err := st.conn.newRequest("GET", nextURI, nil, hs) + req, err := st.conn.newRequest(ctx, "GET", nextURI, nil, hs) if err != nil { + if ctx.Err() == context.Canceled { + st.errors <- context.Canceled + return + } st.errors <- err return } @@ -905,6 +909,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt }() go func() { defer close(st.queryResponses) + defer cancel() for { select { case resp := <-st.httpResponses: @@ -1011,12 +1016,12 @@ func (qr *driverRows) Close() error { if qr.stmt.user != "" { hs.Add(trinoUserHeader, qr.stmt.user) } - req, err := qr.stmt.conn.newRequest("DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) + ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) + defer cancel() + req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) if err != nil { return err } - ctx, cancel := context.WithTimeout(context.Background(), DefaultCancelQueryTimeout) - defer cancel() resp, err := qr.stmt.conn.roundTrip(ctx, req) if err != nil { qferr, ok := err.(*ErrQueryFailed)