From 7aeacbc2cab851b2e76cb4b1f92f5914f70f8f9d Mon Sep 17 00:00:00 2001 From: Richard Hull Date: Fri, 31 Oct 2025 00:25:21 +0000 Subject: [PATCH 1/2] refactor(app): Introduce dependency injection Refactor the core application logic to utilize interfaces for Git and UI operations, allowing dependencies to be injected. This enables comprehensive unit testing of the application flow and state transitions, addressing the lack of existing tests. * Introduces `app.Git` and `app.UI` interfaces. * Adds `github.com/stretchr/testify` for assertion helpers. * Includes unit tests for `app`, `config`, and `llm_provider`. --- TODO.md | 4 +- go.mod | 4 + go.sum | 6 + internal/app/app.go | 45 ++++--- internal/app/app_test.go | 162 +++++++++++++++++++++++++ internal/config/config_test.go | 34 ++++++ internal/git/client.go | 15 +++ internal/git/commit.go | 2 +- internal/git/diff.go | 2 +- internal/llm_provider/provider_test.go | 42 +++++++ internal/ui/client.go | 7 ++ internal/ui/text_area.go | 2 +- main.go | 13 +- 13 files changed, 311 insertions(+), 27 deletions(-) create mode 100644 internal/app/app_test.go create mode 100644 internal/config/config_test.go create mode 100644 internal/git/client.go create mode 100644 internal/llm_provider/provider_test.go create mode 100644 internal/ui/client.go diff --git a/TODO.md b/TODO.md index 2edbbfb..8286f29 100644 --- a/TODO.md +++ b/TODO.md @@ -50,5 +50,5 @@ A `Makefile` would automate common development tasks, such as building, testing, The application currently has no unit tests. -- **Add unit tests for the core application logic:** This will help to ensure that the application is working correctly and prevent regressions. -- **Use a testing framework, such as `testify`:** A testing framework will make it easier to write and run tests. +~~- **Add unit tests for the core application logic:** This will help to ensure that the application is working correctly and prevent regressions.~~ +- ~~**Use a testing framework, such as `testify`:** A testing framework will make it easier to write and run tests.~~ diff --git a/go.mod b/go.mod index 626885f..42976ef 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gookit/color v1.6.0 github.com/joho/godotenv v1.5.1 github.com/spf13/cobra v1.10.1 + github.com/stretchr/testify v1.11.1 google.golang.org/genai v1.33.0 ) @@ -23,6 +24,7 @@ require ( github.com/charmbracelet/x/term v0.2.1 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/galactixx/ansiwalker v1.0.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect @@ -34,6 +36,7 @@ require ( github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.16.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/tidwall/gjson v1.18.0 // indirect @@ -41,6 +44,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) require ( diff --git a/go.sum b/go.sum index 8f5a1fd..f220dfd 100644 --- a/go.sum +++ b/go.sum @@ -86,8 +86,11 @@ github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= @@ -117,6 +120,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= @@ -210,6 +215,7 @@ google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94U google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/app/app.go b/internal/app/app.go index 767b088..6e51d23 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -10,48 +10,59 @@ import ( "github.com/briandowns/spinner" "github.com/galactixx/stringwrap" "github.com/gookit/color" - "github.com/rm-hull/git-commit-summary/internal/config" "github.com/rm-hull/git-commit-summary/internal/git" llmprovider "github.com/rm-hull/git-commit-summary/internal/llm_provider" "github.com/rm-hull/git-commit-summary/internal/ui" ) +type Git interface { + Diff() (string, error) + Commit(message string) error +} + +type UI interface { + TextArea(value string) (string, bool, error) +} + +// Verify that structs implement interfaces +var _ Git = (*git.Client)(nil) +var _ UI = (*ui.Client)(nil) + type App struct { llmProvider llmprovider.Provider + git Git + ui UI prompt string } -func NewApp(ctx context.Context, cfg *config.Config) (*App, error) { - provider, err := llmprovider.NewProvider(ctx, cfg) - if err != nil { - return nil, err - } - +func NewApp(provider llmprovider.Provider, gitClient Git, uiClient UI, prompt string) *App { return &App{ llmProvider: provider, - prompt: cfg.Prompt, - }, nil + git: gitClient, + ui: uiClient, + prompt: prompt, + } } -func (a *App) Run(ctx context.Context, userMessage string) error { +func (app *App) Run(ctx context.Context, userMessage string) error { s := spinner.New(spinner.CharSets[14], 100*time.Millisecond) s.Suffix = color.Render(" Running git diff") s.Start() defer s.Stop() - out, err := git.Diff() + diffOutput, err := app.git.Diff() if err != nil { return err } - if len(out) == 0 { + if len(diffOutput) == 0 { return errors.New("no changes are staged") } - s.Suffix = color.Sprintf(" Generating commit summary (using: %s)", a.llmProvider.Model()) - text := fmt.Sprintf(a.prompt, out) + s.Suffix = color.Sprintf(" Generating commit summary (using: %s)", app.llmProvider.Model()) + text := fmt.Sprintf(app.prompt, diffOutput) - message, err := a.llmProvider.Call(ctx, "", text) + message, err := app.llmProvider.Call(ctx, "", text) if err != nil { return err } @@ -68,13 +79,13 @@ func (a *App) Run(ctx context.Context, userMessage string) error { } wrapped = strings.ReplaceAll(wrapped, "\n\n\n", "\n\n") - edited, accepted, err := ui.TextArea(wrapped) + edited, accepted, err := app.ui.TextArea(wrapped) if err != nil { return err } if accepted { - return git.Commit(edited) + return app.git.Commit(edited) } else { color.Println("ABORTED!") return nil // Or a specific error for abortion diff --git a/internal/app/app_test.go b/internal/app/app_test.go new file mode 100644 index 0000000..1d4edf4 --- /dev/null +++ b/internal/app/app_test.go @@ -0,0 +1,162 @@ +package app + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockProvider struct { + modelName string + callFunc func(ctx context.Context, systemPrompt, userPrompt string) (string, error) +} + +func (m *mockProvider) Call(ctx context.Context, systemPrompt, userPrompt string) (string, error) { + return m.callFunc(ctx, systemPrompt, userPrompt) +} + +func (m *mockProvider) Model() string { + return m.modelName +} + +type mockGitClient struct { + DiffFunc func() (string, error) + CommitFunc func(message string) error +} + +func (m *mockGitClient) Diff() (string, error) { + return m.DiffFunc() +} + +func (m *mockGitClient) Commit(message string) error { + return m.CommitFunc(message) +} + +type mockUIClient struct { + TextAreaFunc func(value string) (string, bool, error) +} + +func (m *mockUIClient) TextArea(value string) (string, bool, error) { + return m.TextAreaFunc(value) +} + +func TestNewApp(t *testing.T) { + provider := &mockProvider{modelName: "test-model"} + gitClient := &mockGitClient{} + uiClient := &mockUIClient{} + + app := NewApp(provider, gitClient, uiClient, "test-prompt") + + assert.NotNil(t, app) + assert.Equal(t, "test-prompt", app.prompt) + assert.IsType(t, &mockProvider{}, app.llmProvider) + assert.IsType(t, &mockGitClient{}, app.git) + assert.IsType(t, &mockUIClient{}, app.ui) +} + +func TestAppRun(t *testing.T) { + ctx := context.Background() + + t.Run("DiffError", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "", assert.AnError }, + } + app := NewApp(&mockProvider{}, gitClient, &mockUIClient{}, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("NoStagedChanges", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "", nil }, + } + app := NewApp(&mockProvider{}, gitClient, &mockUIClient{}, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.EqualError(t, err, "no changes are staged") + }) + + t.Run("LLMCallError", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } + llmProvider := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "", assert.AnError }, + } + app := NewApp(llmProvider, gitClient, &mockUIClient{}, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("TextAreaError", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } + llmProvider := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "", false, assert.AnError }, + } + app := NewApp(llmProvider, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("UserAborted", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } + llmProvider := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "", false, nil }, + } + app := NewApp(llmProvider, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.NoError(t, err) + }) + + t.Run("CommitError", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + CommitFunc: func(message string) error { return assert.AnError }, + } + llmProvider := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + } + app := NewApp(llmProvider, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("Success", func(t *testing.T) { + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + CommitFunc: func(message string) error { return nil }, + } + llmProvider := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + } + app := NewApp(llmProvider, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.NoError(t, err) + }) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..9d23ce0 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,34 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoad(t *testing.T) { + t.Run("Defaults", func(t *testing.T) { + t.Setenv("LLM_PROVIDER", "") + t.Setenv("GEMINI_MODEL", "") + t.Setenv("OPENAI_MODEL", "") + + cfg, err := Load() + assert.NoError(t, err) + assert.Equal(t, "google", cfg.LLMProvider) + assert.Equal(t, "gemini-2.5-flash-preview-09-2025", cfg.Gemini.Model) + assert.Equal(t, "gpt-4o", cfg.OpenAI.Model) + assert.NotEmpty(t, cfg.Prompt) + }) + + t.Run("WithEnvironmentVariables", func(t *testing.T) { + t.Setenv("LLM_PROVIDER", "openai") + t.Setenv("GEMINI_MODEL", "gemini-pro") + t.Setenv("OPENAI_MODEL", "gpt-3.5-turbo") + + cfg, err := Load() + assert.NoError(t, err) + assert.Equal(t, "openai", cfg.LLMProvider) + assert.Equal(t, "gemini-pro", cfg.Gemini.Model) + assert.Equal(t, "gpt-3.5-turbo", cfg.OpenAI.Model) + }) +} diff --git a/internal/git/client.go b/internal/git/client.go new file mode 100644 index 0000000..e6166ed --- /dev/null +++ b/internal/git/client.go @@ -0,0 +1,15 @@ +package git + +type Client struct{} + +func (c *Client) Diff() (string, error) { + out, err := diff() + if err != nil { + return "", err + } + return string(out), nil +} + +func (c *Client) Commit(message string) error { + return commit(message) +} diff --git a/internal/git/commit.go b/internal/git/commit.go index 8f9219e..3ba191b 100644 --- a/internal/git/commit.go +++ b/internal/git/commit.go @@ -6,7 +6,7 @@ import ( "os/exec" ) -func Commit(message string) error { +func commit(message string) error { tmpfile, err := os.CreateTemp("", "gitmsg-*.txt") if err != nil { return err diff --git a/internal/git/diff.go b/internal/git/diff.go index 7ef1ff9..ea02f10 100644 --- a/internal/git/diff.go +++ b/internal/git/diff.go @@ -4,7 +4,7 @@ import ( "os/exec" ) -func Diff() ([]byte, error) { +func diff() ([]byte, error) { return exec.Command( "git", "--no-pager", diff --git a/internal/llm_provider/provider_test.go b/internal/llm_provider/provider_test.go new file mode 100644 index 0000000..cf58ce0 --- /dev/null +++ b/internal/llm_provider/provider_test.go @@ -0,0 +1,42 @@ +package llmprovider + +import ( + "context" + "testing" + + "github.com/rm-hull/git-commit-summary/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestNewProvider(t *testing.T) { + t.Run("GoogleProvider", func(t *testing.T) { + t.Setenv("GEMINI_API_KEY", "dummy-gemini-key") + cfg := &config.Config{ + LLMProvider: "google", + Gemini: config.GeminiConfig{Model: "gemini-test-model"}, + } + provider, err := NewProvider(context.Background(), cfg) + assert.NoError(t, err) + assert.IsType(t, &GoogleProvider{}, provider) + assert.Equal(t, "gemini-test-model", provider.Model()) + }) + + t.Run("OpenAIProvider", func(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "dummy-openai-key") + cfg := &config.Config{ + LLMProvider: "openai", + OpenAI: config.OpenAIConfig{Model: "openai-test-model"}, + } + provider, err := NewProvider(context.Background(), cfg) + assert.NoError(t, err) + assert.IsType(t, &OpenAiProvider{}, provider) + assert.Equal(t, "openai-test-model", provider.Model()) + }) + + t.Run("UnknownProvider", func(t *testing.T) { + cfg := &config.Config{LLMProvider: "unknown"} + _, err := NewProvider(context.Background(), cfg) + assert.Error(t, err) + assert.EqualError(t, err, "unknown LLM provider: unknown") + }) +} diff --git a/internal/ui/client.go b/internal/ui/client.go new file mode 100644 index 0000000..466b968 --- /dev/null +++ b/internal/ui/client.go @@ -0,0 +1,7 @@ +package ui + +type Client struct{} + +func (c *Client) TextArea(value string) (string, bool, error) { + return textArea(value) +} diff --git a/internal/ui/text_area.go b/internal/ui/text_area.go index 10206ef..aa936ae 100644 --- a/internal/ui/text_area.go +++ b/internal/ui/text_area.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/lipgloss" ) -func TextArea(value string) (string, bool, error) { +func textArea(value string) (string, bool, error) { p := tea.NewProgram(initialModel(value)) finalModel, err := p.Run() diff --git a/main.go b/main.go index f55127c..f96dd6a 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,9 @@ import ( "github.com/gookit/color" "github.com/rm-hull/git-commit-summary/internal/app" "github.com/rm-hull/git-commit-summary/internal/config" + "github.com/rm-hull/git-commit-summary/internal/git" + llmprovider "github.com/rm-hull/git-commit-summary/internal/llm_provider" + "github.com/rm-hull/git-commit-summary/internal/ui" "github.com/spf13/cobra" ) @@ -35,16 +38,16 @@ func main() { ctx := context.Background() - application, err := app.NewApp(ctx, cfg) - if err != nil { - handleError(err) - } + provider, err := llmprovider.NewProvider(ctx, cfg) + handleError(err) + + application := app.NewApp(provider, &git.Client{}, &ui.Client{}, cfg.Prompt) handleError(application.Run(ctx, userMessage)) }, } - rootCmd.PersistentFlags().BoolP("version", "v", false, "Display version information") rootCmd.PersistentFlags().StringVarP(&userMessage, "message", "m", "", "Append a message to the commit summary") + rootCmd.PersistentFlags().BoolP("version", "v", false, "Display version information") rootCmd.PersistentFlags().StringVarP(&llmProvider, "llm-provider", "", cfg.LLMProvider, "Use specific LLM provider, overrides environment variable LLM_PROVIDER") _ = rootCmd.Execute() From 6d40f5602ad9dc227edd642b8bd1c94a08f4d801 Mon Sep 17 00:00:00 2001 From: Richard Hull Date: Fri, 31 Oct 2025 00:54:37 +0000 Subject: [PATCH 2/2] refactor: Move spinner logic to UI interface The spinner implementation is now abstracted behind the `UIClient` interface, decoupling the core application logic from the underlying library (`github.com/briandowns/spinner`). This improves separation of concerns and testability. * Introduced `StartSpinner`, `UpdateSpinner`, and `StopSpinner` methods on `UIClient`. * Implemented spinner management in `internal/ui/client.go`. --- internal/app/app.go | 45 ++++++++-------- internal/app/app_test.go | 112 ++++++++++++++++++++++++++++----------- internal/ui/client.go | 29 +++++++++- 3 files changed, 132 insertions(+), 54 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 6e51d23..d127d6d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -5,9 +5,7 @@ import ( "errors" "fmt" "strings" - "time" - "github.com/briandowns/spinner" "github.com/galactixx/stringwrap" "github.com/gookit/color" "github.com/rm-hull/git-commit-summary/internal/git" @@ -15,64 +13,67 @@ import ( "github.com/rm-hull/git-commit-summary/internal/ui" ) -type Git interface { +type GitClient interface { Diff() (string, error) Commit(message string) error } -type UI interface { +// Verify that git.Client implements GitClient. +var _ GitClient = (*git.Client)(nil) + +type UIClient interface { TextArea(value string) (string, bool, error) + StartSpinner(message string) + UpdateSpinner(message string) + StopSpinner() } -// Verify that structs implement interfaces -var _ Git = (*git.Client)(nil) -var _ UI = (*ui.Client)(nil) +// Verify that ui.Client implements UIClient. +var _ UIClient = (*ui.Client)(nil) type App struct { llmProvider llmprovider.Provider - git Git - ui UI + git GitClient + ui UIClient prompt string } -func NewApp(provider llmprovider.Provider, gitClient Git, uiClient UI, prompt string) *App { +func NewApp(provider llmprovider.Provider, git GitClient, ui UIClient, prompt string) *App { return &App{ llmProvider: provider, - git: gitClient, - ui: uiClient, + git: git, + ui: ui, prompt: prompt, } } func (app *App) Run(ctx context.Context, userMessage string) error { - s := spinner.New(spinner.CharSets[14], 100*time.Millisecond) - s.Suffix = color.Render(" Running git diff") - s.Start() - defer s.Stop() + app.ui.StartSpinner(" Running git diff") + defer app.ui.StopSpinner() - diffOutput, err := app.git.Diff() + out, err := app.git.Diff() if err != nil { return err } - if len(diffOutput) == 0 { + if len(out) == 0 { return errors.New("no changes are staged") } - s.Suffix = color.Sprintf(" Generating commit summary (using: %s)", app.llmProvider.Model()) - text := fmt.Sprintf(app.prompt, diffOutput) + app.ui.UpdateSpinner(color.Sprintf(" Generating commit summary (using: %s)", app.llmProvider.Model())) + text := fmt.Sprintf(app.prompt, out) message, err := app.llmProvider.Call(ctx, "", text) if err != nil { return err } - s.Stop() - if userMessage != "" { message = fmt.Sprintf("%s\n\n%s", userMessage, message) } + app.ui.StopSpinner() + wrapped, _, err := stringwrap.StringWrap(message, 72, 4, false) if err != nil { return err diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 1d4edf4..d8af32a 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -34,13 +34,34 @@ func (m *mockGitClient) Commit(message string) error { } type mockUIClient struct { - TextAreaFunc func(value string) (string, bool, error) + TextAreaFunc func(value string) (string, bool, error) + StartSpinnerFunc func(message string) + UpdateSpinnerFunc func(message string) + StopSpinnerFunc func() } func (m *mockUIClient) TextArea(value string) (string, bool, error) { return m.TextAreaFunc(value) } +func (m *mockUIClient) StartSpinner(message string) { + if m.StartSpinnerFunc != nil { + m.StartSpinnerFunc(message) + } +} + +func (m *mockUIClient) UpdateSpinner(message string) { + if m.UpdateSpinnerFunc != nil { + m.UpdateSpinnerFunc(message) + } +} + +func (m *mockUIClient) StopSpinner() { + if m.StopSpinnerFunc != nil { + m.StopSpinnerFunc() + } +} + func TestNewApp(t *testing.T) { provider := &mockProvider{modelName: "test-model"} gitClient := &mockGitClient{} @@ -59,103 +80,132 @@ func TestAppRun(t *testing.T) { ctx := context.Background() t.Run("DiffError", func(t *testing.T) { + mp := &mockProvider{modelName: "test-model"} gitClient := &mockGitClient{ DiffFunc: func() (string, error) { return "", assert.AnError }, } - app := NewApp(&mockProvider{}, gitClient, &mockUIClient{}, "prompt") + uiClient := &mockUIClient{ + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.Error(t, err) assert.Equal(t, assert.AnError, err) }) t.Run("NoStagedChanges", func(t *testing.T) { + mp := &mockProvider{modelName: "test-model"} gitClient := &mockGitClient{ DiffFunc: func() (string, error) { return "", nil }, } - app := NewApp(&mockProvider{}, gitClient, &mockUIClient{}, "prompt") + uiClient := &mockUIClient{ + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.Error(t, err) assert.EqualError(t, err, "no changes are staged") }) t.Run("LLMCallError", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "", assert.AnError }, + } gitClient := &mockGitClient{ DiffFunc: func() (string, error) { return "diff output", nil }, } - llmProvider := &mockProvider{ - modelName: "test-model", - callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "", assert.AnError }, + uiClient := &mockUIClient{ + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, } - app := NewApp(llmProvider, gitClient, &mockUIClient{}, "prompt") + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.Error(t, err) assert.Equal(t, assert.AnError, err) }) t.Run("TextAreaError", func(t *testing.T) { - gitClient := &mockGitClient{ - DiffFunc: func() (string, error) { return "diff output", nil }, - } - llmProvider := &mockProvider{ + mp := &mockProvider{ modelName: "test-model", callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } uiClient := &mockUIClient{ - TextAreaFunc: func(value string) (string, bool, error) { return "", false, assert.AnError }, + TextAreaFunc: func(value string) (string, bool, error) { return "", false, assert.AnError }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, } - app := NewApp(llmProvider, gitClient, uiClient, "prompt") + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.Error(t, err) assert.Equal(t, assert.AnError, err) }) t.Run("UserAborted", func(t *testing.T) { - gitClient := &mockGitClient{ - DiffFunc: func() (string, error) { return "diff output", nil }, - } - llmProvider := &mockProvider{ + mp := &mockProvider{ modelName: "test-model", callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } uiClient := &mockUIClient{ - TextAreaFunc: func(value string) (string, bool, error) { return "", false, nil }, + TextAreaFunc: func(value string) (string, bool, error) { return "", false, nil }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, } - app := NewApp(llmProvider, gitClient, uiClient, "prompt") + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.NoError(t, err) }) t.Run("CommitError", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } gitClient := &mockGitClient{ DiffFunc: func() (string, error) { return "diff output", nil }, CommitFunc: func(message string) error { return assert.AnError }, } - llmProvider := &mockProvider{ - modelName: "test-model", - callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, - } uiClient := &mockUIClient{ - TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, } - app := NewApp(llmProvider, gitClient, uiClient, "prompt") + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.Error(t, err) assert.Equal(t, assert.AnError, err) }) t.Run("Success", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } gitClient := &mockGitClient{ DiffFunc: func() (string, error) { return "diff output", nil }, CommitFunc: func(message string) error { return nil }, } - llmProvider := &mockProvider{ - modelName: "test-model", - callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, - } uiClient := &mockUIClient{ - TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, } - app := NewApp(llmProvider, gitClient, uiClient, "prompt") + app := NewApp(mp, gitClient, uiClient, "prompt") err := app.Run(ctx, "") assert.NoError(t, err) }) diff --git a/internal/ui/client.go b/internal/ui/client.go index 466b968..f5af19c 100644 --- a/internal/ui/client.go +++ b/internal/ui/client.go @@ -1,7 +1,34 @@ package ui -type Client struct{} +import ( + "time" + + "github.com/briandowns/spinner" + "github.com/gookit/color" +) + +type Client struct { + spinner *spinner.Spinner +} func (c *Client) TextArea(value string) (string, bool, error) { return textArea(value) } + +func (c *Client) StartSpinner(message string) { + c.spinner = spinner.New(spinner.CharSets[14], 100*time.Millisecond) + c.spinner.Suffix = color.Render(message) + c.spinner.Start() +} + +func (c *Client) UpdateSpinner(message string) { + if c.spinner != nil { + c.spinner.Suffix = color.Render(message) + } +} + +func (c *Client) StopSpinner() { + if c.spinner != nil { + c.spinner.Stop() + } +}