diff --git a/TODO.md b/TODO.md index 2edbbfb..8286f29 100644 --- a/TODO.md +++ b/TODO.md @@ -50,5 +50,5 @@ A `Makefile` would automate common development tasks, such as building, testing, The application currently has no unit tests. -- **Add unit tests for the core application logic:** This will help to ensure that the application is working correctly and prevent regressions. -- **Use a testing framework, such as `testify`:** A testing framework will make it easier to write and run tests. +~~- **Add unit tests for the core application logic:** This will help to ensure that the application is working correctly and prevent regressions.~~ +- ~~**Use a testing framework, such as `testify`:** A testing framework will make it easier to write and run tests.~~ diff --git a/go.mod b/go.mod index 626885f..42976ef 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gookit/color v1.6.0 github.com/joho/godotenv v1.5.1 github.com/spf13/cobra v1.10.1 + github.com/stretchr/testify v1.11.1 google.golang.org/genai v1.33.0 ) @@ -23,6 +24,7 @@ require ( github.com/charmbracelet/x/term v0.2.1 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/galactixx/ansiwalker v1.0.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect @@ -34,6 +36,7 @@ require ( github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.16.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/tidwall/gjson v1.18.0 // indirect @@ -41,6 +44,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) require ( diff --git a/go.sum b/go.sum index 8f5a1fd..f220dfd 100644 --- a/go.sum +++ b/go.sum @@ -86,8 +86,11 @@ github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= @@ -117,6 +120,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= @@ -210,6 +215,7 @@ google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94U google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/app/app.go b/internal/app/app.go index 767b088..d127d6d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -5,41 +5,53 @@ import ( "errors" "fmt" "strings" - "time" - "github.com/briandowns/spinner" "github.com/galactixx/stringwrap" "github.com/gookit/color" - "github.com/rm-hull/git-commit-summary/internal/config" "github.com/rm-hull/git-commit-summary/internal/git" llmprovider "github.com/rm-hull/git-commit-summary/internal/llm_provider" "github.com/rm-hull/git-commit-summary/internal/ui" ) +type GitClient interface { + Diff() (string, error) + Commit(message string) error +} + +// Verify that git.Client implements GitClient. +var _ GitClient = (*git.Client)(nil) + +type UIClient interface { + TextArea(value string) (string, bool, error) + StartSpinner(message string) + UpdateSpinner(message string) + StopSpinner() +} + +// Verify that ui.Client implements UIClient. +var _ UIClient = (*ui.Client)(nil) + type App struct { llmProvider llmprovider.Provider + git GitClient + ui UIClient prompt string } -func NewApp(ctx context.Context, cfg *config.Config) (*App, error) { - provider, err := llmprovider.NewProvider(ctx, cfg) - if err != nil { - return nil, err - } - +func NewApp(provider llmprovider.Provider, git GitClient, ui UIClient, prompt string) *App { return &App{ llmProvider: provider, - prompt: cfg.Prompt, - }, nil + git: git, + ui: ui, + prompt: prompt, + } } -func (a *App) Run(ctx context.Context, userMessage string) error { - s := spinner.New(spinner.CharSets[14], 100*time.Millisecond) - s.Suffix = color.Render(" Running git diff") - s.Start() - defer s.Stop() +func (app *App) Run(ctx context.Context, userMessage string) error { + app.ui.StartSpinner(" Running git diff") + defer app.ui.StopSpinner() - out, err := git.Diff() + out, err := app.git.Diff() if err != nil { return err } @@ -48,33 +60,33 @@ func (a *App) Run(ctx context.Context, userMessage string) error { return errors.New("no changes are staged") } - s.Suffix = color.Sprintf(" Generating commit summary (using: %s)", a.llmProvider.Model()) - text := fmt.Sprintf(a.prompt, out) + app.ui.UpdateSpinner(color.Sprintf(" Generating commit summary (using: %s)", app.llmProvider.Model())) + text := fmt.Sprintf(app.prompt, out) - message, err := a.llmProvider.Call(ctx, "", text) + message, err := app.llmProvider.Call(ctx, "", text) if err != nil { return err } - s.Stop() - if userMessage != "" { message = fmt.Sprintf("%s\n\n%s", userMessage, message) } + app.ui.StopSpinner() + wrapped, _, err := stringwrap.StringWrap(message, 72, 4, false) if err != nil { return err } wrapped = strings.ReplaceAll(wrapped, "\n\n\n", "\n\n") - edited, accepted, err := ui.TextArea(wrapped) + edited, accepted, err := app.ui.TextArea(wrapped) if err != nil { return err } if accepted { - return git.Commit(edited) + return app.git.Commit(edited) } else { color.Println("ABORTED!") return nil // Or a specific error for abortion diff --git a/internal/app/app_test.go b/internal/app/app_test.go new file mode 100644 index 0000000..d8af32a --- /dev/null +++ b/internal/app/app_test.go @@ -0,0 +1,212 @@ +package app + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockProvider struct { + modelName string + callFunc func(ctx context.Context, systemPrompt, userPrompt string) (string, error) +} + +func (m *mockProvider) Call(ctx context.Context, systemPrompt, userPrompt string) (string, error) { + return m.callFunc(ctx, systemPrompt, userPrompt) +} + +func (m *mockProvider) Model() string { + return m.modelName +} + +type mockGitClient struct { + DiffFunc func() (string, error) + CommitFunc func(message string) error +} + +func (m *mockGitClient) Diff() (string, error) { + return m.DiffFunc() +} + +func (m *mockGitClient) Commit(message string) error { + return m.CommitFunc(message) +} + +type mockUIClient struct { + TextAreaFunc func(value string) (string, bool, error) + StartSpinnerFunc func(message string) + UpdateSpinnerFunc func(message string) + StopSpinnerFunc func() +} + +func (m *mockUIClient) TextArea(value string) (string, bool, error) { + return m.TextAreaFunc(value) +} + +func (m *mockUIClient) StartSpinner(message string) { + if m.StartSpinnerFunc != nil { + m.StartSpinnerFunc(message) + } +} + +func (m *mockUIClient) UpdateSpinner(message string) { + if m.UpdateSpinnerFunc != nil { + m.UpdateSpinnerFunc(message) + } +} + +func (m *mockUIClient) StopSpinner() { + if m.StopSpinnerFunc != nil { + m.StopSpinnerFunc() + } +} + +func TestNewApp(t *testing.T) { + provider := &mockProvider{modelName: "test-model"} + gitClient := &mockGitClient{} + uiClient := &mockUIClient{} + + app := NewApp(provider, gitClient, uiClient, "test-prompt") + + assert.NotNil(t, app) + assert.Equal(t, "test-prompt", app.prompt) + assert.IsType(t, &mockProvider{}, app.llmProvider) + assert.IsType(t, &mockGitClient{}, app.git) + assert.IsType(t, &mockUIClient{}, app.ui) +} + +func TestAppRun(t *testing.T) { + ctx := context.Background() + + t.Run("DiffError", func(t *testing.T) { + mp := &mockProvider{modelName: "test-model"} + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "", assert.AnError }, + } + uiClient := &mockUIClient{ + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("NoStagedChanges", func(t *testing.T) { + mp := &mockProvider{modelName: "test-model"} + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "", nil }, + } + uiClient := &mockUIClient{ + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.EqualError(t, err, "no changes are staged") + }) + + t.Run("LLMCallError", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "", assert.AnError }, + } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } + uiClient := &mockUIClient{ + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("TextAreaError", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "", false, assert.AnError }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("UserAborted", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "", false, nil }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.NoError(t, err) + }) + + t.Run("CommitError", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + CommitFunc: func(message string) error { return assert.AnError }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("Success", func(t *testing.T) { + mp := &mockProvider{ + modelName: "test-model", + callFunc: func(ctx context.Context, systemPrompt, userPrompt string) (string, error) { return "llm message", nil }, + } + gitClient := &mockGitClient{ + DiffFunc: func() (string, error) { return "diff output", nil }, + CommitFunc: func(message string) error { return nil }, + } + uiClient := &mockUIClient{ + TextAreaFunc: func(value string) (string, bool, error) { return "edited message", true, nil }, + StartSpinnerFunc: func(message string) {}, + UpdateSpinnerFunc: func(message string) {}, + StopSpinnerFunc: func() {}, + } + app := NewApp(mp, gitClient, uiClient, "prompt") + err := app.Run(ctx, "") + assert.NoError(t, err) + }) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..9d23ce0 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,34 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoad(t *testing.T) { + t.Run("Defaults", func(t *testing.T) { + t.Setenv("LLM_PROVIDER", "") + t.Setenv("GEMINI_MODEL", "") + t.Setenv("OPENAI_MODEL", "") + + cfg, err := Load() + assert.NoError(t, err) + assert.Equal(t, "google", cfg.LLMProvider) + assert.Equal(t, "gemini-2.5-flash-preview-09-2025", cfg.Gemini.Model) + assert.Equal(t, "gpt-4o", cfg.OpenAI.Model) + assert.NotEmpty(t, cfg.Prompt) + }) + + t.Run("WithEnvironmentVariables", func(t *testing.T) { + t.Setenv("LLM_PROVIDER", "openai") + t.Setenv("GEMINI_MODEL", "gemini-pro") + t.Setenv("OPENAI_MODEL", "gpt-3.5-turbo") + + cfg, err := Load() + assert.NoError(t, err) + assert.Equal(t, "openai", cfg.LLMProvider) + assert.Equal(t, "gemini-pro", cfg.Gemini.Model) + assert.Equal(t, "gpt-3.5-turbo", cfg.OpenAI.Model) + }) +} diff --git a/internal/git/client.go b/internal/git/client.go new file mode 100644 index 0000000..e6166ed --- /dev/null +++ b/internal/git/client.go @@ -0,0 +1,15 @@ +package git + +type Client struct{} + +func (c *Client) Diff() (string, error) { + out, err := diff() + if err != nil { + return "", err + } + return string(out), nil +} + +func (c *Client) Commit(message string) error { + return commit(message) +} diff --git a/internal/git/commit.go b/internal/git/commit.go index 8f9219e..3ba191b 100644 --- a/internal/git/commit.go +++ b/internal/git/commit.go @@ -6,7 +6,7 @@ import ( "os/exec" ) -func Commit(message string) error { +func commit(message string) error { tmpfile, err := os.CreateTemp("", "gitmsg-*.txt") if err != nil { return err diff --git a/internal/git/diff.go b/internal/git/diff.go index 7ef1ff9..ea02f10 100644 --- a/internal/git/diff.go +++ b/internal/git/diff.go @@ -4,7 +4,7 @@ import ( "os/exec" ) -func Diff() ([]byte, error) { +func diff() ([]byte, error) { return exec.Command( "git", "--no-pager", diff --git a/internal/llm_provider/provider_test.go b/internal/llm_provider/provider_test.go new file mode 100644 index 0000000..cf58ce0 --- /dev/null +++ b/internal/llm_provider/provider_test.go @@ -0,0 +1,42 @@ +package llmprovider + +import ( + "context" + "testing" + + "github.com/rm-hull/git-commit-summary/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestNewProvider(t *testing.T) { + t.Run("GoogleProvider", func(t *testing.T) { + t.Setenv("GEMINI_API_KEY", "dummy-gemini-key") + cfg := &config.Config{ + LLMProvider: "google", + Gemini: config.GeminiConfig{Model: "gemini-test-model"}, + } + provider, err := NewProvider(context.Background(), cfg) + assert.NoError(t, err) + assert.IsType(t, &GoogleProvider{}, provider) + assert.Equal(t, "gemini-test-model", provider.Model()) + }) + + t.Run("OpenAIProvider", func(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "dummy-openai-key") + cfg := &config.Config{ + LLMProvider: "openai", + OpenAI: config.OpenAIConfig{Model: "openai-test-model"}, + } + provider, err := NewProvider(context.Background(), cfg) + assert.NoError(t, err) + assert.IsType(t, &OpenAiProvider{}, provider) + assert.Equal(t, "openai-test-model", provider.Model()) + }) + + t.Run("UnknownProvider", func(t *testing.T) { + cfg := &config.Config{LLMProvider: "unknown"} + _, err := NewProvider(context.Background(), cfg) + assert.Error(t, err) + assert.EqualError(t, err, "unknown LLM provider: unknown") + }) +} diff --git a/internal/ui/client.go b/internal/ui/client.go new file mode 100644 index 0000000..f5af19c --- /dev/null +++ b/internal/ui/client.go @@ -0,0 +1,34 @@ +package ui + +import ( + "time" + + "github.com/briandowns/spinner" + "github.com/gookit/color" +) + +type Client struct { + spinner *spinner.Spinner +} + +func (c *Client) TextArea(value string) (string, bool, error) { + return textArea(value) +} + +func (c *Client) StartSpinner(message string) { + c.spinner = spinner.New(spinner.CharSets[14], 100*time.Millisecond) + c.spinner.Suffix = color.Render(message) + c.spinner.Start() +} + +func (c *Client) UpdateSpinner(message string) { + if c.spinner != nil { + c.spinner.Suffix = color.Render(message) + } +} + +func (c *Client) StopSpinner() { + if c.spinner != nil { + c.spinner.Stop() + } +} diff --git a/internal/ui/text_area.go b/internal/ui/text_area.go index 10206ef..aa936ae 100644 --- a/internal/ui/text_area.go +++ b/internal/ui/text_area.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/lipgloss" ) -func TextArea(value string) (string, bool, error) { +func textArea(value string) (string, bool, error) { p := tea.NewProgram(initialModel(value)) finalModel, err := p.Run() diff --git a/main.go b/main.go index f55127c..f96dd6a 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,9 @@ import ( "github.com/gookit/color" "github.com/rm-hull/git-commit-summary/internal/app" "github.com/rm-hull/git-commit-summary/internal/config" + "github.com/rm-hull/git-commit-summary/internal/git" + llmprovider "github.com/rm-hull/git-commit-summary/internal/llm_provider" + "github.com/rm-hull/git-commit-summary/internal/ui" "github.com/spf13/cobra" ) @@ -35,16 +38,16 @@ func main() { ctx := context.Background() - application, err := app.NewApp(ctx, cfg) - if err != nil { - handleError(err) - } + provider, err := llmprovider.NewProvider(ctx, cfg) + handleError(err) + + application := app.NewApp(provider, &git.Client{}, &ui.Client{}, cfg.Prompt) handleError(application.Run(ctx, userMessage)) }, } - rootCmd.PersistentFlags().BoolP("version", "v", false, "Display version information") rootCmd.PersistentFlags().StringVarP(&userMessage, "message", "m", "", "Append a message to the commit summary") + rootCmd.PersistentFlags().BoolP("version", "v", false, "Display version information") rootCmd.PersistentFlags().StringVarP(&llmProvider, "llm-provider", "", cfg.LLMProvider, "Use specific LLM provider, overrides environment variable LLM_PROVIDER") _ = rootCmd.Execute()