diff --git a/.claude/settings.local.json b/.claude/settings.local.json index a7d68e20..de05ae46 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -8,6 +8,9 @@ "Bash(git commit:*)", "WebFetch(domain:expo.dev)", "WebSearch", + "Bash(find:*)", + "WebFetch(domain:github.com)", + "WebFetch(domain:raw.githubusercontent.com)" "Skill(superpowers:brainstorming)", "Bash(python3:*)" ], diff --git a/container/docs/docs.go b/container/docs/docs.go index 7533eb3c..5e8a5bbe 100644 --- a/container/docs/docs.go +++ b/container/docs/docs.go @@ -1315,6 +1315,78 @@ const docTemplate = `{ } } }, + "/v1/inference/status": { + "get": { + "description": "Check if local inference is available and get service information", + "produces": [ + "application/json" + ], + "tags": [ + "inference" + ], + "summary": "Get inference service status", + "responses": { + "200": { + "description": "Inference service status", + "schema": { + "$ref": "#/definitions/internal_handlers.InferenceStatusResponse" + } + } + } + } + }, + "/v1/inference/summarize": { + "post": { + "description": "Generate a short task summary and git branch name using local GGUF model", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "inference" + ], + "summary": "Summarize task and generate branch name", + "parameters": [ + { + "description": "Summarization request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/internal_handlers.SummarizeRequest" + } + } + ], + "responses": { + "200": { + "description": "Successfully generated summary and branch name", + "schema": { + "$ref": "#/definitions/internal_handlers.SummarizeResponse" + } + }, + "400": { + "description": "Invalid request", + "schema": { + "$ref": "#/definitions/fiber.Map" + } + }, + "500": { + "description": "Inference error", + "schema": { + "$ref": "#/definitions/fiber.Map" + } + }, + "503": { + "description": "Inference not available on this platform", + "schema": { + "$ref": "#/definitions/fiber.Map" + } + } + } + } + }, "/v1/notifications": { "post": { "description": "Sends a notification event to all connected SSE clients, including the TUI app which can display native macOS notifications", @@ -1846,6 +1918,10 @@ const docTemplate = `{ } }, "definitions": { + "fiber.Map": { + "type": "object", + "additionalProperties": true + }, "github_com_vanpelt_catnip_internal_models.ClaudeActivityState": { "type": "string", "enum": [ @@ -3187,6 +3263,37 @@ const docTemplate = `{ } } }, + "internal_handlers.InferenceStatusResponse": { + "description": "Status of the local inference service", + "type": "object", + "properties": { + "architecture": { + "description": "Architecture (amd64, arm64)", + "type": "string", + "example": "arm64" + }, + "available": { + "description": "Whether inference is available on this platform", + "type": "boolean", + "example": true + }, + "error": { + "description": "Error message if initialization failed", + "type": "string", + "example": "model not found" + }, + "modelPath": { + "description": "Model path if loaded", + "type": "string", + "example": "/Users/user/.catnip/models/gemma3-270m-summarizer-Q4_K_M.gguf" + }, + "platform": { + "description": "Platform name (darwin, linux, windows)", + "type": "string", + "example": "darwin" + } + } + }, "internal_handlers.NotificationPayload": { "type": "object", "properties": { @@ -3225,6 +3332,33 @@ const docTemplate = `{ "$ref": "#/definitions/internal_handlers.ActiveSessionInfo" } }, + "internal_handlers.SummarizeRequest": { + "description": "Request to summarize a task and generate a branch name", + "type": "object", + "properties": { + "prompt": { + "description": "Task description or code changes to summarize", + "type": "string", + "example": "Add user authentication with OAuth2" + } + } + }, + "internal_handlers.SummarizeResponse": { + "description": "Response containing task summary and suggested branch name", + "type": "object", + "properties": { + "branchName": { + "description": "Git branch name in kebab-case with category prefix", + "type": "string", + "example": "feat/add-user-auth" + }, + "summary": { + "description": "2-4 word summary in Title Case", + "type": "string", + "example": "Add User Auth" + } + } + }, "internal_handlers.UploadResponse": { "description": "Response containing upload status and file location", "type": "object", diff --git a/container/docs/swagger.json b/container/docs/swagger.json index 1cbb0ab6..53f93ac1 100644 --- a/container/docs/swagger.json +++ b/container/docs/swagger.json @@ -1312,6 +1312,78 @@ } } }, + "/v1/inference/status": { + "get": { + "description": "Check if local inference is available and get service information", + "produces": [ + "application/json" + ], + "tags": [ + "inference" + ], + "summary": "Get inference service status", + "responses": { + "200": { + "description": "Inference service status", + "schema": { + "$ref": "#/definitions/internal_handlers.InferenceStatusResponse" + } + } + } + } + }, + "/v1/inference/summarize": { + "post": { + "description": "Generate a short task summary and git branch name using local GGUF model", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "inference" + ], + "summary": "Summarize task and generate branch name", + "parameters": [ + { + "description": "Summarization request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/internal_handlers.SummarizeRequest" + } + } + ], + "responses": { + "200": { + "description": "Successfully generated summary and branch name", + "schema": { + "$ref": "#/definitions/internal_handlers.SummarizeResponse" + } + }, + "400": { + "description": "Invalid request", + "schema": { + "$ref": "#/definitions/fiber.Map" + } + }, + "500": { + "description": "Inference error", + "schema": { + "$ref": "#/definitions/fiber.Map" + } + }, + "503": { + "description": "Inference not available on this platform", + "schema": { + "$ref": "#/definitions/fiber.Map" + } + } + } + } + }, "/v1/notifications": { "post": { "description": "Sends a notification event to all connected SSE clients, including the TUI app which can display native macOS notifications", @@ -1843,6 +1915,10 @@ } }, "definitions": { + "fiber.Map": { + "type": "object", + "additionalProperties": true + }, "github_com_vanpelt_catnip_internal_models.ClaudeActivityState": { "type": "string", "enum": [ @@ -3184,6 +3260,37 @@ } } }, + "internal_handlers.InferenceStatusResponse": { + "description": "Status of the local inference service", + "type": "object", + "properties": { + "architecture": { + "description": "Architecture (amd64, arm64)", + "type": "string", + "example": "arm64" + }, + "available": { + "description": "Whether inference is available on this platform", + "type": "boolean", + "example": true + }, + "error": { + "description": "Error message if initialization failed", + "type": "string", + "example": "model not found" + }, + "modelPath": { + "description": "Model path if loaded", + "type": "string", + "example": "/Users/user/.catnip/models/gemma3-270m-summarizer-Q4_K_M.gguf" + }, + "platform": { + "description": "Platform name (darwin, linux, windows)", + "type": "string", + "example": "darwin" + } + } + }, "internal_handlers.NotificationPayload": { "type": "object", "properties": { @@ -3222,6 +3329,33 @@ "$ref": "#/definitions/internal_handlers.ActiveSessionInfo" } }, + "internal_handlers.SummarizeRequest": { + "description": "Request to summarize a task and generate a branch name", + "type": "object", + "properties": { + "prompt": { + "description": "Task description or code changes to summarize", + "type": "string", + "example": "Add user authentication with OAuth2" + } + } + }, + "internal_handlers.SummarizeResponse": { + "description": "Response containing task summary and suggested branch name", + "type": "object", + "properties": { + "branchName": { + "description": "Git branch name in kebab-case with category prefix", + "type": "string", + "example": "feat/add-user-auth" + }, + "summary": { + "description": "2-4 word summary in Title Case", + "type": "string", + "example": "Add User Auth" + } + } + }, "internal_handlers.UploadResponse": { "description": "Response containing upload status and file location", "type": "object", diff --git a/container/docs/swagger.yaml b/container/docs/swagger.yaml index 48cbe052..87be2c98 100644 --- a/container/docs/swagger.yaml +++ b/container/docs/swagger.yaml @@ -1,4 +1,7 @@ definitions: + fiber.Map: + additionalProperties: true + type: object github_com_vanpelt_catnip_internal_models.ClaudeActivityState: enum: - inactive @@ -1025,6 +1028,30 @@ definitions: example: feature/add-auth type: string type: object + internal_handlers.InferenceStatusResponse: + description: Status of the local inference service + properties: + architecture: + description: Architecture (amd64, arm64) + example: arm64 + type: string + available: + description: Whether inference is available on this platform + example: true + type: boolean + error: + description: Error message if initialization failed + example: model not found + type: string + modelPath: + description: Model path if loaded + example: /Users/user/.catnip/models/gemma3-270m-summarizer-Q4_K_M.gguf + type: string + platform: + description: Platform name (darwin, linux, windows) + example: darwin + type: string + type: object internal_handlers.NotificationPayload: properties: body: @@ -1050,6 +1077,26 @@ definitions: $ref: '#/definitions/internal_handlers.ActiveSessionInfo' description: Map of workspace paths to session information type: object + internal_handlers.SummarizeRequest: + description: Request to summarize a task and generate a branch name + properties: + prompt: + description: Task description or code changes to summarize + example: Add user authentication with OAuth2 + type: string + type: object + internal_handlers.SummarizeResponse: + description: Response containing task summary and suggested branch name + properties: + branchName: + description: Git branch name in kebab-case with category prefix + example: feat/add-user-auth + type: string + summary: + description: 2-4 word summary in Title Case + example: Add User Auth + type: string + type: object internal_handlers.UploadResponse: description: Response containing upload status and file location properties: @@ -2048,6 +2095,54 @@ paths: summary: Cleanup merged worktrees tags: - git + /v1/inference/status: + get: + description: Check if local inference is available and get service information + produces: + - application/json + responses: + "200": + description: Inference service status + schema: + $ref: '#/definitions/internal_handlers.InferenceStatusResponse' + summary: Get inference service status + tags: + - inference + /v1/inference/summarize: + post: + consumes: + - application/json + description: Generate a short task summary and git branch name using local GGUF + model + parameters: + - description: Summarization request + in: body + name: request + required: true + schema: + $ref: '#/definitions/internal_handlers.SummarizeRequest' + produces: + - application/json + responses: + "200": + description: Successfully generated summary and branch name + schema: + $ref: '#/definitions/internal_handlers.SummarizeResponse' + "400": + description: Invalid request + schema: + $ref: '#/definitions/fiber.Map' + "500": + description: Inference error + schema: + $ref: '#/definitions/fiber.Map' + "503": + description: Inference not available on this platform + schema: + $ref: '#/definitions/fiber.Map' + summary: Summarize task and generate branch name + tags: + - inference /v1/notifications: post: consumes: diff --git a/container/go.mod b/container/go.mod index 0ee1ac6c..8a63bcb2 100644 --- a/container/go.mod +++ b/container/go.mod @@ -19,6 +19,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 + github.com/hybridgroup/yzma v0.9.0 github.com/rs/zerolog v1.34.0 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 @@ -51,6 +52,7 @@ require ( github.com/cyphar/filepath-securejoin v0.4.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/ebitengine/purego v0.9.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/fasthttp/websocket v1.5.12 // indirect @@ -67,6 +69,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/jupiterrider/ffi v0.5.1 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect @@ -104,7 +107,7 @@ require ( golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect golang.org/x/mod v0.26.0 // indirect golang.org/x/sync v0.16.0 // indirect - golang.org/x/sys v0.34.0 // indirect + golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.27.0 // indirect golang.org/x/tools v0.35.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/container/go.sum b/container/go.sum index 667a242f..c6397e8e 100644 --- a/container/go.sum +++ b/container/go.sum @@ -101,6 +101,8 @@ github.com/disintegration/gift v1.2.1 h1:Y005a1X4Z7Uc+0gLpSAsKhWi4qLtsdEcMIbbdvd github.com/disintegration/gift v1.2.1/go.mod h1:Jh2i7f7Q2BM7Ezno3PhfezbR1xpUg9dUg3/RlKGr4HI= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A= +github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= @@ -186,6 +188,8 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY= github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= +github.com/hybridgroup/yzma v0.9.0 h1:r0MHUpqvElcpgboci/FaGuq1Z52W4tG6StLiZ+hNIOk= +github.com/hybridgroup/yzma v0.9.0/go.mod h1:0j0lGvdDPSe+WnwmCQJWep37K6htvK+VN7lL9NHQ1V4= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= @@ -194,6 +198,8 @@ github.com/jdkato/prose v1.2.1 h1:Fp3UnJmLVISmlc57BgKUzdjr0lOtjqTZicL3PaYy6cU= github.com/jdkato/prose v1.2.1/go.mod h1:AiRHgVagnEx2JbQRQowVBKjG0bcs/vtkGCH1dYAL1rA= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/jupiterrider/ffi v0.5.1 h1:l7ANXU+Ex33LilVa283HNaf/sTzCrrht7D05k6T6nlc= +github.com/jupiterrider/ffi v0.5.1/go.mod h1:x7xdNKo8h0AmLuXfswDUBxUsd2OqUP4ekC8sCnsmbvo= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -355,8 +361,8 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= diff --git a/container/internal/cmd/download.go b/container/internal/cmd/download.go new file mode 100644 index 00000000..70510b93 --- /dev/null +++ b/container/internal/cmd/download.go @@ -0,0 +1,149 @@ +package cmd + +import ( + "fmt" + goruntime "runtime" + + "github.com/spf13/cobra" + "github.com/vanpelt/catnip/internal/logger" + "github.com/vanpelt/catnip/internal/services" +) + +var downloadCmd = &cobra.Command{ + Use: "download", + Short: "📦 Download inference dependencies", + Hidden: true, + Long: `Download llama.cpp libraries and GGUF model for local inference. + +This command downloads: +- llama.cpp libraries for your platform (stored in ~/.catnip/lib) +- Gemma 270M summarizer model (stored in ~/.catnip/models) + +After running this command, inference will work offline without any additional downloads.`, + RunE: func(cmd *cobra.Command, args []string) error { + return runDownload(cmd) + }, +} + +func init() { + rootCmd.AddCommand(downloadCmd) + + // Add flags + downloadCmd.Flags().Bool("libraries-only", false, "Download only llama.cpp libraries") + downloadCmd.Flags().Bool("model-only", false, "Download only the GGUF model") + downloadCmd.Flags().Bool("force", false, "Force re-download even if files exist") +} + +func runDownload(cmd *cobra.Command) error { + librariesOnly, _ := cmd.Flags().GetBool("libraries-only") + modelOnly, _ := cmd.Flags().GetBool("model-only") + force, _ := cmd.Flags().GetBool("force") + + // Configure logging + logger.Configure(logger.LevelInfo, true) + + // Determine what to download + downloadLibraries := !modelOnly + downloadModel := !librariesOnly + + // Check platform support + if downloadLibraries { + if goruntime.GOOS != "darwin" && goruntime.GOOS != "linux" && goruntime.GOOS != "windows" { + logger.Warnf("⚠️ Inference not supported on %s, skipping library download", goruntime.GOOS) + downloadLibraries = false + } + } + + var libPath string + var modelPath string + + // Download libraries + if downloadLibraries { + logger.Infof("📚 Downloading llama.cpp libraries for %s/%s...", goruntime.GOOS, goruntime.GOARCH) + + downloader, err := services.NewLibraryDownloader() + if err != nil { + return fmt.Errorf("failed to create library downloader: %w", err) + } + + // Check if library exists + existingPath, _ := downloader.GetLibraryPath() + if existingPath != "" && !force { + logger.Infof("✅ Libraries already downloaded at: %s", existingPath) + logger.Infof(" Use --force to re-download") + libPath = existingPath + } else { + path, err := downloader.DownloadLibrary() + if err != nil { + return fmt.Errorf("failed to download libraries: %w", err) + } + libPath = path + logger.Infof("✅ Libraries installed at: %s", libPath) + } + } + + // Download model + if downloadModel { + logger.Infof("📦 Downloading GGUF model (Gemma 270M summarizer)...") + + downloader, err := services.NewModelDownloader() + if err != nil { + return fmt.Errorf("failed to create model downloader: %w", err) + } + + modelFilename := "gemma3-270m-summarizer-Q4_K_M.gguf" + modelURL := "https://huggingface.co/vanpelt/catnip-summarizer/resolve/main/gemma3-270m-summarizer-Q4_K_M.gguf" + + // Check if model exists + existingModelPath := downloader.GetModelPath(modelFilename) + if !force { + // Check if file exists and has reasonable size (> 100MB) + if info, err := services.StatFile(existingModelPath); err == nil && info.Size() > 100*1024*1024 { + logger.Infof("✅ Model already downloaded at: %s", existingModelPath) + logger.Infof(" Size: %.1f MB", float64(info.Size())/(1024*1024)) + logger.Infof(" Use --force to re-download") + modelPath = existingModelPath + } else { + // Model doesn't exist or is incomplete, download it + path, err := downloader.DownloadModel(modelURL, modelFilename, "") + if err != nil { + return fmt.Errorf("failed to download model: %w", err) + } + modelPath = path + + // Get file size for confirmation + if info, err := services.StatFile(modelPath); err == nil { + logger.Infof("✅ Model installed at: %s", modelPath) + logger.Infof(" Size: %.1f MB", float64(info.Size())/(1024*1024)) + } + } + } else { + // Force download + path, err := downloader.DownloadModel(modelURL, modelFilename, "") + if err != nil { + return fmt.Errorf("failed to download model: %w", err) + } + modelPath = path + + // Get file size for confirmation + if info, err := services.StatFile(modelPath); err == nil { + logger.Infof("✅ Model installed at: %s", modelPath) + logger.Infof(" Size: %.1f MB", float64(info.Size())/(1024*1024)) + } + } + } + + // Print summary + fmt.Println() + logger.Infof("🎉 Download complete!") + if downloadLibraries && libPath != "" { + logger.Infof(" Libraries: %s", libPath) + } + if downloadModel && modelPath != "" { + logger.Infof(" Model: %s", modelPath) + } + fmt.Println() + logger.Infof("💡 You can now use inference offline with 'catnip serve'") + + return nil +} diff --git a/container/internal/cmd/serve.go b/container/internal/cmd/serve.go index 416c9530..a0afc6e7 100644 --- a/container/internal/cmd/serve.go +++ b/container/internal/cmd/serve.go @@ -3,6 +3,7 @@ package cmd import ( "net/http/pprof" "os" + goruntime "runtime" "strings" "github.com/gofiber/fiber/v2" @@ -188,6 +189,25 @@ func startServer(cmd *cobra.Command) { claudeService.SetParserService(parserService) // For centralized session parsing parserService.SetClaudeService(claudeService) // For finding project directories + // Initialize inference service if enabled via CATNIP_INFERENCE=1 + var inferenceService *services.InferenceService + if os.Getenv("CATNIP_INFERENCE") == "1" { + inferenceConfig := services.InferenceConfig{ + ModelURL: "https://huggingface.co/vanpelt/catnip-summarizer/resolve/main/gemma3-270m-summarizer-Q4_K_M.gguf", + Checksum: "", // Optional checksum for verification + } + inferenceService = services.NewInferenceService(inferenceConfig) + + // Start background initialization (non-blocking) + go inferenceService.InitializeAsync() + + logger.Infof("🧠 Inference service enabled, downloading in background... (%s/%s)", goruntime.GOOS, goruntime.GOARCH) + } else { + logger.Debugf("🧠 Inference service disabled (set CATNIP_INFERENCE=1 to enable)") + } + + // Wire up SessionService to ClaudeService for best session file selection + claudeService.SetSessionService(sessionService) // Start parser service parserService.Start() @@ -242,6 +262,11 @@ func startServer(cmd *cobra.Command) { defer eventsHandler.Stop() portsHandler := handlers.NewPortsHandler(portMonitor).WithEvents(eventsHandler) proxyHandler := handlers.NewProxyHandler(portMonitor) + // Only create inference handler if service is enabled + var inferenceHandler *handlers.InferenceHandler + if inferenceService != nil { + inferenceHandler = handlers.NewInferenceHandler(inferenceService) + } // Connect events handler to GitService for worktree status events gitService.SetEventsHandler(eventsHandler) @@ -326,6 +351,12 @@ func startServer(cmd *cobra.Command) { v1.Post("/ports/mappings", portsHandler.SetPortMapping) v1.Delete("/ports/mappings/:port", portsHandler.DeletePortMapping) + // Inference routes (only if enabled via CATNIP_INFERENCE=1) + if inferenceHandler != nil { + v1.Post("/inference/summarize", inferenceHandler.HandleSummarize) + v1.Get("/inference/status", inferenceHandler.HandleInferenceStatus) + } + // Server info route v1.Get("/info", func(c *fiber.Ctx) error { commit, date, builtBy := GetBuildInfo() diff --git a/container/internal/cmd/summarize.go b/container/internal/cmd/summarize.go new file mode 100644 index 00000000..374c5444 --- /dev/null +++ b/container/internal/cmd/summarize.go @@ -0,0 +1,88 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "github.com/vanpelt/catnip/internal/logger" + "github.com/vanpelt/catnip/internal/services" +) + +var summarizeCmd = &cobra.Command{ + Use: "summarize [prompt]", + Short: "🧠 Generate task summary and branch name", + Hidden: true, + Long: `Generate a task summary and git branch name using local inference. + +This command uses the local Gemma 270M model to generate: +- A concise 2-4 word task summary (Title Case) +- A git branch name (kebab-case with category prefix) + +The prompt can be provided as arguments or via the --prompt flag. + +Examples: + catnip summarize "Add user authentication with OAuth2" + catnip summarize --prompt "Fix login bug on mobile devices" + catnip summarize Add dark mode toggle to settings`, + RunE: func(cmd *cobra.Command, args []string) error { + return runSummarize(cmd, args) + }, +} + +func init() { + rootCmd.AddCommand(summarizeCmd) + + // Add flags + summarizeCmd.Flags().StringP("prompt", "p", "", "Task description to summarize") +} + +func runSummarize(cmd *cobra.Command, args []string) error { + // Configure logging (quieter for CLI usage) + logger.Configure(logger.LevelWarn, true) + + // Get prompt from flag or args + promptFlag, _ := cmd.Flags().GetString("prompt") + var prompt string + + if promptFlag != "" { + prompt = promptFlag + } else if len(args) > 0 { + prompt = strings.Join(args, " ") + } else { + return fmt.Errorf("prompt required: provide via arguments or --prompt flag") + } + + fmt.Printf("🧠 Generating summary for: %s\n\n", prompt) + + // Initialize inference service + inferenceConfig := services.InferenceConfig{ + ModelURL: "https://huggingface.co/vanpelt/catnip-summarizer/resolve/main/gemma3-270m-summarizer-Q4_K_M.gguf", + Checksum: "", + } + + inferenceService := services.NewInferenceService(inferenceConfig) + + // Run initialization synchronously for CLI usage + inferenceService.InitializeAsync() + + // Check if initialization succeeded + if !inferenceService.IsReady() { + state, message, _ := inferenceService.GetStatus() + return fmt.Errorf("failed to initialize inference service: %s (%s)\n\nTry running: catnip download", message, state) + } + + // Run inference + result, err := inferenceService.Summarize(prompt) + if err != nil { + return fmt.Errorf("inference failed: %w", err) + } + + // Print results + fmt.Println("📝 Summary:") + fmt.Printf(" %s\n\n", result.Summary) + fmt.Println("🌿 Branch name:") + fmt.Printf(" %s\n", result.BranchName) + + return nil +} diff --git a/container/internal/handlers/inference.go b/container/internal/handlers/inference.go new file mode 100644 index 00000000..e592e67e --- /dev/null +++ b/container/internal/handlers/inference.go @@ -0,0 +1,152 @@ +package handlers + +import ( + "fmt" + "runtime" + + "github.com/gofiber/fiber/v2" + "github.com/vanpelt/catnip/internal/logger" + "github.com/vanpelt/catnip/internal/services" +) + +// InferenceHandler handles local GGUF model inference requests +type InferenceHandler struct { + service *services.InferenceService +} + +// NewInferenceHandler creates a new inference handler +func NewInferenceHandler(service *services.InferenceService) *InferenceHandler { + return &InferenceHandler{ + service: service, + } +} + +// SummarizeRequest represents a summarization request +// @Description Request to summarize a task and generate a branch name +type SummarizeRequest struct { + // Task description or code changes to summarize + Prompt string `json:"prompt" example:"Add user authentication with OAuth2"` +} + +// SummarizeResponse represents a summarization response +// @Description Response containing task summary and suggested branch name +type SummarizeResponse struct { + // 2-4 word summary in Title Case + Summary string `json:"summary" example:"Add User Auth"` + // Git branch name in kebab-case with category prefix + BranchName string `json:"branchName" example:"feat/add-user-auth"` +} + +// InferenceStatusResponse represents the inference service status +// @Description Status of the local inference service +type InferenceStatusResponse struct { + // Whether inference is ready for requests + Available bool `json:"available" example:"true"` + // Current status: initializing, ready, failed + Status string `json:"status" example:"ready"` + // Human-readable status message + Message string `json:"message,omitempty" example:"Inference service ready"` + // Download progress (when initializing) + Progress *services.DownloadProgress `json:"progress,omitempty"` + // Platform name (darwin, linux, windows) + Platform string `json:"platform" example:"darwin"` + // Architecture (amd64, arm64) + Architecture string `json:"architecture" example:"arm64"` +} + +// HandleSummarize godoc +// @Summary Summarize task and generate branch name +// @Description Generate a short task summary and git branch name using local GGUF model +// @Tags inference +// @Accept json +// @Produce json +// @Param request body SummarizeRequest true "Summarization request" +// @Success 200 {object} SummarizeResponse "Successfully generated summary and branch name" +// @Failure 400 {object} fiber.Map "Invalid request" +// @Failure 500 {object} fiber.Map "Inference error" +// @Failure 503 {object} fiber.Map "Inference not available on this platform" +// @Router /v1/inference/summarize [post] +func (h *InferenceHandler) HandleSummarize(c *fiber.Ctx) error { + // Check if service is available and ready + if h.service == nil { + return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{ + "error": "Inference service not configured", + }) + } + + // Check if service is ready + if !h.service.IsReady() { + state, message, progress := h.service.GetStatus() + return c.Status(fiber.StatusServiceUnavailable).JSON(fiber.Map{ + "error": fmt.Sprintf("Inference service not ready: %s", message), + "status": string(state), + "progress": progress, + }) + } + + // Parse request + var req SummarizeRequest + if err := c.BodyParser(&req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "Invalid request body", + }) + } + + // Validate prompt + if req.Prompt == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "Prompt is required", + }) + } + + logger.Debugf("🧠 Inference request: %s", req.Prompt) + + // Generate summary + result, err := h.service.Summarize(req.Prompt) + if err != nil { + logger.Errorf("Inference error: %v", err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": fmt.Sprintf("Failed to generate summary: %v", err), + }) + } + + logger.Debugf("✅ Inference result: summary=%s, branch=%s", result.Summary, result.BranchName) + + return c.JSON(SummarizeResponse{ + Summary: result.Summary, + BranchName: result.BranchName, + }) +} + +// HandleInferenceStatus godoc +// @Summary Get inference service status +// @Description Check if local inference is available and get service information +// @Tags inference +// @Produce json +// @Success 200 {object} InferenceStatusResponse "Inference service status" +// @Router /v1/inference/status [get] +func (h *InferenceHandler) HandleInferenceStatus(c *fiber.Ctx) error { + resp := InferenceStatusResponse{ + Platform: runtime.GOOS, + Architecture: runtime.GOARCH, + } + + if h.service == nil { + resp.Available = false + resp.Status = "disabled" + resp.Message = "Inference service not configured" + return c.JSON(resp) + } + + state, message, progress := h.service.GetStatus() + resp.Available = h.service.IsReady() + resp.Status = string(state) + resp.Message = message + + // Include progress if still initializing + if state == services.InferenceStateInitializing { + resp.Progress = &progress + } + + return c.JSON(resp) +} diff --git a/container/internal/services/downloader.go b/container/internal/services/downloader.go new file mode 100644 index 00000000..92259723 --- /dev/null +++ b/container/internal/services/downloader.go @@ -0,0 +1,456 @@ +package services + +import ( + "archive/zip" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" +) + +// ModelDownloader handles downloading and verifying GGUF models +type ModelDownloader struct { + cacheDir string +} + +// NewModelDownloader creates a new model downloader instance +func NewModelDownloader() (*ModelDownloader, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + cacheDir := filepath.Join(homeDir, ".catnip", "models") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create cache directory: %w", err) + } + + return &ModelDownloader{ + cacheDir: cacheDir, + }, nil +} + +// DownloadModel downloads a model from the given URL to the cache directory +// Returns the path to the downloaded model file +func (d *ModelDownloader) DownloadModel(url, filename, expectedChecksum string) (string, error) { + destPath := filepath.Join(d.cacheDir, filename) + + // Check if model already exists and is valid + if _, err := os.Stat(destPath); err == nil { + // File exists, verify checksum + if expectedChecksum != "" { + if valid, err := d.verifyChecksum(destPath, expectedChecksum); err == nil && valid { + return destPath, nil + } + } else { + // No checksum to verify, assume file is good + return destPath, nil + } + } + + // Download to temporary file + tmpPath := destPath + ".tmp" + if err := d.downloadFile(url, tmpPath); err != nil { + os.Remove(tmpPath) // Clean up on error + return "", fmt.Errorf("failed to download model: %w", err) + } + + // Verify checksum if provided + if expectedChecksum != "" { + valid, err := d.verifyChecksum(tmpPath, expectedChecksum) + if err != nil { + os.Remove(tmpPath) + return "", fmt.Errorf("failed to verify checksum: %w", err) + } + if !valid { + os.Remove(tmpPath) + return "", fmt.Errorf("checksum verification failed") + } + } + + // Atomic rename + if err := os.Rename(tmpPath, destPath); err != nil { + os.Remove(tmpPath) + return "", fmt.Errorf("failed to save model: %w", err) + } + + return destPath, nil +} + +// downloadFile downloads a file from the given URL with progress reporting +func (d *ModelDownloader) downloadFile(url, destPath string) error { + // Create the file + out, err := os.Create(destPath) + if err != nil { + return err + } + defer out.Close() + + // Get the data + resp, err := http.Get(url) //nolint:gosec // URL comes from trusted config + if err != nil { + return err + } + defer resp.Body.Close() + + // Check server response + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status: %s", resp.Status) + } + + // Writer with progress reporting + totalBytes := resp.ContentLength + + // Create a reader that reports progress + reader := &progressReader{ + reader: resp.Body, + total: totalBytes, + onProgress: func(current, total int64) { + if total > 0 { + percent := float64(current) / float64(total) * 100 + fmt.Printf("\rDownloading model: %.1f%% (%d/%d MB)", + percent, + current/(1024*1024), + total/(1024*1024)) + } + }, + } + + // Write the body to file + _, err = io.Copy(out, reader) + if err != nil { + return err + } + + fmt.Println() // New line after progress + return nil +} + +// verifyChecksum verifies the SHA256 checksum of a file +func (d *ModelDownloader) verifyChecksum(filePath, expectedChecksum string) (bool, error) { + file, err := os.Open(filePath) + if err != nil { + return false, err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return false, err + } + + actualChecksum := hex.EncodeToString(hash.Sum(nil)) + return actualChecksum == expectedChecksum, nil +} + +// GetModelPath returns the path where a model with the given filename would be stored +func (d *ModelDownloader) GetModelPath(filename string) string { + return filepath.Join(d.cacheDir, filename) +} + +// progressReader wraps an io.Reader to report progress +type progressReader struct { + reader io.Reader + total int64 + current int64 + onProgress func(current, total int64) +} + +func (pr *progressReader) Read(p []byte) (int, error) { + n, err := pr.reader.Read(p) + pr.current += int64(n) + if pr.onProgress != nil { + pr.onProgress(pr.current, pr.total) + } + return n, err +} + +// LibraryDownloader handles downloading llama.cpp libraries +type LibraryDownloader struct { + libDir string +} + +// NewLibraryDownloader creates a new library downloader instance +func NewLibraryDownloader() (*LibraryDownloader, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + libDir := filepath.Join(homeDir, ".catnip", "lib") + if err := os.MkdirAll(libDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create lib directory: %w", err) + } + + return &LibraryDownloader{ + libDir: libDir, + }, nil +} + +// DownloadLibrary downloads the llama.cpp library for the current platform +// Returns the path to the main library file (libllama.dylib, libllama.so, etc.) +func (d *LibraryDownloader) DownloadLibrary() (string, error) { + // Determine platform-specific details + osName, archName, libExt, err := d.getPlatformInfo() + if err != nil { + return "", err + } + + // Check if library already exists + libPath := filepath.Join(d.libDir, osName, archName, "libllama"+libExt) + if _, err := os.Stat(libPath); err == nil { + // Library already exists + return libPath, nil + } + + // Get latest llama.cpp release info + releaseTag, downloadURL, err := d.getLlamaCppRelease(osName, archName) + if err != nil { + return "", fmt.Errorf("failed to get llama.cpp release: %w", err) + } + + fmt.Printf("📦 Downloading llama.cpp %s for %s/%s...\n", releaseTag, osName, archName) + + // Download archive + tmpFile := filepath.Join(d.libDir, "llama-cpp-tmp.zip") + if err := d.downloadFileWithProgress(downloadURL, tmpFile); err != nil { + os.Remove(tmpFile) + return "", fmt.Errorf("failed to download library: %w", err) + } + defer os.Remove(tmpFile) + + // Extract to platform-specific directory + extractDir := filepath.Join(d.libDir, osName, archName) + if err := os.MkdirAll(extractDir, 0755); err != nil { + return "", fmt.Errorf("failed to create extract directory: %w", err) + } + + if err := d.extractZip(tmpFile, extractDir); err != nil { + return "", fmt.Errorf("failed to extract archive: %w", err) + } + + fmt.Println("✅ llama.cpp libraries installed successfully") + + // Return path to main library + return libPath, nil +} + +// getPlatformInfo returns OS name, architecture, and library extension for the current platform +func (d *LibraryDownloader) getPlatformInfo() (osName, archName, libExt string, err error) { + switch runtime.GOOS { + case "darwin": + osName = "macos" + libExt = ".dylib" + case "linux": + osName = "ubuntu" // llama.cpp releases use "ubuntu" for Linux + libExt = ".so" + case "windows": + osName = "win" + libExt = ".dll" + default: + return "", "", "", fmt.Errorf("unsupported OS: %s", runtime.GOOS) + } + + switch runtime.GOARCH { + case "amd64": + archName = "x64" + case "arm64": + archName = "arm64" + default: + return "", "", "", fmt.Errorf("unsupported architecture: %s", runtime.GOARCH) + } + + return osName, archName, libExt, nil +} + +// getLlamaCppRelease fetches the latest llama.cpp release info from GitHub +func (d *LibraryDownloader) getLlamaCppRelease(osName, archName string) (tag, downloadURL string, err error) { + // Get latest release from GitHub API + resp, err := http.Get("https://api.github.com/repos/ggml-org/llama.cpp/releases/latest") + if err != nil { + return "", "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("GitHub API returned status: %s", resp.Status) + } + + // Parse response to find the right asset + // We're looking for patterns like: + // - llama-{tag}-bin-macos-arm64.zip + // - llama-{tag}-bin-ubuntu-x64.zip + // - llama-{tag}-bin-win-cpu-x64.zip + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", err + } + + bodyStr := string(body) + + // Extract tag_name + tagStart := strings.Index(bodyStr, `"tag_name":"`) + if tagStart == -1 { + return "", "", fmt.Errorf("could not find tag_name in GitHub response") + } + tagStart += len(`"tag_name":"`) + tagEnd := strings.Index(bodyStr[tagStart:], `"`) + if tagEnd == -1 { + return "", "", fmt.Errorf("could not parse tag_name") + } + tag = bodyStr[tagStart : tagStart+tagEnd] + + // Build expected filename pattern + var pattern string + switch osName { + case "macos": + pattern = fmt.Sprintf("llama-%s-bin-macos-%s.zip", tag, archName) + case "ubuntu": + pattern = fmt.Sprintf("llama-%s-bin-ubuntu-%s.zip", tag, archName) + case "win": + pattern = fmt.Sprintf("llama-%s-bin-win-cpu-%s.zip", tag, archName) + } + + // Find download URL in browser_download_url fields + searchStr := fmt.Sprintf(`"browser_download_url":"https://github.com/ggml-org/llama.cpp/releases/download/%s/%s"`, tag, pattern) + urlStart := strings.Index(bodyStr, searchStr) + if urlStart == -1 { + return "", "", fmt.Errorf("could not find download URL for %s", pattern) + } + + urlStart += len(`"browser_download_url":"`) + urlEnd := strings.Index(bodyStr[urlStart:], `"`) + downloadURL = bodyStr[urlStart : urlStart+urlEnd] + + return tag, downloadURL, nil +} + +// downloadFileWithProgress downloads a file with progress reporting +func (d *LibraryDownloader) downloadFileWithProgress(url, destPath string) error { + out, err := os.Create(destPath) + if err != nil { + return err + } + defer out.Close() + + resp, err := http.Get(url) //nolint:gosec // URL from trusted GitHub API + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status: %s", resp.Status) + } + + totalBytes := resp.ContentLength + reader := &progressReader{ + reader: resp.Body, + total: totalBytes, + onProgress: func(current, total int64) { + if total > 0 { + percent := float64(current) / float64(total) * 100 + fmt.Printf("\rDownloading: %.1f%% (%d/%d MB)", + percent, + current/(1024*1024), + total/(1024*1024)) + } + }, + } + + _, err = io.Copy(out, reader) + if err != nil { + return err + } + + fmt.Println() // New line after progress + return nil +} + +// extractZip extracts a zip file to the destination directory +func (d *LibraryDownloader) extractZip(zipPath, destDir string) error { + reader, err := zip.OpenReader(zipPath) + if err != nil { + return err + } + defer reader.Close() + + for _, file := range reader.File { + // Only extract files in build/bin/ directory (where the libraries are) + if !strings.Contains(file.Name, "build/bin/") { + continue + } + + // Get the filename relative to build/bin/ + parts := strings.Split(file.Name, "build/bin/") + if len(parts) != 2 { + continue + } + filename := parts[1] + + // Skip directories and non-library files + if file.FileInfo().IsDir() || filename == "" { + continue + } + + // Only extract .dylib, .so, .dll files + if !strings.HasSuffix(filename, ".dylib") && + !strings.HasSuffix(filename, ".so") && + !strings.HasSuffix(filename, ".dll") { + continue + } + + // Create destination path + destPath := filepath.Join(destDir, filename) + + // Extract file + if err := d.extractFile(file, destPath); err != nil { + return fmt.Errorf("failed to extract %s: %w", filename, err) + } + } + + return nil +} + +// extractFile extracts a single file from a zip archive +func (d *LibraryDownloader) extractFile(file *zip.File, destPath string) error { + // Open source file + src, err := file.Open() + if err != nil { + return err + } + defer src.Close() + + // Create destination file + dest, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) + if err != nil { + return err + } + defer dest.Close() + + // Copy contents + // We only extract .dylib/.so/.dll files from trusted GitHub releases + _, err = io.Copy(dest, src) //nolint:gosec // Trusted source (GitHub llama.cpp releases) + return err +} + +// GetLibraryPath returns the path where the library for the current platform would be stored +func (d *LibraryDownloader) GetLibraryPath() (string, error) { + osName, archName, libExt, err := d.getPlatformInfo() + if err != nil { + return "", err + } + + return filepath.Join(d.libDir, osName, archName, "libllama"+libExt), nil +} + +// StatFile is a helper function to get file info (exported for use in cmd package) +func StatFile(path string) (os.FileInfo, error) { + return os.Stat(path) +} diff --git a/container/internal/services/inference.go b/container/internal/services/inference.go new file mode 100644 index 00000000..1b4b46ae --- /dev/null +++ b/container/internal/services/inference.go @@ -0,0 +1,446 @@ +package services + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/hybridgroup/yzma/pkg/llama" + "github.com/vanpelt/catnip/internal/logger" +) + +// InferenceState represents the current state of the inference service +type InferenceState string + +// Inference service states +const ( + InferenceStateInitializing InferenceState = "initializing" + InferenceStateReady InferenceState = "ready" + InferenceStateFailed InferenceState = "failed" + InferenceStateDisabled InferenceState = "disabled" +) + +// DownloadProgress tracks the progress of library and model downloads +type DownloadProgress struct { + LibraryPercent int `json:"library"` + ModelPercent int `json:"model"` + CurrentStep string `json:"step"` // "library", "model", "loading" +} + +// InferenceService handles local GGUF model inference using llama.cpp +type InferenceService struct { + modelPath string + libraryPath string + model llama.Model + mu sync.Mutex + initialized bool + + // State management + state atomic.Value // InferenceState + stateMessage string + stateMu sync.RWMutex + progress DownloadProgress + progressMu sync.RWMutex + + // Configuration for async init + config InferenceConfig + maxRetries int +} + +// InferenceConfig holds configuration for the inference service +type InferenceConfig struct { + ModelPath string + LibraryPath string + ModelURL string + Checksum string +} + +// NewInferenceService creates a new inference service instance (non-blocking) +func NewInferenceService(config InferenceConfig) *InferenceService { + svc := &InferenceService{ + modelPath: config.ModelPath, + libraryPath: config.LibraryPath, + config: config, + maxRetries: 3, + } + svc.state.Store(InferenceStateInitializing) + svc.setStateMessage("Waiting to start initialization...") + return svc +} + +// InitializeAsync starts the background initialization process +func (s *InferenceService) InitializeAsync() { + var lastErr error + + for attempt := 1; attempt <= s.maxRetries; attempt++ { + if attempt > 1 { + // Exponential backoff: 2s, 4s, 8s + backoff := time.Duration(1<user\n{{ $.System }}\n{{ .Content }}\nmodel\n + fullPrompt := "user\n" + systemPrompt + "\n" + prompt + "\nmodel\n" + + // Tokenize the formatted prompt + // CRITICAL FIX: Must add special tokens (BOS) for Gemma to work correctly + addSpecial := true + parseSpecial := true + tokens := llama.Tokenize(vocab, fullPrompt, addSpecial, parseSpecial) + + // Create batch + batch := llama.BatchGetOne(tokens) + + // Setup sampler chain with parameters from Modelfile + samplerParams := llama.SamplerChainDefaultParams() + sampler := llama.SamplerChainInit(samplerParams) + defer llama.SamplerFree(sampler) // Clean up sampler when done + + // Add samplers matching llama.cpp's common_sampler_init order + // Correct order: TOP_K → TOP_P → TYPICAL_P → TEMPERATURE → PENALTIES → Dist + llama.SamplerChainAdd(sampler, llama.SamplerInitTopK(64)) // top_k=64 (from Modelfile) + llama.SamplerChainAdd(sampler, llama.SamplerInitTopP(0.95, 1)) // top_p=0.95 (from Modelfile) + llama.SamplerChainAdd(sampler, llama.SamplerInitTypical(1.0, 1)) // typical_p=1.0 (Ollama default, min_keep=1) + llama.SamplerChainAdd(sampler, llama.SamplerInitTempExt(0.8, 0.0, 1.0)) // temp=0.8 (Ollama default) + llama.SamplerChainAdd(sampler, llama.SamplerInitPenalties(64, 1.1, 0.0, 0.0)) // repeat_penalty=1.1, repeat_last_n=64 + + // Use random seed for variability (Ollama generates new seed per request) + seed := uint32(time.Now().UnixMicro() & 0xFFFFFFFF) //nolint:gosec // Safe: intentional truncation for seed + llama.SamplerChainAdd(sampler, llama.SamplerInitDist(seed)) + + // Generate tokens + maxTokens := int32(128) // Limit generation + var output strings.Builder + buf := make([]byte, 36) // Buffer for token text + newlineCount := 0 + + for pos := int32(0); pos < maxTokens; pos++ { + // Decode batch + llama.Decode(ctx, batch) + + // Sample next token + token := llama.SamplerSample(sampler, ctx, -1) + + // Check for end of generation (EOS token) + if llama.VocabIsEOG(vocab, token) { + break + } + + // Convert token to text + tokenLen := llama.TokenToPiece(vocab, token, buf, 0, true) + if tokenLen > 0 { + output.Write(buf[:tokenLen]) + } + + // Check for stop sequences + currentOutput := output.String() + + // Stop at (from Modelfile) + if strings.Contains(currentOutput, "") { + // Remove the stop sequence from output + currentOutput = strings.Split(currentOutput, "")[0] + output.Reset() + output.WriteString(currentOutput) + break + } + + // Count newlines - stop after we have 2 complete lines + // (We want exactly: Line1\nLine2\n) + if tokenLen > 0 && buf[0] == '\n' { + newlineCount++ + // Stop after 2 newlines (which gives us 2 lines of content) + if newlineCount >= 2 { + break + } + } + + // Create next batch with single token + batch = llama.BatchGetOne([]llama.Token{token}) + } + + // Get raw output + rawOutput := output.String() + + // Parse output into summary and branch name + return s.parseOutput(rawOutput) +} + +// parseOutput parses the model output into summary and branch name +func (s *InferenceService) parseOutput(output string) (*SummarizeResponse, error) { + lines := strings.Split(strings.TrimSpace(output), "\n") + + // Find first two non-empty lines + var summary, branchName string + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if summary == "" { + summary = line + } else if branchName == "" { + branchName = line + break + } + } + + if summary == "" || branchName == "" { + return nil, fmt.Errorf("invalid output format: expected 2 lines, got: %s", output) + } + + return &SummarizeResponse{ + Summary: summary, + BranchName: branchName, + }, nil +} + +// Close frees resources +func (s *InferenceService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Note: yzma doesn't expose model cleanup in current API + s.initialized = false + return nil +} diff --git a/container/internal/services/stderr_unix.go b/container/internal/services/stderr_unix.go new file mode 100644 index 00000000..a99086d5 --- /dev/null +++ b/container/internal/services/stderr_unix.go @@ -0,0 +1,70 @@ +//go:build unix + +package services + +import ( + "os" + "sync" + + "golang.org/x/sys/unix" +) + +// Stderr redirection state +var ( + savedStderrFd = -1 + stderrSuppressed bool + suppressMutex sync.Mutex +) + +// suppressStderr redirects stderr (fd 2) to /dev/null to silence llama.cpp's verbose output +func suppressStderr() { + suppressMutex.Lock() + defer suppressMutex.Unlock() + + if stderrSuppressed { + return + } + + // Open /dev/null + devNull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + if err != nil { + return // If we can't open /dev/null, just continue with normal stderr + } + + // Save the original stderr file descriptor by duplicating it + savedStderrFd, err = unix.Dup(int(os.Stderr.Fd())) + if err != nil { + devNull.Close() + return + } + + // Redirect stderr (fd 2) to /dev/null using dup2 + // unix.Dup2 works on all Unix platforms including Linux arm64 + err = unix.Dup2(int(devNull.Fd()), int(os.Stderr.Fd())) + if err != nil { + unix.Close(savedStderrFd) + devNull.Close() + return + } + + devNull.Close() // We can close devNull now, the fd is duplicated to stderr + stderrSuppressed = true +} + +// restoreStderr restores the original stderr file descriptor +func restoreStderr() { + suppressMutex.Lock() + defer suppressMutex.Unlock() + + if !stderrSuppressed || savedStderrFd < 0 { + return + } + + // Restore stderr by duplicating the saved fd back to fd 2 + _ = unix.Dup2(savedStderrFd, int(os.Stderr.Fd())) + + // Close the saved fd + unix.Close(savedStderrFd) + savedStderrFd = -1 + stderrSuppressed = false +} diff --git a/container/internal/tui/initialization_commands.go b/container/internal/tui/initialization_commands.go index d209e979..91c40b20 100644 --- a/container/internal/tui/initialization_commands.go +++ b/container/internal/tui/initialization_commands.go @@ -56,10 +56,14 @@ func semverCompare(a, b string) int { } parts := strings.Split(s, ".") result := make([]int, 3) - for i := 0; i < 3 && i < len(parts); i++ { + // Limit to first 3 parts + if len(parts) > 3 { + parts = parts[:3] + } + for i, part := range parts { // best-effort parse n := 0 - for _, ch := range parts[i] { + for _, ch := range part { if ch >= '0' && ch <= '9' { n = n*10 + int(ch-'0') } else {