Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yet Another V2 Assistant Streaming Implemenetation #748

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err er
return
}

func sendRequestStreamV2(client *Client, req *http.Request) (stream *StreamerV2, err error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

resp, err := client.config.HTTPClient.Do(req)
if err != nil {
return
}

// TODO: how to handle error?
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

return NewStreamerV2(resp.Body), nil
}

func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
Expand Down
148 changes: 142 additions & 6 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,13 @@ const (
)

type RunRequest struct {
AssistantID string `json:"assistant_id"`
Model string `json:"model,omitempty"`
Instructions string `json:"instructions,omitempty"`
AdditionalInstructions string `json:"additional_instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
AssistantID string `json:"assistant_id"`
Model string `json:"model,omitempty"`
Instructions string `json:"instructions,omitempty"`
AdditionalInstructions string `json:"additional_instructions,omitempty"`
AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`

// Sampling temperature between 0 and 2. Higher values like 0.8 are more random.
// lower values are more focused and deterministic.
Expand Down Expand Up @@ -124,6 +125,11 @@ const (
TruncationStrategyLastMessages = TruncationStrategy("last_messages")
)

type RunRequestStreaming struct {
RunRequest
Stream bool `json:"stream"`
}

type RunModifyRequest struct {
Metadata map[string]any `json:"metadata,omitempty"`
}
Expand Down Expand Up @@ -337,6 +343,36 @@ func (c *Client) SubmitToolOutputs(
return
}

type SubmitToolOutputsStreamRequest struct {
SubmitToolOutputsRequest
Stream bool `json:"stream"`
}

func (c *Client) SubmitToolOutputsStream(
ctx context.Context,
threadID string,
runID string,
request SubmitToolOutputsRequest,
) (stream *StreamerV2, err error) {
urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID)
r := SubmitToolOutputsStreamRequest{
SubmitToolOutputsRequest: request,
Stream: true,
}
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(r),
withBetaAssistantVersion(c.config.AssistantVersion),
)
if err != nil {
return
}

return sendRequestStreamV2(c, req)
}

// CancelRun cancels a run.
func (c *Client) CancelRun(
ctx context.Context,
Expand Down Expand Up @@ -375,6 +411,106 @@ func (c *Client) CreateThreadAndRun(
return
}

type StreamMessageDelta struct {
Role string `json:"role"`
Content []MessageContent `json:"content"`
FileIDs []string `json:"file_ids"`
}

type AssistantStreamEvent struct {
ID string `json:"id"`
Object string `json:"object"`
Delta StreamMessageDelta `json:"delta,omitempty"`

// Run
CreatedAt int64 `json:"created_at,omitempty"`
ThreadID string `json:"thread_id,omitempty"`
AssistantID string `json:"assistant_id,omitempty"`
Status RunStatus `json:"status,omitempty"`
RequiredAction *RunRequiredAction `json:"required_action,omitempty"`
LastError *RunLastError `json:"last_error,omitempty"`
ExpiresAt int64 `json:"expires_at,omitempty"`
StartedAt *int64 `json:"started_at,omitempty"`
CancelledAt *int64 `json:"cancelled_at,omitempty"`
FailedAt *int64 `json:"failed_at,omitempty"`
CompletedAt *int64 `json:"completed_at,omitempty"`
Model string `json:"model,omitempty"`
Instructions string `json:"instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility
Metadata map[string]any `json:"metadata,omitempty"`
Usage Usage `json:"usage,omitempty"`

// ThreadMessage.Completed
Role string `json:"role,omitempty"`
Content []MessageContent `json:"content,omitempty"`
// IncompleteDetails
// IncompleteAt

// Run steps
RunID string `json:"run_id"`
Type RunStepType `json:"type"`
StepDetails StepDetails `json:"step_details"`
ExpiredAt *int64 `json:"expired_at,omitempty"`
}

type AssistantStream struct {
*streamReader[AssistantStreamEvent]
}

func (c *Client) CreateThreadAndRunStream(
ctx context.Context,
request CreateThreadAndRunRequest) (stream *StreamerV2, err error) {
type createThreadAndStreamRequest struct {
CreateThreadAndRunRequest
Stream bool `json:"stream"`
}

urlSuffix := "/threads/runs"
sr := createThreadAndStreamRequest{
CreateThreadAndRunRequest: request,
Stream: true,
}

req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(sr),
withBetaAssistantVersion(c.config.AssistantVersion),
)
if err != nil {
return
}

return sendRequestStreamV2(c, req)
}

func (c *Client) CreateRunStream(
ctx context.Context,
threadID string,
request RunRequest) (stream *StreamerV2, err error) {
urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID)

r := RunRequestStreaming{
RunRequest: request,
Stream: true,
}

req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(r),
withBetaAssistantVersion(c.config.AssistantVersion),
)
if err != nil {
return
}

return sendRequestStreamV2(c, req)
}

// RetrieveRunStep retrieves a run step.
func (c *Client) RetrieveRunStep(
ctx context.Context,
Expand Down
25 changes: 25 additions & 0 deletions run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ func TestRun(t *testing.T) {
})
checks.NoError(t, err, "CreateThreadAndRun error")

_, err = client.CreateThreadAndRunStream(ctx, openai.CreateThreadAndRunRequest{
RunRequest: openai.RunRequest{
AssistantID: assistantID,
},
Thread: openai.ThreadRequest{
Messages: []openai.ThreadMessage{
{
Role: openai.ThreadMessageRoleUser,
Content: "Hello, World!",
},
},
},
})
checks.NoError(t, err, "CreateThreadAndStream error")

_, err = client.CreateRunStream(ctx, threadID, openai.RunRequest{
AssistantID: assistantID,
})
checks.NoError(t, err, "CreateRunStreaming error")

_, err = client.SubmitToolOutputsStream(ctx, threadID, runID, openai.SubmitToolOutputsRequest{
ToolOutputs: nil,
})
checks.NoError(t, err, "SubmitToolOutputsStream error")

_, err = client.RetrieveRunStep(ctx, threadID, runID, stepID)
checks.NoError(t, err, "RetrieveRunStep error")

Expand Down
Loading
Loading