diff --git a/frontend/desktop/src/components/settingfragments/advanced.tsx b/frontend/desktop/src/components/settingfragments/advanced.tsx index 6338975..e82af57 100644 --- a/frontend/desktop/src/components/settingfragments/advanced.tsx +++ b/frontend/desktop/src/components/settingfragments/advanced.tsx @@ -1,51 +1,94 @@ import React, { FunctionComponent, useEffect, useState } from 'react' import Constants from '../../constants' import eventService from '../../events/eventService' +import toast from 'react-hot-toast' const AdvancedSettings: FunctionComponent<{}> = () => { + const [openAIKey, setOpenAIKey] = useState("") + const [openAIModel, setOpenAIModel] = useState("") + const [modelOptions, setModelOptions] = useState<{ value: string }[]>([]) - const [openAIKey, setOpenAIKey] = useState("") + useEffect(() => { + (async () => { + const result = await eventService.listSupportedAIModels() + setModelOptions(result.data.map(model => ({ value: model }))) + })(); + (async () => { + const result = await eventService.getSingleSetting(Constants.SETTING_KEYS.OPENAI_KEY) + setOpenAIKey(result.data) + })(); + (async () => { + const result = await eventService.getSingleSetting(Constants.SETTING_KEYS.OPENAI_MODEL) + setOpenAIModel(result.data) + })(); + }, []) - useEffect(() => { - (async () => { - let result = await eventService.getSingleSetting(Constants.SETTING_KEYS.OPENAI_KEY) - setOpenAIKey(result.data) - })() - }, []) - const updateOpenAIKey = async () => { - const result = await eventService.updateSingleSetting(Constants.SETTING_KEYS.OPENAI_KEY, openAIKey) - if (result.success) - setOpenAIKey(openAIKey) + const updateOpenAIKey = async () => { + const result = await eventService.updateSingleSetting(Constants.SETTING_KEYS.OPENAI_KEY, openAIKey) + if (result.success) { + setOpenAIKey(openAIKey) + toast.success("saved") } + } - return ( - -

Advanced Settings

-
-

OpenAI Key

-

Update OpenAI API key to enable Generate SQL tool.

-
-
-

- ) => { setOpenAIKey(e.target.value) }} - placeholder="Enter API key" /> -

-

- - - -

-
-
-
-
- ) + const updateOpenAIModel = async () => { + const result = await eventService.updateSingleSetting(Constants.SETTING_KEYS.OPENAI_MODEL, openAIModel) + if (result.success) + setOpenAIModel(openAIModel) + toast.success("saved") + } + + const handleModelChange = (e: React.ChangeEvent) => { + const value = e.target.value + setOpenAIModel(value) + } + + return ( + +

Advanced Settings

+
+

OpenAI Key

+

Update OpenAI API key to enable Generate SQL tool.

+
+
+

+ ) => { setOpenAIKey(e.target.value) }} + placeholder="Enter API key" /> +

+

+ + + +

+
+
+

OpenAI Model

+

Update OpenAI Model to enable Generate SQL tool.

+
+

+ + + +

+

+ +

+
+
+
+ ) } export default AdvancedSettings diff --git a/frontend/desktop/src/constants.ts b/frontend/desktop/src/constants.ts index 5b903c8..ac00535 100644 --- a/frontend/desktop/src/constants.ts +++ b/frontend/desktop/src/constants.ts @@ -64,6 +64,7 @@ const Constants: ConstantsType = { TELEMETRY_ENABLED: "TELEMETRY_ENABLED", LOGS_EXPIRE: "LOGS_EXPIRE", OPENAI_KEY: "OPENAI_KEY", + OPENAI_MODEL:"OPENAI_MODEL" }, } diff --git a/frontend/desktop/src/events/constants.ts b/frontend/desktop/src/events/constants.ts index f398872..e992d8d 100644 --- a/frontend/desktop/src/events/constants.ts +++ b/frontend/desktop/src/events/constants.ts @@ -142,6 +142,10 @@ const Events: EventType = { REQUEST: "event:ai:gensql", RESPONSE: "response:ai:gensql" }, + AI_LIST_SUPPORTEDMODELS: { + REQUEST: "event:ai:listmodels", + RESPONSE: "response:ai:listmodels" + }, } export default Events \ No newline at end of file diff --git a/frontend/desktop/src/events/eventService.ts b/frontend/desktop/src/events/eventService.ts index 5a6529f..8e19d94 100644 --- a/frontend/desktop/src/events/eventService.ts +++ b/frontend/desktop/src/events/eventService.ts @@ -210,6 +210,12 @@ const runGenerateSQL = async function (dbConnId: string, text: string): Promise< return response } +const listSupportedAIModels = async function (): Promise> { + const response = responseEvent>(Events.AI_LIST_SUPPORTEDMODELS.RESPONSE) + EventsEmit(Events.AI_LIST_SUPPORTEDMODELS.REQUEST, Events.AI_LIST_SUPPORTEDMODELS.RESPONSE) + return response +} + export default { getHealthCheck, getProjects, @@ -244,5 +250,6 @@ export default { closeTab, runConsoleCommand, checkConnection, - runGenerateSQL + runGenerateSQL, + listSupportedAIModels } \ No newline at end of file diff --git a/frontend/server/src/components/settingfragments/advanced.tsx b/frontend/server/src/components/settingfragments/advanced.tsx index 1e18a87..fde106f 100644 --- a/frontend/server/src/components/settingfragments/advanced.tsx +++ b/frontend/server/src/components/settingfragments/advanced.tsx @@ -1,17 +1,27 @@ import React, { FunctionComponent, useEffect, useState } from 'react' import Constants from '../../constants' import apiService from '../../network/apiService' +import toast from 'react-hot-toast' const AdvancedSettings: FunctionComponent<{}> = () => { - const [openAIKey, setOpenAIKey] = useState("") + const [openAIModel, setOpenAIModel] = useState("") + const [modelOptions, setModelOptions] = useState<{ value: string }[]>([]) useEffect(() => { (async () => { - let result = await apiService.getSingleSetting(Constants.SETTING_KEYS.OPENAI_KEY) + const result = await apiService.listSupportedAIModels() + setModelOptions(result.data.map(model => ({ value: model }))) + })(); + (async () => { + const result = await apiService.getSingleSetting(Constants.SETTING_KEYS.OPENAI_KEY) setOpenAIKey(result.data) - })() + })(); + (async () => { + const result = await apiService.getSingleSetting(Constants.SETTING_KEYS.OPENAI_MODEL) + setOpenAIModel(result.data) + })(); }, []) const updateOpenAIKey = async () => { @@ -20,6 +30,18 @@ const AdvancedSettings: FunctionComponent<{}> = () => { setOpenAIKey(openAIKey) } + const updateOpenAIModel = async () => { + const result = await apiService.updateSingleSetting(Constants.SETTING_KEYS.OPENAI_MODEL, openAIModel) + if (result.success) + setOpenAIModel(openAIModel) + toast.success("saved") + } + + const handleModelChange = (e: React.ChangeEvent) => { + const value = e.target.value + setOpenAIModel(value) + } + return (

Advanced Settings

@@ -27,7 +49,7 @@ const AdvancedSettings: FunctionComponent<{}> = () => {

OpenAI Key

Update OpenAI API key to enable Generate SQL tool.

-
+

= () => {

+

OpenAI Model

+

Update OpenAI Model to enable Generate SQL tool.

+
+

+ + + +

+

+ +

+

) diff --git a/frontend/server/src/constants.ts b/frontend/server/src/constants.ts index 5ed18a5..a121e59 100644 --- a/frontend/server/src/constants.ts +++ b/frontend/server/src/constants.ts @@ -88,6 +88,7 @@ const Constants: ConstantsType = { TELEMETRY_ENABLED: "TELEMETRY_ENABLED", LOGS_EXPIRE: "LOGS_EXPIRE", OPENAI_KEY: "OPENAI_KEY", + OPENAI_MODEL: "OPENAI_MODEL" }, ROLES: { ADMIN: "Admin" diff --git a/frontend/server/src/network/apiService.ts b/frontend/server/src/network/apiService.ts index c435bd9..d772a7e 100644 --- a/frontend/server/src/network/apiService.ts +++ b/frontend/server/src/network/apiService.ts @@ -306,7 +306,11 @@ const generateSQL = async function (dbConnectionId: string, text: string): Promi .then(res => res.data) } - +const listSupportedAIModels = async function (): Promise> { + return await Request.apiInstance + .get>>(`/ai/listmodels`) + .then(res => res.data) +} export default { getHealthCheck, @@ -356,5 +360,6 @@ export default { addRole, deleteRole, updateRolePermission, - generateSQL + generateSQL, + listSupportedAIModels } \ No newline at end of file diff --git a/go.mod b/go.mod index 8ac19aa..327943d 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/jackc/pgx/v4 v4.17.2 github.com/joho/godotenv v1.4.0 github.com/posthog/posthog-go v0.0.0-20221221115252-24dfed35d71a - github.com/sashabaranov/go-openai v1.9.4 + github.com/sashabaranov/go-openai v1.17.2 github.com/tdewolff/parse/v2 v2.6.5 github.com/wailsapp/wails/v2 v2.4.1 github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 diff --git a/go.sum b/go.sum index 7d39b91..9f93cf9 100644 --- a/go.sum +++ b/go.sum @@ -355,8 +355,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/samber/lo v1.27.1 h1:sTXwkRiIFIQG+G0HeAvOEnGjqWeWtI9cg5/n51KrxPg= github.com/samber/lo v1.27.1/go.mod h1:it33p9UtPMS7z72fP4gw/EIfQB2eI8ke7GR2wc6+Rhg= -github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM= -github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.17.2 h1:Uj1Msqh43S9XhjUXYyOqOHMiRQtgQXCo5O0FeWZz7tU= +github.com/sashabaranov/go-openai v1.17.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3D/WJsDd1iXHT96alCoN2KJo6/4x1DZC3wZs8= diff --git a/internal/common/controllers/ai.go b/internal/common/controllers/ai.go index 52dcd66..e459445 100644 --- a/internal/common/controllers/ai.go +++ b/internal/common/controllers/ai.go @@ -32,3 +32,7 @@ func (AIController) GenerateSQL(dbConnectionID, text string) (string, error) { return ai.GenerateSQL(dbConn.Type, text, datamodels) } + +func (AIController) GetModels() []string { + return ai.ListSupportedOpenAiModels() +} diff --git a/internal/common/controllers/setting.go b/internal/common/controllers/setting.go index 5555dd7..08edfa1 100644 --- a/internal/common/controllers/setting.go +++ b/internal/common/controllers/setting.go @@ -15,9 +15,21 @@ type SettingController struct{} func (SettingController) GetSingleSetting(name string) (interface{}, error) { setting, err := dao.Setting.GetSingleSetting(name) + + switch name { + case models.SETTING_NAME_OPENAI_KEY: + return setting.Value, nil + case models.SETTING_NAME_OPENAI_MODEL: + if setting.Value == "" { + return ai.GetOpenAiModel(), nil + } + return setting.Value, nil + } + if err != nil { - return "", errors.New("there was some problem") + return "", errors.New("setting not found") } + switch setting.Name { case models.SETTING_NAME_APP_ID: return setting.UUID().String(), nil @@ -44,6 +56,11 @@ func (SettingController) UpdateSingleSetting(name string, value string) error { } case models.SETTING_NAME_OPENAI_KEY: ai.InitClient(value) + case models.SETTING_NAME_OPENAI_MODEL: + err := ai.SetOpenAiModel(value) + if err != nil { + return err + } default: return errors.New("invalid setting name: " + name) } diff --git a/internal/common/models/setting.go b/internal/common/models/setting.go index fc2aa69..055d401 100644 --- a/internal/common/models/setting.go +++ b/internal/common/models/setting.go @@ -16,6 +16,7 @@ const ( SETTING_NAME_TELEMETRY_ENABLED = "TELEMETRY_ENABLED" SETTING_NAME_LOGS_EXPIRE = "LOGS_EXPIRE" SETTING_NAME_OPENAI_KEY = "OPENAI_KEY" + SETTING_NAME_OPENAI_MODEL = "OPENAI_MODEL" ) func NewSetting(name string, value string) *Setting { diff --git a/internal/desktop/app/events.go b/internal/desktop/app/events.go index 56b885d..d52536a 100644 --- a/internal/desktop/app/events.go +++ b/internal/desktop/app/events.go @@ -62,5 +62,6 @@ func setupEvents(ctx context.Context) { } if aiEventListeners := new(events.AIEventListeners); true { aiEventListeners.GenSQLEvent(ctx) + aiEventListeners.ListSupportedAIModelsEvent(ctx) } } diff --git a/internal/desktop/events/ai.go b/internal/desktop/events/ai.go index badc1c7..a6f00d2 100644 --- a/internal/desktop/events/ai.go +++ b/internal/desktop/events/ai.go @@ -13,7 +13,8 @@ type AIEventListeners struct{} var aiController controllers.AIController const ( - eventAIGenSQL = "event:ai:gensql" + eventAIGenSQL = "event:ai:gensql" + eventAIListModels = "event:ai:listmodels" ) func (AIEventListeners) GenSQLEvent(ctx context.Context) { @@ -37,3 +38,15 @@ func (AIEventListeners) GenSQLEvent(ctx context.Context) { }) }) } + +func (AIEventListeners) ListSupportedAIModelsEvent(ctx context.Context) { + runtime.EventsOn(ctx, eventAIListModels, func(args ...interface{}) { + responseEventName := args[0].(string) + defer recovery(ctx, responseEventName) + output := aiController.GetModels() + runtime.EventsEmit(ctx, responseEventName, map[string]interface{}{ + "success": true, + "data": output, + }) + }) +} diff --git a/internal/desktop/events/tabs.go b/internal/desktop/events/tabs.go index ba00f18..34fdd5e 100644 --- a/internal/desktop/events/tabs.go +++ b/internal/desktop/events/tabs.go @@ -28,7 +28,10 @@ func (TabsEventListeners) CreateNewTab(ctx context.Context) { modelschema := args[3].(string) modelname := args[4].(string) queryID := args[5].(string) - query := args[6].(string) + query := "" + if str, ok := args[6].(string); ok { + query = str + } tab, err := tabController.CreateTab(dbConnectionId, tabType, modelschema, modelname, queryID, query) if err != nil { runtime.EventsEmit(ctx, responseEventName, map[string]interface{}{ diff --git a/internal/desktop/setup/setup.go b/internal/desktop/setup/setup.go index 2863ba4..81d2942 100644 --- a/internal/desktop/setup/setup.go +++ b/internal/desktop/setup/setup.go @@ -41,5 +41,7 @@ func initAIClient() { setting, err := dao.Setting.GetSingleSetting(models.SETTING_NAME_OPENAI_KEY) if err == nil { ai.InitClient(setting.Value) + setting, _ := dao.Setting.GetSingleSetting(models.SETTING_NAME_OPENAI_MODEL) + ai.SetOpenAiModel(setting.Value) } } diff --git a/internal/server/app/router.go b/internal/server/app/router.go index 1ca13d8..ec4338d 100644 --- a/internal/server/app/router.go +++ b/internal/server/app/router.go @@ -14,6 +14,7 @@ import ( func SetupRoutes(app *fiber.App, assets embed.FS) { api := app.Group("/api/v1") { + api.Use(middlewares.APIResponseMiddleware()) api.Get("health", healthCheck) userGroup := api.Group("user") { @@ -121,6 +122,7 @@ func SetupRoutes(app *fiber.App, assets embed.FS) { aiGroup.Use(middlewares.FindUserMiddleware()) aiGroup.Use(middlewares.AuthUserMiddleware()) aiGroup.Post("/gensql", aiHandlers.GenerateSQL) + aiGroup.Get("/listmodels", aiHandlers.ListSupportedAIModels) } } diff --git a/internal/server/handlers/ai.go b/internal/server/handlers/ai.go index 7f78771..fd1cf55 100644 --- a/internal/server/handlers/ai.go +++ b/internal/server/handlers/ai.go @@ -16,10 +16,7 @@ func (AIHandlers) GenerateSQL(c *fiber.Ctx) error { Text string `json:"text"` } if err := c.BodyParser(&body); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } analytics.SendAISQLGeneratedEvent() output, err := aiController.GenerateSQL(body.DBConnectionID, body.Text) @@ -34,3 +31,11 @@ func (AIHandlers) GenerateSQL(c *fiber.Ctx) error { "data": output, }) } + +func (AIHandlers) ListSupportedAIModels(c *fiber.Ctx) error { + output := aiController.GetModels() + return c.JSON(map[string]interface{}{ + "success": true, + "data": output, + }) +} diff --git a/internal/server/handlers/console.go b/internal/server/handlers/console.go index d70768a..49fe2b6 100644 --- a/internal/server/handlers/console.go +++ b/internal/server/handlers/console.go @@ -18,10 +18,7 @@ func (ConsoleHandlers) RunCommand(c *fiber.Ctx) error { CmdString string `json:"cmd"` } if err := c.BodyParser(&body); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } analytics.SendRunCommandEvent() output := consoleController.RunCommand(authUser, body.DBConnectionID, body.CmdString) diff --git a/internal/server/handlers/dbconnection.go b/internal/server/handlers/dbconnection.go index 9381dd0..3240919 100644 --- a/internal/server/handlers/dbconnection.go +++ b/internal/server/handlers/dbconnection.go @@ -31,10 +31,7 @@ func (DBConnectionHandlers) CreateDBConnection(c *fiber.Ctx) error { IsTest bool `json:"isTest"` } if err := c.BodyParser(&createBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } dbConn, err := dbConnController.CreateDBConnection( createBody.ProjectID, diff --git a/internal/server/handlers/project.go b/internal/server/handlers/project.go index 9f79868..117df18 100644 --- a/internal/server/handlers/project.go +++ b/internal/server/handlers/project.go @@ -18,10 +18,7 @@ func (ProjectHandlers) CreateProject(c *fiber.Ctx) error { Name string `json:"name"` } if err := c.BodyParser(&createBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } project, projectMember, err := projectController.CreateProject(authUser, createBody.Name) if err != nil { @@ -104,10 +101,7 @@ func (ProjectHandlers) AddProjectMember(c *fiber.Ctx) error { RoleID string `json:"roleId"` } if err := c.BodyParser(&addMemberBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } newProjectMember, err := projectController.AddProjectMember(authUser, projectID, addMemberBody.Email, addMemberBody.RoleID) if err != nil { diff --git a/internal/server/handlers/query.go b/internal/server/handlers/query.go index 5154245..136f67c 100644 --- a/internal/server/handlers/query.go +++ b/internal/server/handlers/query.go @@ -24,10 +24,7 @@ func (QueryHandlers) RunQuery(c *fiber.Ctx) error { Query string `json:"query"` } if err := c.BodyParser(&runBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } analytics.SendRunQueryEvent() response, err := queryController.RunQuery(authUser, runBody.DBConnectionID, runBody.Query) @@ -128,10 +125,7 @@ func (QueryHandlers) AddSingleDataModelField(c *fiber.Ctx) error { DataType string `json:"dataType"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.AddSingleDataModelField(authUser, authUserProjectIds, reqBody.DBConnectionID, reqBody.Schema, reqBody.Name, reqBody.FieldName, reqBody.DataType) if err != nil { @@ -156,10 +150,7 @@ func (QueryHandlers) DeleteSingleDataModelField(c *fiber.Ctx) error { FieldName string `json:"fieldName"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.DeleteSingleDataModelField(authUser, authUserProjectIds, reqBody.DBConnectionID, reqBody.Schema, reqBody.Name, reqBody.FieldName) if err != nil { @@ -183,10 +174,7 @@ func (QueryHandlers) AddData(c *fiber.Ctx) error { Data map[string]interface{} `json:"data"` } if err := c.BodyParser(&addBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.AddData(authUser, dbConnId, addBody.Schema, addBody.Name, addBody.Data) if err != nil { @@ -210,10 +198,7 @@ func (QueryHandlers) DeleteData(c *fiber.Ctx) error { IDs []string `json:"ids"` // ctid for postgres, _id for mongo } if err := c.BodyParser(&deleteBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.DeleteData(authUser, dbConnId, deleteBody.Schema, deleteBody.Name, deleteBody.IDs) if err != nil { @@ -239,10 +224,7 @@ func (QueryHandlers) UpdateSingleData(c *fiber.Ctx) error { Value string `json:"value"` } if err := c.BodyParser(&updateBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.UpdateSingleData(authUser, dbConnId, updateBody.Schema, updateBody.Name, updateBody.ID, updateBody.ColumnName, updateBody.Value) if err != nil { @@ -268,10 +250,7 @@ func (QueryHandlers) AddSingleDataModelIndex(c *fiber.Ctx) error { IsUnique bool `json:"isUnique"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.AddSingleDataModelIndex(authUser, reqBody.DBConnectionID, reqBody.Schema, reqBody.Name, reqBody.IndexName, reqBody.FieldNames, reqBody.IsUnique) if err != nil { @@ -295,10 +274,7 @@ func (QueryHandlers) DeleteSingleDataModelIndex(c *fiber.Ctx) error { IndexName string `json:"indexName"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } responseData, err := queryController.DeleteSingleDataModelIndex(authUser, reqBody.DBConnectionID, reqBody.Schema, reqBody.Name, reqBody.IndexName) if err != nil { @@ -323,10 +299,7 @@ func (QueryHandlers) SaveDBQuery(c *fiber.Ctx) error { QueryID string `json:"queryId"` } if err := c.BodyParser(&createBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } analytics.SendSavedQueryEvent() queryObj, err := queryController.SaveDBQuery(authUser, authUserProjectIds, dbConnId, createBody.Name, createBody.Query, createBody.QueryID) diff --git a/internal/server/handlers/role.go b/internal/server/handlers/role.go index 0260302..13cbd13 100644 --- a/internal/server/handlers/role.go +++ b/internal/server/handlers/role.go @@ -49,10 +49,7 @@ func (RoleHandlers) AddRole(c *fiber.Ctx) error { Name string `json:"name"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } authUser := middlewares.GetAuthUser(c) role, err := roleController.AddRole(authUser, reqBody.Name) @@ -90,10 +87,7 @@ func (RoleHandlers) UpdateRolePermission(c *fiber.Ctx) error { Value bool `json:"value"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } authUser := middlewares.GetAuthUser(c) rp, err := roleController.AddOrUpdateRolePermission(authUser, roleID, reqBody.Name, reqBody.Value) diff --git a/internal/server/handlers/setting.go b/internal/server/handlers/setting.go index 9d16b5d..c795182 100644 --- a/internal/server/handlers/setting.go +++ b/internal/server/handlers/setting.go @@ -30,10 +30,7 @@ func (SettingHandlers) UpdateSingleSetting(c *fiber.Ctx) error { Value string `json:"value"` } if err := c.BodyParser(&reqBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } err := settingController.UpdateSingleSetting(reqBody.Name, reqBody.Value) if err != nil { diff --git a/internal/server/handlers/tabs.go b/internal/server/handlers/tabs.go index 00ba158..52a6842 100644 --- a/internal/server/handlers/tabs.go +++ b/internal/server/handlers/tabs.go @@ -22,10 +22,7 @@ func (TabsHandlers) CreateNewTab(c *fiber.Ctx) error { Query string `json:"query"` } if err := c.BodyParser(&createBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } tab, err := tabController.CreateTab(authUser.ID, createBody.DBConnectionId, createBody.TabType, createBody.Modelschema, createBody.Modelname, createBody.QueryID, createBody.Query) if err != nil { @@ -69,10 +66,7 @@ func (TabsHandlers) UpdateTab(c *fiber.Ctx) error { Metadata map[string]interface{} `json:"metadata"` } if err := c.BodyParser(&updateBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } tab, err := tabController.UpdateTab(authUser.ID, updateBody.DBConnectionID, updateBody.TabID, updateBody.TabType, updateBody.Metadata) if err != nil { diff --git a/internal/server/handlers/user.go b/internal/server/handlers/user.go index d3188d7..13190c2 100644 --- a/internal/server/handlers/user.go +++ b/internal/server/handlers/user.go @@ -21,10 +21,7 @@ func (UserHandlers) LoginUser(c *fiber.Ctx) error { Password string `json:"password"` } if err := c.BodyParser(&loginBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } userSession, err := userController.LoginUser(loginBody.Email, loginBody.Password) @@ -70,10 +67,7 @@ func (UserHandlers) EditAccount(c *fiber.Ctx) error { ProfileImageURL string `json:"profileImageUrl"` } if err := c.BodyParser(&userBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } err := userController.EditAccount(authUser, userBody.Name, userBody.ProfileImageURL) if err != nil { @@ -95,10 +89,7 @@ func (UserHandlers) ChangePassword(c *fiber.Ctx) error { NewPassword string `json:"newPassword"` } if err := c.BodyParser(&body); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } err := userController.ChangePassword(authUser, body.OldPassword, body.NewPassword) if err != nil { @@ -153,10 +144,7 @@ func (UserHandlers) AddUsers(c *fiber.Ctx) error { Password string `json:"password"` } if err := c.BodyParser(&addUserBody); err != nil { - return c.JSON(map[string]interface{}{ - "success": false, - "error": err.Error(), - }) + return fiber.ErrBadRequest } err := userController.AddUser(authUser, addUserBody.Email, addUserBody.Password) if err != nil { diff --git a/internal/server/middlewares/middlewares.go b/internal/server/middlewares/middlewares.go index 5f4470a..b643cb0 100644 --- a/internal/server/middlewares/middlewares.go +++ b/internal/server/middlewares/middlewares.go @@ -1,6 +1,8 @@ package middlewares import ( + "encoding/json" + "errors" "strings" "github.com/gofiber/fiber/v2" @@ -13,7 +15,42 @@ const ( USER_SESSION = "USER_SESSION" ) -// FindUserMiddleware is find authenticated user before sending the request to next handler +func APIResponseMiddleware() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + err := c.Next() + if err != nil { + code := fiber.StatusInternalServerError + var e *fiber.Error + if errors.As(err, &e) { + code = e.Code + } + return c.Status(code).JSON(map[string]interface{}{ + "success": false, + "error": err.Error(), + }) + } + if err == nil { + response := c.Response() + body := response.Body() + var data interface{} + json.Unmarshal(body, &data) + + if resMap, ok := data.(map[string]interface{}); ok { + if _, ok := resMap["success"]; ok { + return c.JSON(resMap) + } + } + + return c.JSON(map[string]interface{}{ + "success": true, + "data": data, + }) + } + return nil + } +} + +// FindUserMiddleware is to find authenticated user before sending the request to next handler func FindUserMiddleware() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { tokenString := c.Cookies(config.SESSION_COOKIE_NAME, "") @@ -35,16 +72,13 @@ func FindUserMiddleware() func(c *fiber.Ctx) error { } } -// AuthUserMiddleware is checks if authUser is present else returns unauthorized error +// AuthUserMiddleware is to check if authUser is present else returns unauthorized error func AuthUserMiddleware() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { if value := c.Context().UserValue(USER_SESSION); value != nil { return c.Next() } - return c.JSON(map[string]interface{}{ - "success": false, - "error": "Unauthorized", - }) + return fiber.ErrUnauthorized } } diff --git a/pkg/ai/client.go b/pkg/ai/client.go index cc5d37d..f568799 100644 --- a/pkg/ai/client.go +++ b/pkg/ai/client.go @@ -1,10 +1,26 @@ package ai import ( + "errors" + openai "github.com/sashabaranov/go-openai" + "github.com/slashbaseide/slashbase/internal/common/utils" ) -var client *openai.Client +var ( + client *openai.Client + openAiModel string = openai.GPT3Dot5Turbo + supportedOpenAIModels = []string{ + openai.GPT3Dot5TurboInstruct, + openai.GPT3Dot5Turbo16K, + openai.GPT3Dot5Turbo, + openai.GPT3Dot5Turbo1106, + openai.GPT4, + openai.GPT432K, + openai.GPT40613, + openai.GPT432K0613, + } +) func InitClient(token string) { if token == "" { @@ -13,3 +29,22 @@ func InitClient(token string) { } client = openai.NewClient(token) } + +func GetOpenAiModel() string { + return openAiModel +} + +func ListSupportedOpenAiModels() []string { + return supportedOpenAIModels +} + +func SetOpenAiModel(model string) error { + if model == "" { + return errors.New("cannot be empty") + } + if utils.ContainsString(supportedOpenAIModels, model) { + openAiModel = model + return nil + } + return errors.New("invalid model") +} diff --git a/pkg/ai/gensql.go b/pkg/ai/gensql.go index 8ee19ac..2101e4d 100644 --- a/pkg/ai/gensql.go +++ b/pkg/ai/gensql.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "regexp" "strings" openai "github.com/sashabaranov/go-openai" @@ -21,24 +22,48 @@ func GenerateSQL(dbtype, text string, datamodels []*qemodels.DBDataModel) (strin } dbDataModelDescription := generateDBDataModelsDescription(datamodels) - prompt := fmt.Sprintf("### %s SQL tables, with their properties:\n#\n#%s\n#\n### A query to %s:\n\n", dbtype, dbDataModelDescription, text) - - req := openai.CompletionRequest{ - Model: openai.GPT3TextDavinci003, - Temperature: 0, - MaxTokens: 150, - TopP: 1, - FrequencyPenalty: 0, - PresencePenalty: 0, - Stop: []string{"#", ";"}, - Prompt: prompt, + systemMessage := "No text, just write SQL query with ```sql." + prompt := fmt.Sprintf("%s SQL tables, with their properties:\n%s\n\nWrite a query to %s", dbtype, dbDataModelDescription, text) + + req := openai.ChatCompletionRequest{ + Model: openAiModel, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: systemMessage, + }, + { + Role: openai.ChatMessageRoleUser, + Content: prompt, + }, + }, } - resp, err := client.CreateCompletion(context.Background(), req) + resp, err := client.CreateChatCompletion(context.Background(), req) if err != nil { return "", fmt.Errorf("completion error: %v", err) } - return resp.Choices[0].Text, nil + if len(resp.Choices) == 0 { + return "", errors.New("empty response") + } + + messageContent := resp.Choices[0].Message.Content + + if strings.Contains(messageContent, "```sql") { + re, _ := regexp.Compile("```sql[\\s\\S]([\\s\\S]*)[\\s\\S]```") + submatches := re.FindStringSubmatch(messageContent) + if len(submatches) == 2 { + messageContent = submatches[1] + } + } else if strings.Contains(messageContent, "```") { + re, _ := regexp.Compile("```[\\s\\S]([\\s\\S]*)[\\s\\S]```") + submatches := re.FindStringSubmatch(messageContent) + if len(submatches) == 2 { + messageContent = submatches[1] + } + } + + return messageContent, nil } func generateDBDataModelsDescription(datamodels []*qemodels.DBDataModel) string { @@ -52,7 +77,7 @@ func generateDBDataModelsDescription(datamodels []*qemodels.DBDataModel) string } fields = append(fields, fname) } - desc += fmt.Sprintf("# %s (%s) \n", dm.Name, strings.Join(fields, ", ")) + desc += fmt.Sprintf("- %s (%s) \n", dm.Name, strings.Join(fields, ", ")) } return desc } diff --git a/wails.json b/wails.json index ec64dad..365ec87 100644 --- a/wails.json +++ b/wails.json @@ -13,7 +13,7 @@ }, "info": { "productName": "Slashbase", - "productVersion": "v0.10.1", + "productVersion": "v0.10.2", "copyright": "Copyright © Slashbase.com", "comments": "Open-source Modern database IDE" }