Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package app

import (
"context"
"fmt"

tea "github.com/charmbracelet/bubbletea"
"github.com/cockroachdb/errors"
Expand All @@ -28,8 +29,8 @@ func NewApp(provider llmprovider.Provider, git interfaces.GitClient, prompt stri
}
}

func (app *App) Run(ctx context.Context, userMessage string) error {
model := ui.InitialModel(ctx, app.llmProvider, app.git, app.prompt, userMessage)
func (app *App) Run(ctx context.Context, userMessage string, yolo bool) error {
model := ui.InitialModel(ctx, app.llmProvider, app.git, app.prompt, userMessage, yolo)
p := tea.NewProgram(model)

finalModel, err := p.Run()
Expand All @@ -51,6 +52,11 @@ func (app *App) Run(ctx context.Context, userMessage string) error {
}

if m.Action() == ui.Commit {
if yolo {
fmt.Println(ui.BoldGreen.Render("COMMIT MESSAGE:"))
fmt.Println(m.CommitMessage())
fmt.Println()
}
err = app.git.Commit(m.CommitMessage())
if err != nil {
return err
Expand Down
22 changes: 21 additions & 1 deletion internal/ui/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type Model struct {
promptView tea.Model
action Action
err error
yolo bool
}

func InitialModel(
Expand All @@ -65,6 +66,7 @@ func InitialModel(
gitClient interfaces.GitClient,
systemPrompt string,
userMessage string,
yolo bool,
) *Model {
return &Model{
ctx: ctx,
Expand All @@ -76,6 +78,7 @@ func InitialModel(
spinner: spinner.New(spinner.WithSpinner(spinner.MiniDot)),
spinnerMessage: Magenta.Render("Running git commands to determine modified files..."),
action: None,
yolo: yolo,
}
}

Expand Down Expand Up @@ -111,7 +114,6 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, m.generateSummary(m.diff, "")

case llmResultMsg:
m.state = showCommitView
commitMessage := string(msg)
if m.userMessage != "" {
// append the user supplied message
Expand All @@ -126,6 +128,18 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
commitMessage = strings.ReplaceAll(commitMessage, "\n\n\n", "\n\n")

if m.yolo {
if commitMessage == "" {
m.err = errors.New("failed to generate a commit summary")
return m, tea.Quit
}
m.action = Commit
m.commitMessage = commitMessage
return m, tea.Quit
}
Comment thread
rm-hull marked this conversation as resolved.

m.state = showCommitView
m.commitView, m.err = initialCommitViewModel(commitMessage)
if m.err != nil {
return m, tea.Quit
Expand Down Expand Up @@ -185,8 +199,14 @@ func (m *Model) View() string {
case showSpinner:
return m.spinner.View() + " " + m.spinnerMessage
case showCommitView:
if m.commitView == nil {
return m.spinner.View() + " " + m.spinnerMessage
}
return m.commitView.View()
case showRegeneratePrompt:
if m.commitView == nil || m.promptView == nil {
return m.spinner.View() + " " + m.spinnerMessage
}
return m.commitView.View() + m.promptView.View()
default:
return ""
Expand Down
22 changes: 21 additions & 1 deletion internal/ui/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestModel_Update(t *testing.T) {
// Explicitly use the types to avoid "imported and not used" warnings
var _ interfaces.GitClient = mockGit
var _ llmprovider.Provider = mockLLM
return InitialModel(ctx, mockLLM, mockGit, "system prompt", "user message")
return InitialModel(ctx, mockLLM, mockGit, "system prompt", "user message", false)
}

t.Run("tea.KeyMsg - CtrlC in showSpinner state", func(t *testing.T) {
Expand Down Expand Up @@ -291,6 +291,26 @@ func TestModel_Update(t *testing.T) {
assert.Nil(t, cmd) // Mock returns nil cmd
mockPromptView.AssertCalled(t, "Update", testMsg)
})

t.Run("llmResultMsg - YOLO mode", func(t *testing.T) {
m := InitialModel(ctx, mockLLM, mockGit, "system prompt", "user message", true)
updatedModel, cmd := m.Update(llmResultMsg("commit summary"))

assert.Equal(t, Commit, updatedModel.(*Model).action)
assert.Contains(t, updatedModel.(*Model).commitMessage, "commit summary")
assert.NotNil(t, cmd)
assert.IsType(t, tea.QuitMsg{}, cmd())
})

t.Run("llmResultMsg - YOLO mode - empty summary", func(t *testing.T) {
m := InitialModel(ctx, mockLLM, mockGit, "system prompt", "", true)
updatedModel, cmd := m.Update(llmResultMsg(""))

assert.NotNil(t, updatedModel.(*Model).err)
assert.Equal(t, "failed to generate a commit summary", updatedModel.(*Model).err.Error())
assert.NotNil(t, cmd)
assert.IsType(t, tea.QuitMsg{}, cmd())
})
}

// mockTeaModel is a generic mock for tea.Model interface
Expand Down
10 changes: 6 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func main() {
var llmProvider string
var runSetupWizard *bool
var showVersion *bool
var yoloMode *bool
var addAll *bool

rootCmd := &cobra.Command{
Use: "git-commit-summary",
Expand Down Expand Up @@ -60,9 +62,8 @@ func main() {
provider, err := llmprovider.NewProvider(ctx, cfg)
handleError(err)

addAll, _ := cmd.Flags().GetBool("all")
application := app.NewApp(provider, git.NewClient(addAll), cfg.Prompt)
err = application.Run(ctx, userMessage)
application := app.NewApp(provider, git.NewClient(*addAll), cfg.Prompt)
err = application.Run(ctx, userMessage, *yoloMode)
if err != nil {
handleError(err)
}
Expand All @@ -71,9 +72,10 @@ func main() {

showVersion = rootCmd.PersistentFlags().BoolP("version", "v", false, "Display version information")
runSetupWizard = rootCmd.PersistentFlags().Bool("setup-wizard", false, "Run setup wizard")
yoloMode = rootCmd.PersistentFlags().Bool("yolo", false, "Commit immediately without asking for confirmation")
addAll = rootCmd.PersistentFlags().BoolP("all", "a", false, "Add all tracked files to the commit")
rootCmd.PersistentFlags().StringVarP(&userMessage, "message", "m", "", "Append a message to the commit summary")
rootCmd.PersistentFlags().StringVar(&llmProvider, "llm-provider", cfg.LLMProvider, "Use specific LLM provider, overrides environment variable LLM_PROVIDER")
rootCmd.PersistentFlags().BoolP("all", "a", false, "Add all tracked files to the commit")

_ = rootCmd.Execute()
}
Expand Down
Loading