Skip to content

Commit

Permalink
feat: provide input via stdin and as an argument reopened (#239)
Browse files Browse the repository at this point in the history
* feat: provide prompt via stdin and cmd arg

* chore: fix tests for new feature

* chore: deactivate govulncheck

* chore: update docs with new feature
  • Loading branch information
tbckr authored Mar 22, 2024
1 parent 06d3518 commit e282744
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 41 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ $ echo -n "mass of sun" | sgpt
The mass of the sun is approximately 1.989 x 10^30 kilograms.
```
You can also add another prompt to the piped data by specifying the `stdin` modifier and then specifying the prompt:
```shell
$ echo "Say: Hello World!" | sgpt stdin 'Replace every "World" word with "ChatGPT"'
Hello ChatGPT!
```
If you want to stream the completion to the command line, you can add the `--stream` flag. This will stream the output
to the command line as it is generated.
Expand Down
7 changes: 7 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ $ echo -n "mass of sun" | sgpt
The mass of the sun is approximately 1.989 x 10^30 kilograms.
```

You can also add another prompt to the piped data by specifying the `stdin` modifier and then specifying the prompt:

```shell
$ echo "Say: Hello World!" | sgpt stdin 'Replace every "World" word with "ChatGPT"'
Hello ChatGPT!
```

If you want to stream the completion to the command line, you can add the `--stream` flag. This will stream the output
to the command line as it is generated.

Expand Down
44 changes: 25 additions & 19 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func CreateClient(config *viper.Viper, out io.Writer) (*OpenAIClient, error) {
// CreateCompletion creates a completion for the given prompt and modifier. If chatID is provided, the chat is reused
// and the completion is added to the chat with this ID. If no chatID is provided, only the modifier and prompt are
// used to create the completion. The completion is printed to the out writer of the client and returned as a string.
func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID, prompt, modifier string, input []string) (string, error) {
func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID string, prompt []string, modifier string, input []string) (string, error) {
var messages []openai.ChatCompletionMessage
var err error

Expand All @@ -128,12 +128,12 @@ func (c *OpenAIClient) CreateCompletion(ctx context.Context, chatID, prompt, mod
messages = append(messages, loadedMessages...)

// Add prompt to messages
var promptMessage openai.ChatCompletionMessage
promptMessage, err = c.createPromptMessage(prompt, input)
var promptMessages []openai.ChatCompletionMessage
promptMessages, err = c.createPromptMessages(prompt, input)
if err != nil {
return "", err
}
messages = append(messages, promptMessage)
messages = append(messages, promptMessages...)
slog.Debug("Added prompt message")

// Create request
Expand Down Expand Up @@ -218,20 +218,22 @@ func (c *OpenAIClient) loadChatMessages(isChat bool, chatID, modifier string) (m
return
}

func (c *OpenAIClient) createPromptMessage(prompt string, input []string) (message openai.ChatCompletionMessage, err error) {
func (c *OpenAIClient) createPromptMessages(prompts, input []string) (messages []openai.ChatCompletionMessage, err error) {
if len(input) > 0 {
slog.Warn("The GPT-4 Vision API is in beta and may not work as expected")
// Request to the gpt-4-vision API
slog.Warn("The GPT-4 Vision API is in beta and may not work as expected")

var messageParts []openai.ChatMessagePart
// Add prompt to message
messageParts := []openai.ChatMessagePart{
{
// We append the stdin as part of the prompt as a message part
for _, p := range prompts {
messageParts = append(messageParts, openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeText,
Text: prompt,
},
Text: p,
})
}

// Add images to message
// Add images to messages
for _, i := range input {
// By default, assume that the input is a URL
imageData := i
Expand All @@ -241,7 +243,7 @@ func (c *OpenAIClient) createPromptMessage(prompt string, input []string) (messa
// Input is a file, load image data
imageData, err = c.buildImageFileData(i)
if err != nil {
return openai.ChatCompletionMessage{}, err
return []openai.ChatCompletionMessage{}, err
}
}

Expand All @@ -253,19 +255,23 @@ func (c *OpenAIClient) createPromptMessage(prompt string, input []string) (messa
})
}

message = openai.ChatCompletionMessage{
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
MultiContent: messageParts,
}
})
} else {
// Normal prompt
message = openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: prompt,
// We append the stdin as part of the prompt
// This means we just add the prompt as a message
for _, p := range prompts {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: p,
})
}
}
slog.Debug("Added prompt message")
return message, nil
slog.Debug("Added prompt messages")
return messages, nil
}

func (c *OpenAIClient) buildImageFileData(inputFile string) (imageData string, err error) {
Expand Down
30 changes: 15 additions & 15 deletions pkg/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestSimplePrompt(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Say: Hello World!"
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestStreamSimplePrompt(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Say: Hello World!"
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestPromptSaveAsChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Say: Hello World!"
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -193,7 +193,7 @@ func TestPromptSaveAsChat(t *testing.T) {

// Check if the prompt was added
require.Equal(t, openai.ChatMessageRoleUser, messages[0].Role)
require.Equal(t, prompt, messages[0].Content)
require.Equal(t, prompt[0], messages[0].Content)

// Check if the response was added
require.Equal(t, openai.ChatMessageRoleAssistant, messages[1].Role)
Expand All @@ -212,7 +212,7 @@ func TestPromptLoadChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Repeat last message"
prompt := []string{"Repeat last message"}
expected := "World!"

httpmock.ActivateNonDefault(client.HTTPClient)
Expand Down Expand Up @@ -258,7 +258,7 @@ func TestPromptLoadChat(t *testing.T) {

// Check if the prompt was added
require.Equal(t, openai.ChatMessageRoleUser, messages[2].Role)
require.Equal(t, prompt, messages[2].Content)
require.Equal(t, prompt[0], messages[2].Content)

// Check if the response was added
require.Equal(t, openai.ChatMessageRoleAssistant, messages[3].Role)
Expand All @@ -277,7 +277,7 @@ func TestPromptWithModifier(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "Print Hello World"
prompt := []string{"Print Hello World!"}
response := `echo \"Hello World\"`
expected := `echo "Hello World"`

Expand Down Expand Up @@ -325,7 +325,7 @@ func TestPromptWithModifier(t *testing.T) {

// Check if the prompt was added
require.Equal(t, openai.ChatMessageRoleUser, messages[1].Role)
require.Equal(t, prompt, messages[1].Content)
require.Equal(t, prompt[0], messages[1].Content)

// Check if the response was added
require.Equal(t, openai.ChatMessageRoleAssistant, messages[2].Role)
Expand All @@ -344,7 +344,7 @@ func TestSimplePromptWithLocalImage(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what can you see on the picture?"
prompt := []string{"what can you see on the picture?"}
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "testdata/marvin.jpg"

Expand Down Expand Up @@ -381,7 +381,7 @@ func TestSimplePromptWithLocalImageAndChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what can you see on the picture?"
prompt := []string{"what can you see on the picture?"}
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "testdata/marvin.jpg"

Expand Down Expand Up @@ -423,7 +423,7 @@ func TestSimplePromptWithLocalImageAndChat(t *testing.T) {
require.Len(t, messages[0].MultiContent, 2)
// Check, if the prompt is a multi content message
require.Equal(t, "text", string(messages[0].MultiContent[0].Type))
require.Equal(t, prompt, messages[0].MultiContent[0].Text)
require.Equal(t, prompt[0], messages[0].MultiContent[0].Text)
// Check, if the image was added
require.Equal(t, "image_url", string(messages[0].MultiContent[1].Type))
require.NotEmpty(t, messages[0].MultiContent[1].ImageURL.URL)
Expand All @@ -446,7 +446,7 @@ func TestSimplePromptWithURLImageAndChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what can you see on the picture?"
prompt := []string{"what can you see on the picture?"}
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "https://upload.wikimedia.org/wikipedia/en/c/cb/Marvin_%28HHGG%29.jpg"

Expand Down Expand Up @@ -488,7 +488,7 @@ func TestSimplePromptWithURLImageAndChat(t *testing.T) {
require.Len(t, messages[0].MultiContent, 2)
// Check, if the prompt is a multi content message
require.Equal(t, "text", string(messages[0].MultiContent[0].Type))
require.Equal(t, prompt, messages[0].MultiContent[0].Text)
require.Equal(t, prompt[0], messages[0].MultiContent[0].Text)
// Check, if the image was added
require.Equal(t, "image_url", string(messages[0].MultiContent[1].Type))
require.Equal(t, inputImage, messages[0].MultiContent[1].ImageURL.URL)
Expand All @@ -510,7 +510,7 @@ func TestSimplePromptWithMixedImagesAndChat(t *testing.T) {
client, err := CreateClient(testCtx.Config, writer)
require.NoError(t, err)

prompt := "what is the difference between those two pictures?"
prompt := []string{"what is the difference between those two pictures?"}
expected := "The two images provided appear to be identical. Both show the same depiction of a"
inputImageFile := "testdata/marvin.jpg"
inputImageURL := "https://upload.wikimedia.org/wikipedia/en/c/cb/Marvin_%28HHGG%29.jpg"
Expand Down Expand Up @@ -554,7 +554,7 @@ func TestSimplePromptWithMixedImagesAndChat(t *testing.T) {

// Check, if the prompt is a multi content message
require.Equal(t, "text", string(messages[0].MultiContent[0].Type))
require.Equal(t, prompt, messages[0].MultiContent[0].Text)
require.Equal(t, prompt[0], messages[0].MultiContent[0].Text)
// Check, if the URL image was added
require.Equal(t, "image_url", string(messages[0].MultiContent[1].Type))
require.Equal(t, inputImageURL, messages[0].MultiContent[1].ImageURL.URL)
Expand Down
19 changes: 12 additions & 7 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,30 @@ ls | sort
return err
}

var prompt, input string
var prompts []string
mode := "txt"

if isPiped {
var stdinInput string
slog.Debug("Piped shell detected")
// input is provided via stdin
input, err = fs.ReadString(cmd.InOrStdin())
stdinInput, err = fs.ReadString(cmd.InOrStdin())
if err != nil {
return err
}
if len(input) == 0 {
if len(stdinInput) == 0 {
slog.Debug("No input via pipe provided")
return ErrMissingInput
}
prompt = input
prompts = append(prompts, stdinInput)
// mode is provided via command line args
if len(args) == 1 {
slog.Debug("Mode provided via command line args")
mode = args[0]
} else if len(args) == 2 {
slog.Debug("Mode and prompt provided via command line args")
mode = args[0]
prompts = append(prompts, args[1])
}

} else {
Expand All @@ -201,12 +206,12 @@ ls | sort
} else if len(args) == 1 {
// input is provided via command line args
slog.Debug("No mode provided via command line args - using default mode")
prompt = args[0]
prompts = append(prompts, args[0])
} else {
// input and mode are provided via command line args
slog.Debug("Mode and prompt provided via command line args")
mode = strings.ToLower(args[0])
prompt = args[1]
prompts = append(prompts, args[1])
}
}

Expand All @@ -218,7 +223,7 @@ ls | sort
}

var response string
response, err = client.CreateCompletion(cmd.Context(), root.chat, prompt, mode, root.input)
response, err = client.CreateCompletion(cmd.Context(), root.chat, prompts, mode, root.input)
if err != nil {
return err
}
Expand Down
51 changes: 51 additions & 0 deletions pkg/cli/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,57 @@ func TestRootCmd_SimplePromptViaPipedShellAndModifier(t *testing.T) {
wg.Wait()
}

func TestRootCmd_PipedShellAndModifierAndPrompt(t *testing.T) {
testCtx := testlib.NewTestCtx(t)
testlib.SetAPIKey(t)
mem := &exitMemento{}

var wg sync.WaitGroup
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()

client, err := api.CreateClient(testCtx.Config, stdoutWriter)
require.NoError(t, err)

stdinPrompt := "Say: Hello World!"
prompt := "Replace every 'World' word with 'ChatGPT'"
response := "Hello ChatGPT!"
expected := "Hello ChatGPT!\n"

httpmock.ActivateNonDefault(client.HTTPClient)
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(response)

root := newRootCmd(mem.Exit, testCtx.Config, mockIsPipedShell(true, nil), useMockClient(client))
root.cmd.SetIn(stdinReader)
root.cmd.SetOut(stdoutWriter)

wg.Add(1)
go func() {
defer wg.Done()
_, errWrite := stdinWriter.Write([]byte(stdinPrompt))
require.NoError(t, stdinWriter.Close())
require.NoError(t, errWrite)
}()

wg.Add(1)
go func() {
defer wg.Done()
var buf bytes.Buffer
_, errReader := io.Copy(&buf, stdoutReader)
require.NoError(t, errReader)
require.NoError(t, stdoutReader.Close())
require.Equal(t, expected, buf.String())
}()

root.Execute([]string{"stdin", prompt})
require.Equal(t, 0, mem.code)
require.NoError(t, stdinReader.Close())
require.NoError(t, stdoutWriter.Close())

wg.Wait()
}

func TestRootCmd_SimpleShellPrompt(t *testing.T) {
testCtx := testlib.NewTestCtx(t)
testlib.SetAPIKey(t)
Expand Down
3 changes: 3 additions & 0 deletions pkg/modifiers/modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ func GetChatModifier(config *viper.Viper, modifier string) (string, error) {
case "txt":
slog.Debug("No persona provided")
return "", nil
case "stdin":
slog.Debug("No persona provided, just using stdin with additional prompt")
return "", nil
default:
slog.Debug("Unsupported persona: " + modifier)
return "", ErrUnsupportedModifier
Expand Down

0 comments on commit e282744

Please sign in to comment.