Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions backend/pkg/controller/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ type assistantWorkerCtx struct {
flowID int64
fw FlowWorker

// prompter is an optional reusable prompter for the user. When non-nil,
// LoadAssistantWorker uses it instead of issuing another GetUserPrompts
// query and merging defaults. LoadFlowWorker sets this so a flow load
// with multiple assistants only pays the DB+merge cost once.
prompter templates.Prompter

flowWorkerCtx
}

Expand Down Expand Up @@ -158,7 +164,10 @@ func NewAssistantWorker(ctx context.Context, awc newAssistantWorkerCtx) (Assista
return nil, wrapErrorEndSpan(ctx, assistantSpan, "failed to create flow assistant log worker", err)
}

prompter := templates.NewDefaultPrompter() // TODO: change to flow prompter by userID from DB
prompter, err := newUserPrompter(ctx, awc.db, awc.userID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, assistantSpan, "failed to build user prompter", err)
}
executor, err := tools.NewFlowToolsExecutor(awc.db, awc.cfg, awc.docker, awc.functions, awc.userID, awc.flowID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, assistantSpan, "failed to create flow tools executor", err)
Expand Down Expand Up @@ -327,7 +336,13 @@ func LoadAssistantWorker(
return nil, wrapErrorEndSpan(ctx, assistantSpan, "failed to create flow assistant log worker", err)
}

prompter := templates.NewDefaultPrompter() // TODO: change to flow prompter by userID from DB
prompter := awc.prompter
if prompter == nil {
prompter, err = newUserPrompter(ctx, awc.db, awc.userID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, assistantSpan, "failed to build user prompter", err)
}
}
executor, err := tools.NewFlowToolsExecutor(awc.db, awc.cfg, awc.docker, functions, awc.userID, awc.flowID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, assistantSpan, "failed to create flow tools executor", err)
Expand Down
12 changes: 9 additions & 3 deletions backend/pkg/controller/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"pentagi/pkg/providers/pconfig"
"pentagi/pkg/providers/provider"
"pentagi/pkg/resources"
"pentagi/pkg/templates"
"pentagi/pkg/tools"

dockercontainer "github.com/docker/docker/api/types/container"
Expand Down Expand Up @@ -181,7 +180,10 @@ func NewFlowWorker(
flowSpan := observation.Span(langfuse.WithSpanName("prepare flow worker"))
ctx, _ = flowSpan.Observation(ctx)

prompter := templates.NewDefaultPrompter() // TODO: change to flow prompter by userID from DB
prompter, err := newUserPrompter(ctx, fwc.db, fwc.userID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, flowSpan, "failed to build user prompter", err)
}
executor, err := tools.NewFlowToolsExecutor(fwc.db, fwc.cfg, fwc.docker, fwc.functions, fwc.userID, flow.ID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, flowSpan, "failed to create flow tools executor", err)
Expand Down Expand Up @@ -352,7 +354,10 @@ func LoadFlowWorker(ctx context.Context, flow database.Flow, fwc flowWorkerCtx)
return nil, wrapErrorEndSpan(ctx, flowSpan, "failed to unmarshal functions", err)
}

prompter := templates.NewDefaultPrompter() // TODO: change to flow prompter by userID from DB
prompter, err := newUserPrompter(ctx, fwc.db, flow.UserID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, flowSpan, "failed to build user prompter", err)
}
Comment thread
mason5052 marked this conversation as resolved.
executor, err := tools.NewFlowToolsExecutor(fwc.db, fwc.cfg, fwc.docker, functions, flow.UserID, flow.ID)
if err != nil {
return nil, wrapErrorEndSpan(ctx, flowSpan, "failed to create flow tools executor", err)
Expand Down Expand Up @@ -445,6 +450,7 @@ func LoadFlowWorker(ctx context.Context, flow database.Flow, fwc flowWorkerCtx)
awc := assistantWorkerCtx{
userID: flow.UserID,
flowID: flow.ID,
prompter: prompter,
fw: fw,
flowWorkerCtx: fwc,
}
Expand Down
50 changes: 50 additions & 0 deletions backend/pkg/controller/prompter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package controller

import (
"context"
"fmt"

"pentagi/pkg/database"
"pentagi/pkg/templates"
)

// newUserPrompter loads the user's custom prompts from the database and
// overlays them onto the compiled default templates. Prompt types that
// the user has not customized continue to use the defaults. A database
// error is returned to the caller so that session creation fails
// explicitly instead of silently falling back to defaults.
func newUserPrompter(ctx context.Context, db database.Querier, userID int64) (templates.Prompter, error) {
userPrompts, err := db.GetUserPrompts(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to load user prompts: %w", err)
}

defaults, err := templates.LoadDefaultPromptsMap()
if err != nil {
return nil, fmt.Errorf("failed to load default templates: %w", err)
}

return buildUserPrompter(defaults, userPrompts), nil
}

// buildUserPrompter is the pure merge step extracted from newUserPrompter so
// it can be unit-tested without a database fake or filesystem access. It
// mutates the supplied defaults map by overlaying each non-empty user
// override on top, then returns a Prompter backed by that map. Callers must
// pass a fresh map (e.g., from templates.LoadDefaultPromptsMap) so the
// embedded defaults are not modified.
func buildUserPrompter(defaults templates.PromptsMap, userPrompts []database.Prompt) templates.Prompter {
for _, p := range userPrompts {
if p.Prompt == "" {
// The Prompts UI uses delete (or reset, which writes the
// default body back) to remove a customization, so an empty
// body is unexpected. Skip it instead of clobbering the
// default with an empty string that would later surface as
// ErrTemplateNotFound deep inside agent rendering.
continue
}
defaults[templates.PromptType(p.Type)] = p.Prompt
}

return templates.NewFlowPrompter(defaults)
}
209 changes: 209 additions & 0 deletions backend/pkg/controller/prompter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package controller

import (
"context"
"errors"
"testing"

"pentagi/pkg/database"
"pentagi/pkg/templates"
)

// fakeQuerier satisfies database.Querier by embedding the interface
// (so the unused methods stay nil) and overrides only GetUserPrompts.
// Calling any other method would panic, which is exactly what we want
// for a unit test that should not touch unrelated DB code.
type fakeQuerier struct {
database.Querier
prompts []database.Prompt
err error
}

func (f *fakeQuerier) GetUserPrompts(ctx context.Context, userID int64) ([]database.Prompt, error) {
if f.err != nil {
return nil, f.err
}
return f.prompts, nil
}

// getDefaultTemplate returns the compiled default body for a prompt
// type, used to compare overridden vs. preserved prompts in tests.
func getDefaultTemplate(t *testing.T, pt templates.PromptType) string {
t.Helper()
body, err := templates.NewDefaultPrompter().GetTemplate(pt)
if err != nil {
t.Fatalf("default prompter has no template for %q: %v", pt, err)
}
return body
}

// loadDefaults returns a fresh PromptsMap of the embedded default templates so
// each test mutates its own map without leaking state between cases.
func loadDefaults(t *testing.T) templates.PromptsMap {
t.Helper()
defaults, err := templates.LoadDefaultPromptsMap()
if err != nil {
t.Fatalf("LoadDefaultPromptsMap error: %v", err)
}
return defaults
}

func TestBuildUserPrompter_NoUserPrompts(t *testing.T) {
prompter := buildUserPrompter(loadDefaults(t), nil)

for _, pt := range []templates.PromptType{
templates.PromptTypePrimaryAgent,
templates.PromptTypeAssistant,
} {
got, err := prompter.GetTemplate(pt)
if err != nil {
t.Fatalf("GetTemplate(%q) error: %v", pt, err)
}
want := getDefaultTemplate(t, pt)
if got != want {
t.Errorf("prompt %q: expected default body when user has no overrides, got divergent body", pt)
}
}
}

func TestBuildUserPrompter_SingleOverride(t *testing.T) {
const customBody = "custom primary agent prompt body"
userPrompts := []database.Prompt{
{Type: database.PromptTypePrimaryAgent, Prompt: customBody},
}

prompter := buildUserPrompter(loadDefaults(t), userPrompts)

got, err := prompter.GetTemplate(templates.PromptTypePrimaryAgent)
if err != nil {
t.Fatalf("GetTemplate(primary_agent) error: %v", err)
}
if got != customBody {
t.Errorf("primary_agent: expected custom body %q, got %q", customBody, got)
}

// Spot-check that an unrelated type still resolves to the default.
for _, pt := range []templates.PromptType{
templates.PromptTypeAssistant,
templates.PromptTypePentester,
} {
got, err := prompter.GetTemplate(pt)
if err != nil {
t.Fatalf("GetTemplate(%q) error: %v", pt, err)
}
if got != getDefaultTemplate(t, pt) {
t.Errorf("prompt %q should still match default after a single unrelated override", pt)
}
}
}

func TestBuildUserPrompter_PartialOverridesPreserveDefaults(t *testing.T) {
overrides := map[database.PromptType]string{
database.PromptTypePrimaryAgent: "custom primary",
database.PromptTypeCoder: "custom coder",
database.PromptTypeReporter: "custom reporter",
}

userPrompts := make([]database.Prompt, 0, len(overrides))
for pt, body := range overrides {
userPrompts = append(userPrompts, database.Prompt{Type: pt, Prompt: body})
}

prompter := buildUserPrompter(loadDefaults(t), userPrompts)

for pt, want := range overrides {
got, err := prompter.GetTemplate(templates.PromptType(pt))
if err != nil {
t.Fatalf("GetTemplate(%q) error: %v", pt, err)
}
if got != want {
t.Errorf("prompt %q: expected custom body %q, got %q", pt, want, got)
}
}

// Types that were not overridden must still match the defaults.
for _, pt := range []templates.PromptType{
templates.PromptTypeAssistant,
templates.PromptTypePentester,
templates.PromptTypeSearcher,
} {
got, err := prompter.GetTemplate(pt)
if err != nil {
t.Fatalf("GetTemplate(%q) error: %v", pt, err)
}
if got != getDefaultTemplate(t, pt) {
t.Errorf("prompt %q should still match default after partial overrides", pt)
}
}
}

func TestBuildUserPrompter_EmptyBodyIsIgnored(t *testing.T) {
userPrompts := []database.Prompt{
{Type: database.PromptTypePrimaryAgent, Prompt: ""},
}

prompter := buildUserPrompter(loadDefaults(t), userPrompts)

got, err := prompter.GetTemplate(templates.PromptTypePrimaryAgent)
if err != nil {
t.Fatalf("GetTemplate(primary_agent) error: %v", err)
}
if got != getDefaultTemplate(t, templates.PromptTypePrimaryAgent) {
t.Errorf("primary_agent: empty user body must not clobber default")
}
}

func TestNewUserPrompter_HappyPath(t *testing.T) {
const customAssistant = "custom assistant prompt"
const customCoder = "custom coder prompt"

db := &fakeQuerier{
prompts: []database.Prompt{
{Type: database.PromptTypeAssistant, Prompt: customAssistant},
{Type: database.PromptTypeCoder, Prompt: customCoder},
},
}

prompter, err := newUserPrompter(context.Background(), db, 42)
if err != nil {
t.Fatalf("newUserPrompter error: %v", err)
}

for pt, want := range map[templates.PromptType]string{
templates.PromptTypeAssistant: customAssistant,
templates.PromptTypeCoder: customCoder,
} {
got, err := prompter.GetTemplate(pt)
if err != nil {
t.Fatalf("GetTemplate(%q) error: %v", pt, err)
}
if got != want {
t.Errorf("prompt %q: expected custom body %q, got %q", pt, want, got)
}
}

// Sanity check that an un-customized type still falls back.
got, err := prompter.GetTemplate(templates.PromptTypePentester)
if err != nil {
t.Fatalf("GetTemplate(pentester) error: %v", err)
}
if got != getDefaultTemplate(t, templates.PromptTypePentester) {
t.Errorf("pentester: expected default body when user has no override")
}
}

func TestNewUserPrompter_DBErrorPropagates(t *testing.T) {
sentinel := errors.New("db connection lost")
db := &fakeQuerier{err: sentinel}

prompter, err := newUserPrompter(context.Background(), db, 42)
if err == nil {
t.Fatalf("expected error, got nil prompter=%v", prompter)
}
if prompter != nil {
t.Errorf("expected nil prompter on DB error, got %v", prompter)
}
if !errors.Is(err, sentinel) {
t.Errorf("expected wrapped sentinel error, got %v", err)
}
}
28 changes: 21 additions & 7 deletions backend/pkg/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,31 @@ func (dp *defaultPrompter) RenderTemplate(promptType PromptType, params any) (st
}

func (dp *defaultPrompter) DumpTemplates() ([]byte, error) {
promptsMap, err := LoadDefaultPromptsMap()
if err != nil {
return nil, err
}

blob, err := json.Marshal(promptsMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal templates: %w", err)
}

return blob, nil
}

// LoadDefaultPromptsMap returns a freshly populated PromptsMap of the embedded
// default templates. Callers that need to overlay overrides on top of the
// defaults can mutate the returned map directly, without going through the
// JSON dump path used by Prompter.DumpTemplates(). Each call returns a new
// map, so mutations do not affect subsequent callers.
func LoadDefaultPromptsMap() (PromptsMap, error) {
prompts, err := promptTemplates.ReadDir("prompts")
if err != nil {
return nil, fmt.Errorf("failed to read templates: %w", err)
}

promptsMap := make(PromptsMap)
promptsMap := make(PromptsMap, len(prompts))
for _, prompt := range prompts {
promptBytes, err := promptTemplates.ReadFile(path.Join("prompts", prompt.Name()))
if err != nil {
Expand All @@ -682,12 +701,7 @@ func (dp *defaultPrompter) DumpTemplates() ([]byte, error) {
promptsMap[PromptType(promptName)] = string(promptBytes)
}

blob, err := json.Marshal(promptsMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal templates: %w", err)
}

return blob, nil
return promptsMap, nil
}

func RenderPrompt(name, prompt string, params any) (string, error) {
Expand Down