Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions redisinsight/api/config/features-config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"version": 2.4602,
"version": 2.4603,
"features": {
"insightsRecommendations": {
"flag": true,
Expand Down Expand Up @@ -47,28 +47,6 @@
}
}
},
"documentationChat": {
"flag": true,
"perc": [[0,100]],
"filters": [
{
"name": "config.server.buildType",
"value": "ELECTRON",
"cond": "eq"
}
]
},
"databaseChat": {
"flag": true,
"perc": [[0,100]],
"filters": [
{
"name": "config.server.buildType",
"value": "ELECTRON",
"cond": "eq"
}
]
},
"cloudSsoRecommendedSettings": {
"flag": true,
"perc": [[0, 100]],
Expand Down
2 changes: 1 addition & 1 deletion redisinsight/api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"reflect-metadata": "^0.1.13",
"rxjs": "^7.5.6",
"socket.io": "^4.6.2",
"socket.io-client": "^4.7.5",
"source-map-support": "^0.5.19",
"sqlite3": "5.1.6",
"swagger-ui-express": "^4.1.4",
Expand Down Expand Up @@ -125,7 +126,6 @@
"nyc": "^15.1.0",
"object-diff": "^0.0.4",
"rimraf": "^3.0.2",
"socket.io-client": "^4.4.1",
"socket.io-mock": "^1.3.2",
"supertest": "^4.0.2",
"ts-jest": "^26.1.0",
Expand Down
2 changes: 2 additions & 0 deletions redisinsight/api/src/__mocks__/cloud-capi-key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,6 @@ export const mockCloudCapiKeyService = jest.fn(() => ({
export const mockCloudCapiKeyAnalytics = jest.fn(() => ({
sendCloudAccountKeyGenerated: jest.fn(),
sendCloudAccountKeyGenerationFailed: jest.fn(),
sendCloudAccountSecretGenerated: jest.fn(),
sendCloudAccountSecretGenerationFailed: jest.fn(),
}));
29 changes: 28 additions & 1 deletion redisinsight/api/src/__mocks__/cloud-common.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
import { HttpStatus } from '@nestjs/common';
import ERROR_MESSAGES from 'src/constants/error-messages';
import { CustomErrorCodes } from 'src/constants';

export const mockCapiUnauthorizedError = {
message: 'Request failed with status code 401',
response: {
status: 401,
},
};

export const mockApiInternalServerError = {
export const mockSmApiUnauthorizedError = mockCapiUnauthorizedError;

export const mockSmApiInternalServerError = {
message: 'Something wrong',
response: {
status: 500,
},
};

export const mockSmApiBadRequestError = {
message: 'Bad Request',
response: {
status: 400,
},
};

export const mockUtm = {
source: 'redisinsight',
medium: 'sso',
campaign: 'workbench',
};

export const mockCloudApiUnauthorizedExceptionResponse = {
error: 'CloudApiUnauthorized',
errorCode: CustomErrorCodes.CloudApiUnauthorized,
message: ERROR_MESSAGES.UNAUTHORIZED,
statusCode: HttpStatus.UNAUTHORIZED,
};

export const mockCloudApiBadRequestExceptionResponse = {
error: 'CloudApiBadRequest',
errorCode: CustomErrorCodes.CloudApiBadRequest,
message: ERROR_MESSAGES.BAD_REQUEST,
statusCode: HttpStatus.BAD_REQUEST,
};
1 change: 1 addition & 0 deletions redisinsight/api/src/__mocks__/cloud-user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ export const mockCloudUserRepository = jest.fn(() => ({
export const mockCloudUserApiService = jest.fn(() => ({
getCapiKeys: jest.fn().mockResolvedValue(mockCloudCapiAuthDto),
me: jest.fn().mockResolvedValue(mockCloudUser),
getCloudUser: jest.fn().mockResolvedValue(mockCloudUser),
setCurrentAccount: jest.fn(),
updateUser: jest.fn(),
}));
2 changes: 2 additions & 0 deletions redisinsight/api/src/constants/telemetry-events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ export enum TelemetryEvents {
// Event for cloud CAPI keys
CloudAccountKeyGenerated = 'CLOUD_ACCOUNT_KEY_GENERATED',
CloudAccountKeyGenerationFailed = 'CLOUD_ACCOUNT_KEY_GENERATION_FAILED',
CloudAccountSecretGenerated = 'CLOUD_ACCOUNT_SECRET_GENERATED',
CloudAccountSecretGenerationFailed = 'CLOUD_ACCOUNT_SECRET_GENERATION_FAILED',

// Events for cli tool
CliClientCreated = 'CLI_CLIENT_CREATED',
Expand Down
214 changes: 110 additions & 104 deletions redisinsight/api/src/modules/ai/query/ai-query.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,133 +83,139 @@ export class AiQueryService {
dto: SendAiQueryMessageDto,
res: Response,
) {
let socket: Socket;
return this.aiQueryAuthProvider.callWithAuthRetry(sessionMetadata, async () => {
let socket: Socket;

try {
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
const history = await this.aiQueryMessageRepository.list(sessionMetadata, databaseId, auth.accountId);
try {
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
const history = await this.aiQueryMessageRepository.list(sessionMetadata, databaseId, auth.accountId);

const client = await this.databaseClientFactory.getOrCreateClient({
sessionMetadata,
databaseId,
context: ClientContext.AI,
});

let context = await this.aiQueryContextRepository.getFullDbContext(sessionMetadata, databaseId, auth.accountId);

if (!context) {
context = await this.aiQueryContextRepository.setFullDbContext(
const client = await this.databaseClientFactory.getOrCreateClient({
sessionMetadata,
databaseId,
auth.accountId,
await getFullDbContext(client),
);
}
context: ClientContext.AI,
});

const question = classToClass(AiQueryMessage, {
type: AiQueryMessageType.HumanMessage,
content: dto.content,
databaseId,
accountId: auth.accountId,
createdAt: new Date(),
});

const answer = classToClass(AiQueryMessage, {
type: AiQueryMessageType.AiMessage,
content: '',
databaseId,
accountId: auth.accountId,
});

socket = await this.aiQueryProvider.getSocket(auth);

socket.on(AiQueryWsEvents.REPLY_CHUNK, (chunk) => {
answer.content += chunk;
res.write(chunk);
});

socket.on(AiQueryWsEvents.GET_INDEX, async (index, cb) => {
try {
const indexContext = await this.aiQueryContextRepository.getIndexContext(
let context = await this.aiQueryContextRepository.getFullDbContext(sessionMetadata, databaseId, auth.accountId);

if (!context) {
context = await this.aiQueryContextRepository.setFullDbContext(
sessionMetadata,
databaseId,
auth.accountId,
index,
await getFullDbContext(client),
);
}

const question = classToClass(AiQueryMessage, {
type: AiQueryMessageType.HumanMessage,
content: dto.content,
databaseId,
accountId: auth.accountId,
createdAt: new Date(),
});

const answer = classToClass(AiQueryMessage, {
type: AiQueryMessageType.AiMessage,
content: '',
databaseId,
accountId: auth.accountId,
});

socket = await this.aiQueryProvider.getSocket(sessionMetadata, auth);

socket.on(AiQueryWsEvents.REPLY_CHUNK, (chunk) => {
answer.content += chunk;
res.write(chunk);
});

if (!context) {
return cb(await this.aiQueryContextRepository.setIndexContext(
socket.on(AiQueryWsEvents.GET_INDEX, async (index, cb) => {
try {
const indexContext = await this.aiQueryContextRepository.getIndexContext(
sessionMetadata,
databaseId,
auth.accountId,
index,
await getIndexContext(client, index),
));
);

if (!indexContext) {
return cb(await this.aiQueryContextRepository.setIndexContext(
sessionMetadata,
databaseId,
auth.accountId,
index,
await getIndexContext(client, index),
));
}

return cb(indexContext);
} catch (e) {
this.logger.warn('Unable to create index content', e);
return cb(e.message);
}

return cb(indexContext);
} catch (e) {
this.logger.warn('Unable to create index content', e);
return cb(e.message);
}
});

socket.on(AiQueryWsEvents.RUN_QUERY, async (data, cb) => {
try {
if (!COMMANDS_WHITELIST[(data?.[0] || '').toLowerCase()]) {
return cb('-ERR: This command is not allowed');
});

socket.on(AiQueryWsEvents.RUN_QUERY, async (data, cb) => {
try {
if (!COMMANDS_WHITELIST[(data?.[0] || '').toLowerCase()]) {
return cb('-ERR: This command is not allowed');
}

return cb(await client.sendCommand(data, { replyEncoding: 'utf8' }));
} catch (e) {
this.logger.warn('Query execution error', e);
return cb(e.message);
}

return cb(await client.sendCommand(data, { replyEncoding: 'utf8' }));
} catch (e) {
this.logger.warn('Query execution error', e);
return cb(e.message);
}
});

socket.on(AiQueryWsEvents.TOOL_CALL, async (data) => {
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
type: AiQueryIntermediateStepType.TOOL_CALL,
data,
}));
});

socket.on(AiQueryWsEvents.TOOL_REPLY, async (data) => {
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
type: AiQueryIntermediateStepType.TOOL,
data,
}));
});

await socket.emitWithAck('stream', dto.content, context, AiQueryService.prepareHistory(history));
socket.close();
await this.aiQueryMessageRepository.createMany(sessionMetadata, [question, answer]);

return res.end();
} catch (e) {
socket?.close?.();
throw wrapAiQueryError(e, 'Unable to send the question');
}
});

socket.on(AiQueryWsEvents.TOOL_CALL, async (data) => {
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
type: AiQueryIntermediateStepType.TOOL_CALL,
data,
}));
});

socket.on(AiQueryWsEvents.TOOL_REPLY, async (data) => {
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
type: AiQueryIntermediateStepType.TOOL,
data,
}));
});

await socket.emitWithAck('stream', dto.content, context, AiQueryService.prepareHistory(history));
socket.close();
await this.aiQueryMessageRepository.createMany(sessionMetadata, [question, answer]);

return res.end();
} catch (e) {
socket?.close?.();
throw wrapAiQueryError(e, 'Unable to send the question');
}
});
}

async getHistory(sessionMetadata: SessionMetadata, databaseId: string): Promise<AiQueryMessage[]> {
try {
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
return await this.aiQueryMessageRepository.list(sessionMetadata, databaseId, auth.accountId);
} catch (e) {
throw wrapAiQueryError(e, 'Unable to get history');
}
return this.aiQueryAuthProvider.callWithAuthRetry(sessionMetadata, async () => {
try {
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
return await this.aiQueryMessageRepository.list(sessionMetadata, databaseId, auth.accountId);
} catch (e) {
throw wrapAiQueryError(e, 'Unable to get history');
}
});
}

async clearHistory(sessionMetadata: SessionMetadata, databaseId: string): Promise<void> {
try {
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
return this.aiQueryAuthProvider.callWithAuthRetry(sessionMetadata, async () => {
try {
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);

await this.aiQueryContextRepository.reset(sessionMetadata, databaseId, auth.accountId);
await this.aiQueryContextRepository.reset(sessionMetadata, databaseId, auth.accountId);

return this.aiQueryMessageRepository.clearHistory(sessionMetadata, databaseId, auth.accountId);
} catch (e) {
throw wrapAiQueryError(e, 'Unable to clear history');
}
return this.aiQueryMessageRepository.clearHistory(sessionMetadata, databaseId, auth.accountId);
} catch (e) {
throw wrapAiQueryError(e, 'Unable to clear history');
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export class AiQueryBadRequestException extends HttpException {
message,
statusCode: HttpStatus.BAD_REQUEST,
error: 'AiQueryBadRequest',
errorCode: CustomErrorCodes.QueryAiInternalServerError,
errorCode: CustomErrorCodes.QueryAiBadRequest,
};

super(response, response.statusCode, options);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { AxiosError } from 'axios';
import { get } from 'lodash';
import { HttpException } from '@nestjs/common';
import {
AiQueryUnauthorizedException,
Expand All @@ -13,11 +14,12 @@ export const wrapAiQueryError = (error: AxiosError, message?: string): HttpExcep
return error;
}

const { response } = error;
// TransportError or Axios error
const response = get(error, ['description', 'target', '_req', 'res'], error.response);

if (response) {
const errorOptions = { cause: new Error(response?.data as string) };
switch (response?.status) {
switch (response?.status || response?.statusCode) {
case 401:
return new AiQueryUnauthorizedException(message, errorOptions);
case 403:
Expand Down
Loading