Skip to content

Commit

Permalink
token repeat limit for prediction requests (ollama#3080)
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceMacD authored and zhewang1-intc committed May 13, 2024
1 parent 2f755f7 commit 4a5cddb
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions llm/dyn_ext_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,14 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
}

retryNeeded := false
// keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string
var tokenRepeat int
out:
for {
select {
case <-ctx.Done():
// This handles the request cancellation
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 {
return extServerResponseToErr(resp)
} else {
return nil
}
return cancelCompletion(llm, resp)
default:
var result C.ext_server_task_result_t
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
Expand All @@ -261,6 +258,20 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
break out
}

switch {
case strings.TrimSpace(p.Content) == lastToken:
tokenRepeat++
default:
lastToken = strings.TrimSpace(p.Content)
tokenRepeat = 0
}

// 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 {
slog.Debug("prediction aborted, token repeat limit reached")
return cancelCompletion(llm, resp)
}

if p.Content != "" {
fn(PredictResult{
Content: p.Content,
Expand Down Expand Up @@ -288,6 +299,15 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
return fmt.Errorf("max retries exceeded")
}

func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error {
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 {
return extServerResponseToErr(resp)
} else {
return nil
}
}

func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
Expand Down

0 comments on commit 4a5cddb

Please sign in to comment.