Skip to content

Commit

Permalink
maintain underlying error structs to allow for type conversion (#293)
Browse files Browse the repository at this point in the history
* maintain underlying error structs to allow for type conversion and
defensive error checking

* allow Error.Is for Azure responses

* update readme, add tests to ensure type conversion

* fix whitespacing

* read me

* add import to readme example
  • Loading branch information
qhenkart committed May 3, 2023
1 parent 24aa200 commit a24581d
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
34 changes: 29 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/).
* DALL·E 2
* Whisper

Installation:
### Installation:
```
go get github.com/sashabaranov/go-openai
```


ChatGPT example usage:
### ChatGPT example usage:

```go
package main
Expand Down Expand Up @@ -52,9 +52,7 @@ func main() {

```



Other examples:
### Other examples:

<details>
<summary>ChatGPT streaming completion</summary>
Expand Down Expand Up @@ -462,3 +460,29 @@ func main() {
}
```
</details>

<details>
<summary>Error handling</summary>

Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors)

example:
```
e := &openai.APIError{}
if errors.As(err, &e) {
switch e.HTTPStatusCode {
case 401:
// invalid auth or key (do not retry)
case 429:
// rate limiting or engine overload (wait and retry)
case 500:
// openai server error (retry)
default:
// unhandled
}
}
```
</details>


7 changes: 4 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,16 @@ func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := RequestError{
reqErr := &RequestError{
HTTPStatusCode: resp.StatusCode,
Err: err,
}
if errRes.Error != nil {
reqErr.Err = errRes.Error
}
return fmt.Errorf("error, %w", &reqErr)
return reqErr
}

errRes.Error.HTTPStatusCode = resp.StatusCode
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
return errRes.Error
}
9 changes: 8 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field

import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -106,7 +107,7 @@ func TestHandleErrorResp(t *testing.T) {
}
}`,
)),
expected: "error, status code 401, message: Access denied due to Virtual Network/Firewall rules.",
expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
},
{
name: "503 Model Overloaded",
Expand Down Expand Up @@ -135,6 +136,12 @@ func TestHandleErrorResp(t *testing.T) {
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
t.Fail()
}

e := &APIError{}
if !errors.As(err, &e) {
t.Errorf("(%s) Expected error to be of type APIError", tc.name)
t.Fail()
}
})
}
}
6 changes: 5 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type ErrorResponse struct {
}

func (e *APIError) Error() string {
if e.HTTPStatusCode > 0 {
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
}

return e.Message
}

Expand Down Expand Up @@ -70,7 +74,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
}

func (e *RequestError) Error() string {
return fmt.Sprintf("status code %d, message: %s", e.HTTPStatusCode, e.Err)
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
}

func (e *RequestError) Unwrap() error {
Expand Down

0 comments on commit a24581d

Please sign in to comment.