From 679819b4f6e91ff02a87b0fa4040b9eaf9d361a6 Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Mon, 27 Apr 2026 17:27:05 +0530 Subject: [PATCH] =?UTF-8?q?@=20=20feat(ws):=20add=20result-channel=20outpu?= =?UTF-8?q?t=20streaming,=20ACK=20handling,=20and=20replay=20=E2=94=82=20~?= =?UTF-8?q?=20=20-=20Add=20ResultChannel=20interface=20and=20integrate=20i?= =?UTF-8?q?nto=20task=20polling=20runner=20for=20=20=20=20=20=20stdout/std?= =?UTF-8?q?err=20capture=20with=20interval/byte-threshold=20flush=20=20=20?= =?UTF-8?q?=20-=20Add=20WS=20client=20methods=20for=20SendOutput,=20SendFi?= =?UTF-8?q?nal,=20sendIfActive,=20replayUnacked=20=20=20=20-=20Wire=20resu?= =?UTF-8?q?lt=20outbox=20into=20WS=20client=20on=20startup=20via=20newDefa?= =?UTF-8?q?ultWebSocketRuntime=20=20=20=20-=20Extend=20ErrorPayload=20with?= =?UTF-8?q?=20HighestAcceptedSequence=20for=20retryable=20gap=20errors=20?= =?UTF-8?q?=20=20=20-=20Add=20ExecutionAttemptID=20to=20domain=20Task=20mo?= =?UTF-8?q?del=20=20=20=20-=20Add=20tests=20for=20ack-driven=20outbox=20re?= =?UTF-8?q?moval,=20reconnect=20replay,=20throttled=20=20=20=20=20=20captu?= =?UTF-8?q?re,=20retryable=20error=20handling,=20and=20HTTP=20fallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 7 +- app/jobs/taskjob/result_channel_test.go | 239 +++++++++++++++++++++++ app/jobs/taskjob/taskjob.go | 185 +++++++++++++++++- app/services/wsclient/client.go | 153 ++++++++++++++- app/services/wsclient/client_test.go | 182 +++++++++++++++++ domain/task/task.go | 21 +- internal/wsprotocol/ack_sequence.go | 27 +-- internal/wsprotocol/ack_sequence_test.go | 12 +- main.go | 17 +- 9 files changed, 807 insertions(+), 36 deletions(-) create mode 100644 app/jobs/taskjob/result_channel_test.go diff --git a/AGENTS.md b/AGENTS.md index 643ec7d..e344707 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1 +1,6 @@ -- Use jj instead of git +- Use jj instead of git. Check with `ls .jj/` — if no `.jj` directory exists, fall back to git. + +- **Branching:** `jj describe -m "type(scope): message"` on `@`, then `jj bookmark create ` and `jj git push -b `. +- **Before push:** `jj git fetch && jj rebase -b @ -o main`. Resolve conflicts locally, verify with tests, then push. +- **Conflict resolution:** `jj resolve --list` to find conflicts, edit files to remove markers (`<<<<<<<` / `>>>>>>>`), then squash resolution with `jj squash` (no `--interactive` flag). No detached HEAD or rebase-in-progress state to manage. +- **Undo:** `jj op undo` reverts any operation. Safe to experiment. diff --git a/app/jobs/taskjob/result_channel_test.go b/app/jobs/taskjob/result_channel_test.go new file mode 100644 index 0000000..477fc51 --- /dev/null +++ b/app/jobs/taskjob/result_channel_test.go @@ -0,0 +1,239 @@ +package taskjob + +import ( + "bytes" + "context" + "errors" + "hostlink/app/services/localtaskstore" + "hostlink/app/services/taskreporter" + "hostlink/domain/task" + "io" + "sync" + "testing" + "time" +) + +func TestTaskJobStreamsOutputAndFinalOverResultChannel(t *testing.T) { + fetcher := &fakeTaskFetcher{tasks: []task.Task{{ + ID: "task-1", + ExecutionAttemptID: "attempt-1", + Command: "printf 'out\\n'; printf 'err\\n' >&2", + Status: "pending", + }}} + reporter := &fakeTaskReporter{} + channel := &fakeResultChannel{} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger}) + + job.processTask(context.Background(), fetcher.tasks[0], reporter, channel) + + if len(channel.outputs) != 2 { + t.Fatalf("outputs len = %d, want 2", len(channel.outputs)) + } + stdout := outputByStream(channel.outputs, "stdout") + stderr := outputByStream(channel.outputs, "stderr") + if stdout == nil || stdout.Sequence != 1 || stdout.Payload != "out\n" { + t.Fatalf("stdout chunk = %#v", stdout) + } + if stderr == nil || stderr.Sequence != 1 || stderr.Payload != "err\n" { + t.Fatalf("stderr chunk = %#v", stderr) + } + if len(channel.finals) != 1 { + t.Fatalf("finals len = %d, want 1", len(channel.finals)) + } + if channel.finals[0].TaskID != "task-1" || channel.finals[0].ExecutionAttemptID != "attempt-1" || channel.finals[0].Status != "completed" { + t.Fatalf("final = %#v", channel.finals[0]) + } + if len(reporter.results) != 0 { + t.Fatalf("http reports len = %d, want 0 when result channel succeeds", len(reporter.results)) + } +} + +func TestTaskJobFallsBackToHTTPReporterWhenResultChannelDisabled(t *testing.T) { + fetcher := &fakeTaskFetcher{tasks: []task.Task{{ID: "task-1", Command: "printf 'out'", Status: "pending"}}} + reporter := &fakeTaskReporter{} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger}) + + job.processTask(context.Background(), fetcher.tasks[0], reporter, nil) + + if len(reporter.results) != 1 { + t.Fatalf("http reports len = %d, want 1", len(reporter.results)) + } + if reporter.results[0].Output != "out" { + t.Fatalf("output = %q, want out", reporter.results[0].Output) + } +} + +func TestTaskJobFallsBackToHTTPReporterWhenFinalPersistenceFails(t *testing.T) { + fetcher := &fakeTaskFetcher{tasks: []task.Task{{ + ID: "task-1", + ExecutionAttemptID: "attempt-1", + Command: "printf 'out'", + Status: "pending", + }}} + reporter := &fakeTaskReporter{} + channel := &fakeResultChannel{finalErr: errors.New("store down")} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger}) + + job.processTask(context.Background(), fetcher.tasks[0], reporter, channel) + + if len(reporter.results) != 1 { + t.Fatalf("http reports len = %d, want 1", len(reporter.results)) + } + if len(channel.finals) != 1 { + t.Fatalf("finals len = %d, want 1", len(channel.finals)) + } +} + +func TestCaptureStreamFlushesOnByteThreshold(t *testing.T) { + reader, writer := io.Pipe() + channel := &fakeResultChannel{} + job := NewJobWithConf(TaskJobConfig{ + OutputFlushInterval: time.Hour, + OutputFlushThreshold: 4, + }) + done := make(chan struct{}) + + go func() { + var sink bytes.Buffer + job.captureStream(context.Background(), task.Task{ID: "task-1", ExecutionAttemptID: "attempt-1"}, "stdout", reader, &sink, channel) + close(done) + }() + + _, _ = writer.Write([]byte("ab")) + if len(channel.outputs) != 0 { + t.Fatalf("outputs len = %d, want no flush before threshold", len(channel.outputs)) + } + _, _ = writer.Write([]byte("cd")) + waitForOutputs(t, channel, 1) + _ = writer.Close() + <-done + + if channel.outputs[0].Payload != "abcd" { + t.Fatalf("payload = %q, want abcd", channel.outputs[0].Payload) + } +} + +func TestCaptureStreamFlushesOnInterval(t *testing.T) { + reader, writer := io.Pipe() + channel := &fakeResultChannel{} + job := NewJobWithConf(TaskJobConfig{ + OutputFlushInterval: 10 * time.Millisecond, + OutputFlushThreshold: 1024, + }) + done := make(chan struct{}) + + go func() { + var sink bytes.Buffer + job.captureStream(context.Background(), task.Task{ID: "task-1", ExecutionAttemptID: "attempt-1"}, "stdout", reader, &sink, channel) + close(done) + }() + + _, _ = writer.Write([]byte("slow")) + waitForOutputs(t, channel, 1) + _ = writer.Close() + <-done + + if channel.outputs[0].Payload != "slow" { + t.Fatalf("payload = %q, want slow", channel.outputs[0].Payload) + } +} + +func TestCaptureStreamRetainsChunkWhenPersistFails(t *testing.T) { + reader, writer := io.Pipe() + channel := &fakeResultChannel{outputErrs: []error{errors.New("store down"), nil}} + job := NewJobWithConf(TaskJobConfig{ + OutputFlushInterval: 10 * time.Millisecond, + OutputFlushThreshold: 1024, + }) + done := make(chan struct{}) + + go func() { + var sink bytes.Buffer + job.captureStream(context.Background(), task.Task{ID: "task-1", ExecutionAttemptID: "attempt-1"}, "stdout", reader, &sink, channel) + close(done) + }() + + _, _ = writer.Write([]byte("retry")) + waitForOutputs(t, channel, 2) + _ = writer.Close() + <-done + + if channel.outputs[1].Payload != "retry" || channel.outputs[1].Sequence != 1 { + t.Fatalf("retry output = %#v", channel.outputs[1]) + } +} + +func runOnceTrigger(ctx context.Context, fn func() error) { + _ = fn() +} + +type fakeTaskFetcher struct { + tasks []task.Task +} + +func (f *fakeTaskFetcher) Fetch() ([]task.Task, error) { + return f.tasks, nil +} + +type fakeTaskReporter struct { + mu sync.Mutex + results []*taskreporter.TaskResult +} + +func (f *fakeTaskReporter) Report(taskID string, result *taskreporter.TaskResult) error { + f.mu.Lock() + defer f.mu.Unlock() + f.results = append(f.results, result) + return nil +} + +type fakeResultChannel struct { + mu sync.Mutex + outputs []localtaskstore.OutputChunk + finals []localtaskstore.FinalResult + outputErrs []error + finalErr error +} + +func (f *fakeResultChannel) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error { + f.mu.Lock() + defer f.mu.Unlock() + f.outputs = append(f.outputs, chunk) + if len(f.outputErrs) > 0 { + err := f.outputErrs[0] + f.outputErrs = f.outputErrs[1:] + return err + } + return nil +} + +func (f *fakeResultChannel) SendFinal(ctx context.Context, result localtaskstore.FinalResult) error { + f.mu.Lock() + defer f.mu.Unlock() + f.finals = append(f.finals, result) + return f.finalErr +} + +func waitForOutputs(t *testing.T, channel *fakeResultChannel, count int) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + channel.mu.Lock() + current := len(channel.outputs) + channel.mu.Unlock() + if current >= count { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("timed out waiting for %d outputs", count) +} + +func outputByStream(outputs []localtaskstore.OutputChunk, stream string) *localtaskstore.OutputChunk { + for i := range outputs { + if outputs[i].Stream == stream { + return &outputs[i] + } + } + return nil +} diff --git a/app/jobs/taskjob/taskjob.go b/app/jobs/taskjob/taskjob.go index d0d8c05..6df56c2 100644 --- a/app/jobs/taskjob/taskjob.go +++ b/app/jobs/taskjob/taskjob.go @@ -3,14 +3,21 @@ package taskjob import ( + "bufio" + "bytes" "context" + "encoding/json" "fmt" + "hostlink/app/services/localtaskstore" "hostlink/app/services/taskfetcher" "hostlink/app/services/taskreporter" "hostlink/domain/task" + "io" "os" "os/exec" + "strings" "sync" + "time" "github.com/labstack/gommon/log" ) @@ -18,7 +25,14 @@ import ( type TriggerFunc func(context.Context, func() error) type TaskJobConfig struct { - Trigger TriggerFunc + Trigger TriggerFunc + OutputFlushInterval time.Duration + OutputFlushThreshold int +} + +type ResultChannel interface { + SendOutput(context.Context, localtaskstore.OutputChunk) error + SendFinal(context.Context, localtaskstore.FinalResult) error } type TaskJob struct { @@ -37,15 +51,25 @@ func NewJobWithConf(cfg TaskJobConfig) *TaskJob { if cfg.Trigger == nil { cfg.Trigger = Trigger } + if cfg.OutputFlushInterval == 0 { + cfg.OutputFlushInterval = 100 * time.Millisecond + } + if cfg.OutputFlushThreshold == 0 { + cfg.OutputFlushThreshold = 16 * 1024 + } return &TaskJob{ config: cfg, } } -func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr taskreporter.TaskReporter) context.CancelFunc { +func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr taskreporter.TaskReporter, channels ...ResultChannel) context.CancelFunc { ctx, cancel := context.WithCancel(ctx) tj.cancel = cancel + var channel ResultChannel + if len(channels) > 0 { + channel = channels[0] + } tj.wg.Add(1) go func() { defer tj.wg.Done() @@ -61,7 +85,7 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr } } for _, t := range incompleteTasks { - tj.processTask(t, tr) + tj.processTask(ctx, t, tr, channel) } return nil }) @@ -69,7 +93,7 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr return cancel } -func (tj *TaskJob) processTask(t task.Task, tr taskreporter.TaskReporter) { +func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter.TaskReporter, channel ResultChannel) { tempFile, err := os.CreateTemp("", "*_script.sh") if err != nil { t.Error = fmt.Sprintf("failed to create temp file: %v", err) @@ -116,6 +140,11 @@ func (tj *TaskJob) processTask(t task.Task, tr taskreporter.TaskReporter) { return } execCmd := exec.Command("/bin/sh", "-c", tempFile.Name()) + if channel != nil && t.ExecutionAttemptID != "" { + tj.processTaskWithResultChannel(ctx, t, execCmd, tr, channel) + return + } + output, err := execCmd.CombinedOutput() exitCode := 0 errMsg := "" @@ -139,6 +168,154 @@ func (tj *TaskJob) processTask(t task.Task, tr taskreporter.TaskReporter) { } } +func (tj *TaskJob) processTaskWithResultChannel(ctx context.Context, t task.Task, execCmd *exec.Cmd, tr taskreporter.TaskReporter, channel ResultChannel) { + stdout, err := execCmd.StdoutPipe() + if err != nil { + tj.reportHTTPResult(t, tr, "failed", "", fmt.Sprintf("failed to capture stdout: %v", err), 1) + return + } + stderr, err := execCmd.StderrPipe() + if err != nil { + tj.reportHTTPResult(t, tr, "failed", "", fmt.Sprintf("failed to capture stderr: %v", err), 1) + return + } + + if err := execCmd.Start(); err != nil { + tj.reportHTTPResult(t, tr, "failed", "", err.Error(), 1) + return + } + + var stdoutBuf bytes.Buffer + var stderrBuf bytes.Buffer + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + tj.captureStream(ctx, t, "stdout", stdout, &stdoutBuf, channel) + }() + go func() { + defer wg.Done() + tj.captureStream(ctx, t, "stderr", stderr, &stderrBuf, channel) + }() + wg.Wait() + + exitCode := 0 + status := "completed" + errMsg := "" + if err := execCmd.Wait(); err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + exitCode = exitError.ExitCode() + } else { + exitCode = 1 + } + status = "failed" + errMsg = err.Error() + } + + if stderrBuf.Len() > 0 { + errMsg = stderrBuf.String() + } + output := stdoutBuf.String() + resultPayload := taskreporter.TaskResult{Status: status, Output: output, Error: errMsg, ExitCode: exitCode} + finalPayload, err := json.Marshal(resultPayload) + if err != nil { + tj.reportHTTPResult(t, tr, status, output, errMsg, exitCode) + return + } + + final := localtaskstore.FinalResult{ + MessageID: messageID(t.ID, t.ExecutionAttemptID, "final", 0), + TaskID: t.ID, + ExecutionAttemptID: t.ExecutionAttemptID, + Status: status, + ExitCode: exitCode, + Payload: string(finalPayload), + } + if err := channel.SendFinal(ctx, final); err != nil { + tj.reportHTTPResult(t, tr, status, output, errMsg, exitCode) + } +} + +func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string, reader io.Reader, sink *bytes.Buffer, channel ResultChannel) { + sequence := int64(1) + chunks := make(chan string, 1) + go func() { + defer close(chunks) + buffered := bufio.NewReaderSize(reader, tj.config.OutputFlushThreshold) + for { + buf := make([]byte, max(tj.config.OutputFlushThreshold, 1)) + n, err := buffered.Read(buf) + if n > 0 { + chunks <- string(buf[:n]) + } + if err != nil { + return + } + } + }() + + var pending bytes.Buffer + ticker := time.NewTicker(tj.config.OutputFlushInterval) + defer ticker.Stop() + + flush := func() bool { + if pending.Len() == 0 { + return true + } + chunk := pending.String() + err := channel.SendOutput(ctx, localtaskstore.OutputChunk{ + MessageID: messageID(t.ID, t.ExecutionAttemptID, stream, sequence), + TaskID: t.ID, + ExecutionAttemptID: t.ExecutionAttemptID, + Stream: stream, + Sequence: sequence, + Payload: chunk, + ByteCount: int64(len(chunk)), + }) + if err != nil { + return false + } + pending.Reset() + sequence++ + return true + } + + for { + select { + case chunk, ok := <-chunks: + if !ok { + flush() + return + } + sink.WriteString(chunk) + pending.WriteString(chunk) + if pending.Len() >= tj.config.OutputFlushThreshold { + flush() + } + case <-ticker.C: + flush() + case <-ctx.Done(): + return + } + } +} + +func (tj *TaskJob) reportHTTPResult(t task.Task, tr taskreporter.TaskReporter, status, output, errMsg string, exitCode int) { + if reportErr := tr.Report(t.ID, &taskreporter.TaskResult{ + Status: status, + Output: output, + Error: errMsg, + ExitCode: exitCode, + }); reportErr != nil { + log.Errorf("failed to report task %s: %v", t.ID, reportErr) + } +} + +func messageID(taskID, attemptID, stream string, sequence int64) string { + parts := []string{"msg", taskID, attemptID, stream, fmt.Sprintf("%d", sequence), fmt.Sprintf("%d", time.Now().UnixNano())} + return strings.NewReplacer("/", "-", " ", "-", "|", "-").Replace(strings.Join(parts, "-")) +} + func (tj *TaskJob) Shutdown() { if tj.cancel != nil { tj.cancel() diff --git a/app/services/wsclient/client.go b/app/services/wsclient/client.go index 45bc992..5b9a16a 100644 --- a/app/services/wsclient/client.go +++ b/app/services/wsclient/client.go @@ -2,8 +2,10 @@ package wsclient import ( "context" + "encoding/json" "errors" "fmt" + "hostlink/app/services/localtaskstore" "math/rand/v2" "net/http" "sync" @@ -38,6 +40,7 @@ type Config struct { ReconnectMax time.Duration PingInterval time.Duration SleepFunc SleepFunc + ResultOutbox localtaskstore.ResultOutbox } type Client struct { @@ -51,8 +54,11 @@ type Client struct { sleep SleepFunc mu sync.RWMutex + writeMu sync.Mutex active bool lastAck *wsprotocol.AckPayload + conn Conn + outbox localtaskstore.ResultOutbox } func New(cfg Config) (*Client, error) { @@ -92,6 +98,7 @@ func New(cfg Config) (*Client, error) { reconnectMax: cfg.ReconnectMax, pingInterval: cfg.PingInterval, sleep: cfg.SleepFunc, + outbox: cfg.ResultOutbox, }, nil } @@ -148,10 +155,12 @@ func (c *Client) runOnce(ctx context.Context) error { if err != nil { return err } + c.setConn(conn) defer conn.Close() + defer c.setConn(nil) hello := c.buildHello() - if err := conn.WriteEnvelope(ctx, hello); err != nil { + if err := c.writeEnvelope(ctx, conn, hello); err != nil { return err } @@ -169,7 +178,7 @@ func (c *Client) runOnce(ctx context.Context) error { case err := <-readErr: return err case <-ticker.C: - if err := conn.Ping(ctx); err != nil { + if err := c.ping(ctx, conn); err != nil { _ = conn.Close() return err } @@ -200,6 +209,9 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) } if ack.AckedMessageID == helloMessageID { c.setActive(true) + if err := c.replayUnacked(ctx, conn); err != nil { + return err + } } c.setLastAck(&ack) case wsprotocol.TypeAck: @@ -207,8 +219,20 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) if err != nil { return err } + if c.outbox != nil && ack.AckedMessageID != "" { + if err := c.outbox.AckMessage(ack.AckedMessageID); err != nil { + return err + } + } c.setLastAck(&ack) case wsprotocol.TypeError: + payload, err := wsprotocol.DecodePayload[wsprotocol.ErrorPayload](env) + if err != nil { + return err + } + if payload.Retryable { + continue + } return fmt.Errorf("websocket protocol error: %s", env.MessageID) default: return fmt.Errorf("unsupported inbound websocket message type: %s", env.Type) @@ -216,6 +240,49 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) } } +func (c *Client) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error { + if c.outbox == nil { + return fmt.Errorf("result outbox is not configured") + } + if err := c.outbox.AppendOutputChunk(chunk); err != nil { + return err + } + return c.sendIfActive(ctx, envelopeFromOutboxMessage(c.agentID, localtaskstore.OutboxMessage{ + MessageID: chunk.MessageID, + TaskID: chunk.TaskID, + ExecutionAttemptID: chunk.ExecutionAttemptID, + Type: localtaskstore.OutboxMessageTypeOutput, + Stream: chunk.Stream, + Sequence: chunk.Sequence, + Payload: chunk.Payload, + ByteCount: chunk.ByteCount, + })) +} + +func (c *Client) SendFinal(ctx context.Context, result localtaskstore.FinalResult) error { + if c.outbox == nil { + return fmt.Errorf("result outbox is not configured") + } + if err := c.outbox.RecordFinal(result); err != nil { + return err + } + err := c.sendIfActive(ctx, envelopeFromOutboxMessage(c.agentID, localtaskstore.OutboxMessage{ + MessageID: result.MessageID, + TaskID: result.TaskID, + ExecutionAttemptID: result.ExecutionAttemptID, + Type: localtaskstore.OutboxMessageTypeFinal, + Payload: result.Payload, + ByteCount: int64(len(result.Payload)), + })) + if err != nil { + return err + } + if !c.IsActive() { + return fmt.Errorf("websocket result channel is inactive") + } + return nil +} + func (c *Client) buildHello() wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, @@ -233,12 +300,94 @@ func (c *Client) setActive(active bool) { c.active = active } +func (c *Client) setConn(conn Conn) { + c.mu.Lock() + defer c.mu.Unlock() + c.conn = conn +} + func (c *Client) setLastAck(ack *wsprotocol.AckPayload) { c.mu.Lock() defer c.mu.Unlock() c.lastAck = ack } +func (c *Client) sendIfActive(ctx context.Context, env wsprotocol.Envelope) error { + c.mu.RLock() + conn := c.conn + active := c.active + c.mu.RUnlock() + if conn == nil || !active { + return nil + } + return c.writeEnvelope(ctx, conn, env) +} + +func (c *Client) replayUnacked(ctx context.Context, conn Conn) error { + if c.outbox == nil { + return nil + } + messages, err := c.outbox.UnackedMessages() + if err != nil { + return err + } + for _, message := range messages { + if err := c.writeEnvelope(ctx, conn, envelopeFromOutboxMessage(c.agentID, message)); err != nil { + return err + } + } + return nil +} + +func (c *Client) writeEnvelope(ctx context.Context, conn Conn, env wsprotocol.Envelope) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + return conn.WriteEnvelope(ctx, env) +} + +func (c *Client) ping(ctx context.Context, conn Conn) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + return conn.Ping(ctx) +} + +func envelopeFromOutboxMessage(agentID string, message localtaskstore.OutboxMessage) wsprotocol.Envelope { + now := time.Now().UTC().Format(time.RFC3339) + if message.Type == localtaskstore.OutboxMessageTypeOutput { + sequence := int(message.Sequence) + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: message.MessageID, + Type: wsprotocol.TypeTaskOutput, + AgentID: agentID, + TaskID: message.TaskID, + ExecutionAttemptID: message.ExecutionAttemptID, + Sequence: &sequence, + SentAt: now, + Payload: map[string]any{ + "stream": message.Stream, + "data": message.Payload, + "byte_count": message.ByteCount, + }, + } + } + + payload := map[string]any{} + if message.Payload != "" { + _ = json.Unmarshal([]byte(message.Payload), &payload) + } + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: message.MessageID, + Type: wsprotocol.TypeTaskFinal, + AgentID: agentID, + TaskID: message.TaskID, + ExecutionAttemptID: message.ExecutionAttemptID, + SentAt: now, + Payload: payload, + } +} + func sleepContext(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) defer timer.Stop() diff --git a/app/services/wsclient/client_test.go b/app/services/wsclient/client_test.go index b89b00e..62bc4ea 100644 --- a/app/services/wsclient/client_test.go +++ b/app/services/wsclient/client_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "encoding/pem" "errors" + "hostlink/app/services/localtaskstore" "net/http" "os" "path/filepath" @@ -117,6 +118,144 @@ func TestClientHandlesAckWithoutTaskSideEffects(t *testing.T) { } } +func TestClientAckRemovesResultMessageFromOutbox(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Stream: "stdout", + Sequence: 1, + Payload: "hello\n", + ByteCount: 6, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithResultOutbox(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + conn.waitForWrite(t) + conn.readCh <- ackEnvelope("msg_ack", "msg-output-1", wsprotocol.TypeTaskOutput) + + waitFor(t, func() bool { + messages, err := store.UnackedMessages() + return err == nil && len(messages) == 0 + }, "ack to remove outbox message") + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientReplaysUnackedMessagesAfterHelloAck(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Stream: "stdout", + Sequence: 1, + Payload: "hello\n", + ByteCount: 6, + })) + requireNoError(t, store.RecordFinal(localtaskstore.FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithResultOutbox(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_hello_ack", + Type: wsprotocol.TypeAgentHelloAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMap(t, wsprotocol.BuildAck(wsprotocol.AckOptions{ + AckedMessageID: hello.MessageID, + AckedType: wsprotocol.TypeAgentHello, + })), + } + + output := conn.waitForWrite(t) + final := conn.waitForWrite(t) + if output.MessageID != "msg-output-1" || output.Type != wsprotocol.TypeTaskOutput { + t.Fatalf("first replay = %#v", output) + } + if final.MessageID != "msg-final-1" || final.Type != wsprotocol.TypeTaskFinal { + t.Fatalf("second replay = %#v", final) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientRetryableErrorKeepsConnectionAndOutboxMessage(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Stream: "stdout", + Sequence: 1, + Payload: "hello\n", + ByteCount: 6, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithResultOutbox(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + conn.waitForWrite(t) + conn.readCh <- wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_error", + Type: wsprotocol.TypeError, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMap(t, wsprotocol.BuildError(wsprotocol.ErrorOptions{ + Code: "output_sequence_gap", + Message: "expected sequence 2", + Retryable: true, + RelatedMessageID: "msg-output-1", + HighestAcceptedSequence: intValuePtr(1), + })), + } + + waitFor(t, func() bool { return !conn.closed() }, "connection to remain open") + messages, err := store.UnackedMessages() + if err != nil { + t.Fatalf("unacked messages: %v", err) + } + if len(messages) != 1 || messages[0].MessageID != "msg-output-1" { + t.Fatalf("messages = %#v", messages) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientErrorMessageTriggersReconnect(t *testing.T) { first := newFakeConn() second := newFakeConn() @@ -276,6 +415,10 @@ func WithPingInterval(d time.Duration) clientOption { return func(cfg *Config) { cfg.PingInterval = d } } +func WithResultOutbox(outbox localtaskstore.ResultOutbox) clientOption { + return func(cfg *Config) { cfg.ResultOutbox = outbox } +} + type fakeDialer struct { mu sync.Mutex conn *fakeConn @@ -387,6 +530,45 @@ func waitFor(t *testing.T, check func() bool, description string) { t.Fatalf("timed out waiting for %s", description) } +func ackEnvelope(messageID, ackedMessageID string, ackedType wsprotocol.MessageType) wsprotocol.Envelope { + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: messageID, + Type: wsprotocol.TypeAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: map[string]any{ + "acked_message_id": ackedMessageID, + "acked_type": string(ackedType), + }, + } +} + +func newClientTestStore(t *testing.T) *localtaskstore.Store { + t.Helper() + store, err := localtaskstore.New(localtaskstore.Config{ + Path: filepath.Join(t.TempDir(), "task_store.db"), + SpoolCapBytes: 1024 * 1024, + TerminalReserveBytes: 1024, + }) + if err != nil { + t.Fatalf("new local task store: %v", err) + } + t.Cleanup(func() { _ = store.Close() }) + return store +} + +func requireNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func intValuePtr(value int) *int { + return &value +} + func saveTestPrivateKey(t *testing.T, dir string) string { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) diff --git a/domain/task/task.go b/domain/task/task.go index c3df94e..0713024 100644 --- a/domain/task/task.go +++ b/domain/task/task.go @@ -5,16 +5,17 @@ import ( ) type Task struct { - ID string `json:"id"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - DeletedAt *time.Time `json:"deleted_at,omitempty"` - Command string `json:"command"` - Status string `json:"status"` - Priority int `json:"priority"` - Output string `json:"output"` - Error string `json:"error"` - ExitCode int `json:"exit_code"` + ID string `json:"id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + Command string `json:"command"` + Status string `json:"status"` + Priority int `json:"priority"` + Output string `json:"output"` + Error string `json:"error"` + ExitCode int `json:"exit_code"` } type TaskFilters struct { diff --git a/internal/wsprotocol/ack_sequence.go b/internal/wsprotocol/ack_sequence.go index 0582f8d..4bd53ef 100644 --- a/internal/wsprotocol/ack_sequence.go +++ b/internal/wsprotocol/ack_sequence.go @@ -17,17 +17,19 @@ type AckOptions struct { } type ErrorPayload struct { - Code string `json:"code"` - Message string `json:"message"` - Retryable bool `json:"retryable"` - RelatedMessageID string `json:"related_message_id,omitempty"` + Code string `json:"code"` + Message string `json:"message"` + Retryable bool `json:"retryable"` + RelatedMessageID string `json:"related_message_id,omitempty"` + HighestAcceptedSequence *int `json:"highest_accepted_sequence,omitempty"` } type ErrorOptions struct { - Code string - Message string - Retryable bool - RelatedMessageID string + Code string + Message string + Retryable bool + RelatedMessageID string + HighestAcceptedSequence *int } type MessageRecordResult int @@ -76,10 +78,11 @@ func BuildAck(opts AckOptions) AckPayload { func BuildError(opts ErrorOptions) ErrorPayload { return ErrorPayload{ - Code: opts.Code, - Message: opts.Message, - Retryable: opts.Retryable, - RelatedMessageID: opts.RelatedMessageID, + Code: opts.Code, + Message: opts.Message, + Retryable: opts.Retryable, + RelatedMessageID: opts.RelatedMessageID, + HighestAcceptedSequence: opts.HighestAcceptedSequence, } } diff --git a/internal/wsprotocol/ack_sequence_test.go b/internal/wsprotocol/ack_sequence_test.go index 6b49ac9..dc57bf5 100644 --- a/internal/wsprotocol/ack_sequence_test.go +++ b/internal/wsprotocol/ack_sequence_test.go @@ -52,10 +52,11 @@ func TestSampleAckPayloadUsesCanonicalFieldNames(t *testing.T) { func TestErrorPayload(t *testing.T) { payload := BuildError(ErrorOptions{ - Code: "output_sequence_gap", - Message: "expected sequence 8", - Retryable: true, - RelatedMessageID: "msg_123", + Code: "output_sequence_gap", + Message: "expected sequence 8", + Retryable: true, + RelatedMessageID: "msg_123", + HighestAcceptedSequence: intPtr(7), }) if payload.Code != "output_sequence_gap" { @@ -67,6 +68,9 @@ func TestErrorPayload(t *testing.T) { if payload.RelatedMessageID != "msg_123" { t.Errorf("related_message_id = %q, want msg_123", payload.RelatedMessageID) } + if payload.HighestAcceptedSequence == nil || *payload.HighestAcceptedSequence != 7 { + t.Fatalf("highest_accepted_sequence = %v, want 7", payload.HighestAcceptedSequence) + } } func TestMessageTracker(t *testing.T) { diff --git a/main.go b/main.go index d0c7b9e..7208b8a 100644 --- a/main.go +++ b/main.go @@ -267,7 +267,14 @@ func runServer(ctx context.Context, cmd *cli.Command) error { // Wait for registration to complete <-registeredChan log.Println("Agent registered, starting task job...") - startWebSocketClientIfEnabled(ctx, newDefaultWebSocketRuntime) + var resultChannel taskjob.ResultChannel + startWebSocketClientIfEnabled(ctx, func() (webSocketRuntime, error) { + runtime, err := newDefaultWebSocketRuntime(localStore) + if err == nil { + resultChannel = runtime.(taskjob.ResultChannel) + } + return runtime, err + }) fetcher, err := taskfetcher.NewDefault() if err != nil { @@ -280,7 +287,7 @@ func runServer(ctx context.Context, cmd *cli.Command) error { return } taskJob := taskjob.New() - taskJob.Register(ctx, fetcher, reporter) + taskJob.Register(ctx, fetcher, reporter, resultChannel) metricsReporter, err := metrics.New() if err != nil { @@ -328,11 +335,14 @@ func startWebSocketClientIfEnabled(ctx context.Context, constructor func() (webS return true } -func newDefaultWebSocketRuntime() (webSocketRuntime, error) { +func newDefaultWebSocketRuntime(localStore *localtaskstore.Store) (webSocketRuntime, error) { state := agentstate.New(appconf.AgentStatePath()) if err := state.Load(); err != nil { return nil, fmt.Errorf("failed to load agent state: %w", err) } + if localStore == nil { + return nil, fmt.Errorf("local task store is not available") + } return wsclient.New(wsclient.Config{ URL: appconf.WebSocketURL(), AgentState: state, @@ -340,6 +350,7 @@ func newDefaultWebSocketRuntime() (webSocketRuntime, error) { ReconnectMin: appconf.WebSocketReconnectMin(), ReconnectMax: appconf.WebSocketReconnectMax(), PingInterval: appconf.WebSocketPingInterval(), + ResultOutbox: localStore, }) }