Skip to content

Commit

Permalink
refactor: local kb persona override and max turns
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihil committed Aug 8, 2023
1 parent 0df7ddb commit 9efb66f
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lib/services/aiAssist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ import { AbstractManager } from './utils';
* ```
*/
class AIAssist extends AbstractManager implements ContextHandler {
private static readonly MAX_STOREABLE_TURNS = 10;
private static readonly MAX_STOREABLE_TURNS = 15;

private static readonly MSGS_PER_TURN = 2; // 1 record for user input + 1 record for assistant output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ export class KnowledgeBase implements BaseLanguageGenerator<KnowledgeBaseSetting
system,
temperature,
maxTokens,
chatHistory: [],
});
}

Expand Down Expand Up @@ -98,11 +97,10 @@ export class KnowledgeBase implements BaseLanguageGenerator<KnowledgeBaseSetting
system,
temperature,
maxTokens,
chatHistory: [],
});
}

public async generate(prompt: string, { chatHistory }: KnowledgeBaseSettings): Promise<AnswerReturn> {
public async generate(prompt: string, { chatHistory, persona }: KnowledgeBaseSettings): Promise<AnswerReturn> {
const generatedQuestion = await this.promptQuestionSynthesis(prompt, chatHistory);

if (!generatedQuestion.output) {
Expand All @@ -127,7 +125,8 @@ export class KnowledgeBase implements BaseLanguageGenerator<KnowledgeBaseSetting
};
}

const generatedAnswer = await this.promptAnswerSynthesis(prompt, chatHistory, kbResult.chunks, {});
const kbPersona = persona ?? this.knowledgeBaseConfig.kbStrategy?.summarization;
const generatedAnswer = await this.promptAnswerSynthesis(prompt, chatHistory, kbResult.chunks, kbPersona);

if (!generatedAnswer.output) {
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { AIResponse } from '../../../utils/ai';

export interface KnowledgeBaseSettings {
chatHistory: BaseUtils.ai.Message[];
persona?: BaseUtils.ai.AIModelParams;
}

export interface KnowledgeBaseConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export class LLM implements BaseLanguageGenerator<LLMSettings> {
}

public async generate(prompt: string, settings: LLMSettings): Promise<GenerateReturn> {
const { model: modelName, temperature, maxTokens, system = '', chatHistory } = settings;
const { model: modelName, temperature, maxTokens, system = '', chatHistory = [] } = settings;

const model = AI.get(modelName);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { BaseUtils } from '@voiceflow/base-types';

export interface LLMSettings {
chatHistory: BaseUtils.ai.Message[];
chatHistory?: BaseUtils.ai.Message[];
model?: BaseUtils.ai.GPT_MODEL;
temperature?: number;
maxTokens?: number;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ export interface ResolvedPromptVariant extends BaseResolvedVariant {
maxLength: number | null;
systemPrompt: string | null;
};
};
} | null;
}

export type ResolvedVariant = ResolvedPromptVariant | ResolvedJSONVariant | ResolvedTextVariant;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { BaseTrace, BaseUtils } from '@voiceflow/base-types';
import { InternalException } from '@voiceflow/exception';
import VError from '@voiceflow/verror';
import { match } from 'ts-pattern';

Expand Down Expand Up @@ -84,18 +85,28 @@ export class PromptVariant extends BaseVariant<ResolvedPromptVariant> {
return messagesByTurn.slice(messagesByTurn.length - maxTurns).flat();
}

private async resolveByLLM(prompt: string): Promise<BaseTrace.V3.TextTrace> {
private getLocalPersona(): BaseUtils.ai.AIModelParams | null {
if (!this.rawVariant.prompt) return null;

const {
persona: { model: modelName, temperature, maxLength, systemPrompt },
persona: { model, temperature, maxLength, systemPrompt },
} = this.rawVariant.prompt;

const resolvedSystem = systemPrompt ? serializeResolvedMarkup(this.varContext.resolveMarkup([systemPrompt])) : null;

const { output } = await this.langGen.llm.generate(prompt, {
...(modelName && { model: modelName }),
return {
...(model && { model }),
...(temperature && { temperature }),
...(maxLength && { maxTokens: maxLength }),
...(systemPrompt && { prompt: resolvedSystem }),
};
}

private async resolveByLLM(prompt: string): Promise<BaseTrace.V3.TextTrace> {
const persona = this.getLocalPersona();

const { output } = await this.langGen.llm.generate(prompt, {
...persona,
chatHistory: this.chatHistory,
});

Expand All @@ -108,8 +119,10 @@ export class PromptVariant extends BaseVariant<ResolvedPromptVariant> {
}

private async resolveByKB(prompt: string): Promise<ArrayOrElement<BaseTrace.V3.TextTrace | BaseTrace.V3.DebugTrace>> {
const persona = this.getLocalPersona();
const genResult = await this.langGen.knowledgeBase.generate(prompt, {
chatHistory: this.chatHistory,
...(persona && { persona }),
});

if (genResult.output === null) {
Expand Down Expand Up @@ -167,6 +180,12 @@ export class PromptVariant extends BaseVariant<ResolvedPromptVariant> {
}

async trace(): Promise<ArrayOrElement<BaseTrace.V3.TextTrace | BaseTrace.V3.DebugTrace>> {
if (!this.rawVariant.prompt?.text) {
throw new InternalException({
message: 'prompt-type variant is missing a prompt and so it could not be resolved',
});
}

const { text } = this.rawVariant.prompt;
const resolvedPrompt = serializeResolvedMarkup(this.varContext.resolveMarkup(text));

Expand Down

0 comments on commit 9efb66f

Please sign in to comment.