Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add prompt service #6241

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- CreateEnum
CREATE TYPE "AiPromptRole" AS ENUM ('system', 'assistant', 'user');

-- CreateTable
CREATE TABLE "ai_prompts" (
"id" VARCHAR NOT NULL,
"name" VARCHAR(20) NOT NULL,
"idx" INTEGER NOT NULL,
"role" "AiPromptRole" NOT NULL,
"content" TEXT NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "ai_prompts_pkey" PRIMARY KEY ("id")
);

-- CreateIndex
CREATE UNIQUE INDEX "ai_prompts_name_idx_key" ON "ai_prompts"("name", "idx");
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-- CreateTable
CREATE TABLE "ai_sessions" (
"id" VARCHAR NOT NULL,
"user_id" VARCHAR NOT NULL,
"workspace_id" VARCHAR NOT NULL,
"doc_id" VARCHAR NOT NULL,
"prompt_name" VARCHAR NOT NULL,
"action" BOOLEAN NOT NULL,
"model" VARCHAR NOT NULL,
"messages" JSON NOT NULL,
"created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(6) NOT NULL,

CONSTRAINT "ai_sessions_pkey" PRIMARY KEY ("id")
);

-- AddForeignKey
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;

-- AddForeignKey
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;

-- AddForeignKey
ALTER TABLE "ai_sessions" ADD CONSTRAINT "ai_sessions_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE;
45 changes: 45 additions & 0 deletions packages/backend/server/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ model User {
pagePermissions WorkspacePageUserPermission[]
connectedAccounts ConnectedAccount[]
sessions UserSession[]
AiSession AiSession[]

@@map("users")
}
Expand Down Expand Up @@ -96,6 +97,7 @@ model Workspace {
permissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[]
features WorkspaceFeatures[]
AiSession AiSession[]

@@map("workspaces")
}
Expand Down Expand Up @@ -321,6 +323,8 @@ model Snapshot {
// but the created time of last seen update that has been merged into snapshot.
updatedAt DateTime @map("updated_at") @db.Timestamptz(6)

AiSession AiSession[]

@@id([id, workspaceId])
@@map("snapshots")
}
Expand Down Expand Up @@ -422,6 +426,47 @@ model UserInvoice {
@@map("user_invoices")
}

enum AiPromptRole {
system
assistant
user
}

model AiPrompt {
id String @id @default(uuid()) @db.VarChar
// prompt name
name String @db.VarChar(20)
// if a group of prompts contains multiple sentences, idx specifies the order of each sentence
darkskygit marked this conversation as resolved.
Show resolved Hide resolved
idx Int @db.Integer
// system/assistant/user
role AiPromptRole
// prompt content
content String @db.Text
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)

@@unique([name, idx])
@@map("ai_prompts")
}

model AiSession {
id String @id @default(uuid()) @db.VarChar
userId String @map("user_id") @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
darkskygit marked this conversation as resolved.
Show resolved Hide resolved
docId String @map("doc_id") @db.VarChar
promptName String @map("prompt_name") @db.VarChar
action Boolean @db.Boolean
model String @db.VarChar
messages Json @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6)

user User @relation(fields: [userId], references: [id], onDelete: Cascade)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade)

@@map("ai_sessions")
}

model DataMigration {
id String @id @default(uuid()) @db.VarChar(36)
name String @db.VarChar
Expand Down
5 changes: 3 additions & 2 deletions packages/backend/server/src/plugins/copilot/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { ServerFeature } from '../../core/config';
import { Plugin } from '../registry';
import { assertProvidersConfigs, CopilotProviderService } from './provider';
import { PromptService } from './prompt';
import { assertProvidersConfigs, CopilotProviderService } from './providers';

@Plugin({
name: 'copilot',
providers: [CopilotProviderService],
providers: [PromptService, CopilotProviderService],
contributesTo: ServerFeature.Copilot,
if: config => {
if (config.flavor.graphql) {
Expand Down
72 changes: 72 additions & 0 deletions packages/backend/server/src/plugins/copilot/prompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { Injectable } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';

import { ChatMessage } from './types';

@Injectable()
export class PromptService {
constructor(private readonly db: PrismaClient) {}

/**
* list prompt names
* @returns prompt names
*/
async list() {
return this.db.aiPrompt
.findMany({ select: { name: true } })
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
}

/**
* get prompt messages by prompt name
* @param name prompt name
* @returns prompt messages
*/
async get(name: string): Promise<ChatMessage[]> {
return this.db.aiPrompt.findMany({
where: {
name,
},
select: {
role: true,
content: true,
},
orderBy: {
idx: 'asc',
},
});
}

async set(name: string, messages: ChatMessage[]) {
return this.db.$transaction(async tx => {
const prompts = await tx.aiPrompt.count({ where: { name } });
if (prompts > 0) {
return 0;
}
return tx.aiPrompt
.createMany({
data: messages.map((m, idx) => ({ name, idx, ...m })),
})
.then(ret => ret.count);
});
}

async update(name: string, messages: ChatMessage[]) {
return this.db.$transaction(async tx => {
await tx.aiPrompt.deleteMany({ where: { name } });
return tx.aiPrompt
.createMany({
data: messages.map((m, idx) => ({ name, idx, ...m })),
})
.then(ret => ret.count);
});
}

async delete(name: string) {
return this.db.aiPrompt
.deleteMany({
where: { name },
})
.then(ret => ret.count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import assert from 'node:assert';

import { Injectable, Logger } from '@nestjs/common';

import { Config } from '../../fundamentals';
import { Config } from '../../../fundamentals';
import {
CapabilityToCopilotProvider,
CopilotConfig,
CopilotProvider,
CopilotProviderCapability,
CopilotProviderType,
} from './types';
} from '../types';

type CopilotProviderConfig = CopilotConfig[keyof CopilotConfig];

Expand Down
18 changes: 14 additions & 4 deletions packages/backend/server/src/plugins/copilot/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { AiPromptRole } from '@prisma/client';
import type { ClientOptions as OpenAIClientOptions } from 'openai';
import { z } from 'zod';

export interface CopilotConfig {
openai: OpenAIClientOptions;
Expand All @@ -23,10 +25,18 @@ export interface CopilotProvider {
getCapabilities(): CopilotProviderCapability[];
}

export type ChatMessage = {
role: 'system' | 'assistant' | 'user';
content: string;
};
export const ChatMessageSchema = z
.object({
role: z.enum(
Array.from(Object.values(AiPromptRole)) as [
'system' | 'assistant' | 'user',
]
),
content: z.string(),
})
.strict();

export type ChatMessage = z.infer<typeof ChatMessageSchema>;

export interface CopilotTextToTextProvider extends CopilotProvider {
generateText(messages: ChatMessage[], model: string): Promise<string>;
Expand Down
70 changes: 70 additions & 0 deletions packages/backend/server/tests/copilot.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/// <reference types="../src/global.d.ts" />

import { TestingModule } from '@nestjs/testing';
import type { TestFn } from 'ava';
import ava from 'ava';

import { AuthService } from '../src/core/auth';
import { QuotaManagementService, QuotaModule } from '../src/core/quota';
import { ConfigModule } from '../src/fundamentals/config';
import { CopilotModule } from '../src/plugins/copilot';
import { PromptService } from '../src/plugins/copilot/prompt';
import { createTestingModule } from './utils';

const test = ava as TestFn<{
auth: AuthService;
quotaManager: QuotaManagementService;
module: TestingModule;
prompt: PromptService;
}>;

test.beforeEach(async t => {
const module = await createTestingModule({
imports: [
ConfigModule.forRoot({
plugins: {
copilot: {
openai: {
apiKey: '1',
},
},
},
}),
QuotaModule,
CopilotModule,
],
});

const quotaManager = module.get(QuotaManagementService);
const auth = module.get(AuthService);
const prompt = module.get(PromptService);

t.context.module = module;
t.context.quotaManager = quotaManager;
t.context.auth = auth;
t.context.prompt = prompt;
});

test.afterEach.always(async t => {
await t.context.module.close();
});

test('should be able to manage prompt', async t => {
const { prompt } = t.context;

t.is((await prompt.list()).length, 0, 'should have no prompt');

await prompt.set('test', [
{ role: 'system', content: 'hello' },
{ role: 'user', content: 'hello' },
]);
t.is((await prompt.list()).length, 1, 'should have one prompt');
t.is((await prompt.get('test')).length, 2, 'should have two messages');

await prompt.update('test', [{ role: 'system', content: 'hello' }]);
t.is((await prompt.get('test')).length, 1, 'should have one message');

await prompt.delete('test');
t.is((await prompt.list()).length, 0, 'should have no prompt');
t.is((await prompt.get('test')).length, 0, 'should have no messages');
});
Loading