Skip to content

Commit

Permalink
feat: copilot controller (#6272)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed Apr 10, 2024
1 parent e6a5765 commit 7c38a54
Show file tree
Hide file tree
Showing 18 changed files with 729 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,5 @@ ALTER TABLE "ai_sessions_messages" ADD CONSTRAINT "ai_sessions_messages_session_
-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;

-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;

-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_doc_id_workspace_id_fkey" FOREIGN KEY ("doc_id", "workspace_id") REFERENCES "snapshots"("guid", "workspace_id") ON DELETE CASCADE ON UPDATE CASCADE;

-- AddForeignKey
ALTER TABLE "ai_sessions_metadata" ADD CONSTRAINT "ai_sessions_metadata_prompt_name_fkey" FOREIGN KEY ("prompt_name") REFERENCES "ai_prompts_metadata"("name") ON DELETE CASCADE ON UPDATE CASCADE;
11 changes: 3 additions & 8 deletions packages/backend/server/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ model Workspace {
permissions WorkspaceUserPermission[]
pagePermissions WorkspacePageUserPermission[]
features WorkspaceFeatures[]
aiSessions AiSession[]
@@map("workspaces")
}
Expand Down Expand Up @@ -323,8 +322,6 @@ model Snapshot {
// but the created time of last seen update that has been merged into snapshot.
updatedAt DateTime @map("updated_at") @db.Timestamptz(6)
aiSessions AiSession[]
@@id([id, workspaceId])
@@map("snapshots")
}
Expand Down Expand Up @@ -485,11 +482,9 @@ model AiSession {
promptName String @map("prompt_name") @db.VarChar(32)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
doc Snapshot @relation(fields: [docId, workspaceId], references: [id, workspaceId], onDelete: Cascade)
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
messages AiSessionMessage[]
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
messages AiSessionMessage[]
@@map("ai_sessions_metadata")
}
Expand Down
52 changes: 52 additions & 0 deletions packages/backend/server/src/core/workspaces/permission.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ export class PermissionService {
return data?.type as Permission;
}

/**
* check whether a workspace exists and has any one can access it
* @param workspaceId workspace id
* @returns
*/
async hasWorkspace(workspaceId: string) {
return await this.prisma.workspaceUserPermission
.count({
where: {
workspaceId,
accepted: true,
},
})
.then(count => count > 0);
}

async getOwnedWorkspaces(userId: string) {
return this.prisma.workspaceUserPermission
.findMany({
Expand Down Expand Up @@ -96,6 +112,23 @@ export class PermissionService {
return count !== 0;
}

/**
* only check permission if the workspace is a cloud workspace
* @param workspaceId workspace id
* @param userId user id, check if is a public workspace if not provided
* @param permission default is read
*/
async checkCloudWorkspace(
workspaceId: string,
userId?: string,
permission: Permission = Permission.Read
) {
const hasWorkspace = await this.hasWorkspace(workspaceId);
if (hasWorkspace) {
await this.checkWorkspace(workspaceId, userId, permission);
}
}

async checkWorkspace(
ws: string,
user?: string,
Expand Down Expand Up @@ -263,6 +296,25 @@ export class PermissionService {
/// End regin: workspace permission

/// Start regin: page permission
/**
* only check permission if the workspace is a cloud workspace
* @param workspaceId workspace id
* @param pageId page id aka doc id
* @param userId user id, check if is a public page if not provided
* @param permission default is read
*/
async checkCloudPagePermission(
workspaceId: string,
pageId: string,
userId?: string,
permission = Permission.Read
) {
const hasWorkspace = await this.hasWorkspace(workspaceId);
if (hasWorkspace) {
await this.checkPagePermission(workspaceId, pageId, userId, permission);
}
}

async checkPagePermission(
ws: string,
page: string,
Expand Down
151 changes: 151 additions & 0 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import {
BadRequestException,
Controller,
Get,
InternalServerErrorException,
Param,
Query,
Req,
Sse,
} from '@nestjs/common';
import {
concatMap,
connect,
EMPTY,
from,
map,
merge,
Observable,
switchMap,
toArray,
} from 'rxjs';

import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { CopilotProviderService } from './providers';
import { ChatSessionService } from './session';
import { CopilotCapability } from './types';

export interface ChatEvent {
data: string;
id?: string;
}

@Controller('/api/copilot')
export class CopilotController {
constructor(
private readonly chatSession: ChatSessionService,
private readonly provider: CopilotProviderService
) {}

@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});

try {
delete params.message;
const content = await provider.generateText(
session.finish(params),
session.model,
{
signal: req.signal,
user: user.id,
}
);

session.push({
role: 'assistant',
content,
createdAt: new Date(),
});
await session.save();

return content;
} catch (e: any) {
throw new InternalServerErrorException(
e.message || "Couldn't generate text"
);
}
}

@Public()
@Sse('/chat/:sessionId/stream')
async chatStream(
@CurrentUser() user: CurrentUser,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});

delete params.message;
return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: req.signal,
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(map(data => ({ id: sessionId, data }))),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
)
);
}
}
7 changes: 7 additions & 0 deletions packages/backend/server/src/plugins/copilot/index.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import { ServerFeature } from '../../core/config';
import { QuotaService } from '../../core/quota';
import { PermissionService } from '../../core/workspaces/permission';
import { Plugin } from '../registry';
import { CopilotController } from './controller';
import { PromptService } from './prompt';
import {
assertProvidersConfigs,
CopilotProviderService,
OpenAIProvider,
registerCopilotProvider,
} from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver';
import { ChatSessionService } from './session';

registerCopilotProvider(OpenAIProvider);
Expand All @@ -16,10 +19,14 @@ registerCopilotProvider(OpenAIProvider);
name: 'copilot',
providers: [
PermissionService,
QuotaService,
ChatSessionService,
CopilotResolver,
UserCopilotResolver,
PromptService,
CopilotProviderService,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,
if: config => {
if (config.flavor.graphql) {
Expand Down
8 changes: 4 additions & 4 deletions packages/backend/server/src/plugins/copilot/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ export class ChatPrompt {
) {
return new ChatPrompt(
options.name,
options.action,
options.model,
options.action || undefined,
options.model || undefined,
options.messages
);
}

constructor(
public readonly name: string,
public readonly action: string | null,
public readonly model: string | null,
public readonly action: string | undefined,
public readonly model: string | undefined,
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);
Expand Down
18 changes: 10 additions & 8 deletions packages/backend/server/src/plugins/copilot/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ import assert from 'node:assert';
import { ClientOptions, OpenAI } from 'openai';

import {
ChatMessage,
ChatMessageRole,
CopilotCapability,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
CopilotTextToTextProvider,
PromptMessage,
} from '../types';

const DEFAULT_DIMENSIONS = 256;

export class OpenAIProvider
implements CopilotTextToTextProvider, CopilotTextToEmbeddingProvider
{
Expand Down Expand Up @@ -50,7 +52,7 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}

private chatToGPTMessage(messages: ChatMessage[]) {
private chatToGPTMessage(messages: PromptMessage[]) {
// filter redundant fields
return messages.map(message => ({
role: message.role,
Expand All @@ -63,7 +65,7 @@ export class OpenAIProvider
embeddings,
model,
}: {
messages?: ChatMessage[];
messages?: PromptMessage[];
embeddings?: string[];
model: string;
}) {
Expand Down Expand Up @@ -106,7 +108,7 @@ export class OpenAIProvider
// ====== text to text ======

async generateText(
messages: ChatMessage[],
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
Expand Down Expand Up @@ -134,8 +136,8 @@ export class OpenAIProvider
}

async *generateTextStream(
messages: ChatMessage[],
model: string,
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
Expand Down Expand Up @@ -179,15 +181,15 @@ export class OpenAIProvider
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: 256 }
} = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });

const result = await this.instance.embeddings.create({
model: model,
input: messages,
dimensions: options.dimensions,
dimensions: options.dimensions || DEFAULT_DIMENSIONS,
user: options.user,
});
return result.data.map(e => e.embedding);
Expand Down

0 comments on commit 7c38a54

Please sign in to comment.