From 33ff9e1b51a11169ae0e4d5dd08eb04ee6f33173 Mon Sep 17 00:00:00 2001 From: aman Date: Fri, 17 Apr 2026 10:08:25 +0530 Subject: [PATCH] fix: use parameterized queries in RQL utility functions --- .../store/postgres/audit_record_repository.go | 2 +- internal/store/postgres/org_billing_repository.go | 12 ++++++------ .../store/postgres/org_billing_repository_test.go | 4 ++-- .../store/postgres/org_projects_repository.go | 6 +++--- .../postgres/org_projects_repository_test.go | 4 ++-- internal/store/postgres/org_users_repository.go | 4 ++-- .../store/postgres/org_users_repository_test.go | 2 +- internal/store/postgres/prospect_repository.go | 2 +- internal/store/postgres/userpat_repository.go | 2 +- pkg/utils/rql.go | 15 ++++++++------- 10 files changed, 27 insertions(+), 26 deletions(-) diff --git a/internal/store/postgres/audit_record_repository.go b/internal/store/postgres/audit_record_repository.go index 9a0a7801a..b8bfbfbde 100644 --- a/internal/store/postgres/audit_record_repository.go +++ b/internal/store/postgres/audit_record_repository.go @@ -342,7 +342,7 @@ func (r AuditRecordRepository) buildFilteredQuery(rqlQuery *rql.Query) (*goqu.Se rqlQuery = utils.NewRQLQuery("", utils.DefaultOffset, utils.DefaultLimit, []rql.Filter{}, []rql.Sort{}, []string{}) } - baseStmt := dialect.From(TABLE_AUDITRECORDS).Where(goqu.Ex{"deleted_at": nil}) + baseStmt := dialect.From(TABLE_AUDITRECORDS).Prepared(true).Where(goqu.Ex{"deleted_at": nil}) // Apply filters baseStmt, err := utils.AddRQLFiltersInQuery(baseStmt, rqlQuery, auditRecordRQLFilterSupportedColumns, auditrecord.AuditRecordRQLSchema{}) diff --git a/internal/store/postgres/org_billing_repository.go b/internal/store/postgres/org_billing_repository.go index 1a9d85598..b3a6bc7d3 100644 --- a/internal/store/postgres/org_billing_repository.go +++ b/internal/store/postgres/org_billing_repository.go @@ -407,9 +407,9 @@ func addRQLSortInQuery(query *goqu.SelectDataset, rql *rql.Query) (*goqu.SelectD func processStringDataType(filter rql.Filter, query *goqu.SelectDataset) *goqu.SelectDataset { switch filter.Operator { case OPERATOR_EMPTY: - query = query.Where(goqu.L(fmt.Sprintf("coalesce(%s, '') = ''", filter.Name))) + query = query.Where(goqu.Or(goqu.I(filter.Name).IsNull(), goqu.I(filter.Name).Eq(""))) case OPERATOR_NOT_EMPTY: - query = query.Where(goqu.L(fmt.Sprintf("coalesce(%s, '') != ''", filter.Name))) + query = query.Where(goqu.And(goqu.I(filter.Name).IsNotNull(), goqu.I(filter.Name).Neq(""))) case OPERATOR_IN: // process the values of in operator as comma separated list query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").In(strings.Split(filter.Value.(string), ","))) @@ -418,14 +418,14 @@ func processStringDataType(filter rql.Filter, query *goqu.SelectDataset) *goqu.S query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").NotIn(strings.Split(filter.Value.(string), ","))) case OPERATOR_LIKE: // some semi string sql types like UUID require casting to text to support like operator - query = query.Where(goqu.L(fmt.Sprintf(`"%s"::TEXT LIKE '%s'`, filter.Name, filter.Value.(string)))) + query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").Like(filter.Value.(string))) case OPERATOR_NOT_LIKE: // some semi string sql types like UUID require casting to text to support like operator - query = query.Where(goqu.L(fmt.Sprintf(`"%s"::TEXT NOT LIKE '%s'`, filter.Name, filter.Value.(string)))) + query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").NotLike(filter.Value.(string))) case OPERATOR_ILIKE: - query = query.Where(goqu.L(fmt.Sprintf(`"%s"::TEXT ILIKE '%s'`, filter.Name, filter.Value.(string)))) + query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").ILike(filter.Value.(string))) case OPERATOR_NOT_ILIKE: - query = query.Where(goqu.L(fmt.Sprintf(`"%s"::TEXT NOT ILIKE '%s'`, filter.Name, filter.Value.(string)))) + query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").NotILike(filter.Value.(string))) default: query = query.Where(goqu.Ex{filter.Name: goqu.Op{filter.Operator: filter.Value.(string)}}) } diff --git a/internal/store/postgres/org_billing_repository_test.go b/internal/store/postgres/org_billing_repository_test.go index 67ad4c882..6eeb47762 100644 --- a/internal/store/postgres/org_billing_repository_test.go +++ b/internal/store/postgres/org_billing_repository_test.go @@ -103,8 +103,8 @@ func TestPrepareDataQuery(t *testing.T) { Limit: 20, Offset: 40, }, - wantSQL: `SELECT "id", "title", "name", "state", "avatar", "updated_at", "created_at", "created_by", "country", "plan_id", "plan_name", "subscription_state", "subscription_cycle_end_at", "plan_interval", "payment_mode" FROM (SELECT "organizations"."id" AS "id", "organizations"."title" AS "title", "organizations"."name" AS "name", "organizations"."avatar" AS "avatar", "organizations"."created_at" AS "created_at", "organizations"."updated_at" AS "updated_at", "organizations"."state" AS "state", organizations.metadata->>'country' AS "country", organizations.metadata->>'poc' AS "created_by", "billing_plans"."id" AS "plan_id", "billing_plans"."name" AS "plan_name", "billing_plans"."interval" AS "plan_interval", "billing_subscriptions"."state" AS "subscription_state", "billing_subscriptions"."trial_ends_at", "billing_subscriptions"."current_period_end_at" AS "subscription_cycle_end_at", "billing_customers"."payment_mode" AS "payment_mode", ROW_NUMBER() OVER (PARTITION BY "organizations"."id" ORDER BY "billing_subscriptions"."created_at" DESC) AS "row_num" FROM "organizations" LEFT JOIN "billing_customers" ON ("organizations"."id" = "billing_customers"."org_id") LEFT JOIN "billing_subscriptions" ON (("billing_subscriptions"."customer_id" = "billing_customers"."id") AND ("billing_subscriptions"."state" != $1)) LEFT JOIN "billing_plans" ON ("billing_plans"."id" = "billing_subscriptions"."plan_id")) AS "ranked_subscriptions" WHERE (("row_num" = $2) AND ("state" = $3) AND (CAST("plan_name" AS TEXT) IN ($4, $5)) AND coalesce(subscription_state, '') != '' AND ((CAST("id" AS TEXT) ILIKE $6) OR (CAST("title" AS TEXT) ILIKE $7) OR (CAST("state" AS TEXT) ILIKE $8) OR (CAST("plan_name" AS TEXT) ILIKE $9) OR (CAST("subscription_state" AS TEXT) ILIKE $10) OR (CAST("plan_interval" AS TEXT) ILIKE $11))) ORDER BY "created_at" DESC, "title" ASC LIMIT $12 OFFSET $13`, - wantParameters: []interface{}{"canceled", int64(1), "active", "free", "premium", "%test%", "%test%", "%test%", "%test%", "%test%", "%test%", int64(20), int64(40)}, + wantSQL: `SELECT "id", "title", "name", "state", "avatar", "updated_at", "created_at", "created_by", "country", "plan_id", "plan_name", "subscription_state", "subscription_cycle_end_at", "plan_interval", "payment_mode" FROM (SELECT "organizations"."id" AS "id", "organizations"."title" AS "title", "organizations"."name" AS "name", "organizations"."avatar" AS "avatar", "organizations"."created_at" AS "created_at", "organizations"."updated_at" AS "updated_at", "organizations"."state" AS "state", organizations.metadata->>'country' AS "country", organizations.metadata->>'poc' AS "created_by", "billing_plans"."id" AS "plan_id", "billing_plans"."name" AS "plan_name", "billing_plans"."interval" AS "plan_interval", "billing_subscriptions"."state" AS "subscription_state", "billing_subscriptions"."trial_ends_at", "billing_subscriptions"."current_period_end_at" AS "subscription_cycle_end_at", "billing_customers"."payment_mode" AS "payment_mode", ROW_NUMBER() OVER (PARTITION BY "organizations"."id" ORDER BY "billing_subscriptions"."created_at" DESC) AS "row_num" FROM "organizations" LEFT JOIN "billing_customers" ON ("organizations"."id" = "billing_customers"."org_id") LEFT JOIN "billing_subscriptions" ON (("billing_subscriptions"."customer_id" = "billing_customers"."id") AND ("billing_subscriptions"."state" != $1)) LEFT JOIN "billing_plans" ON ("billing_plans"."id" = "billing_subscriptions"."plan_id")) AS "ranked_subscriptions" WHERE (("row_num" = $2) AND ("state" = $3) AND (CAST("plan_name" AS TEXT) IN ($4, $5)) AND (("subscription_state" IS NOT NULL) AND ("subscription_state" != $6)) AND ((CAST("id" AS TEXT) ILIKE $7) OR (CAST("title" AS TEXT) ILIKE $8) OR (CAST("state" AS TEXT) ILIKE $9) OR (CAST("plan_name" AS TEXT) ILIKE $10) OR (CAST("subscription_state" AS TEXT) ILIKE $11) OR (CAST("plan_interval" AS TEXT) ILIKE $12))) ORDER BY "created_at" DESC, "title" ASC LIMIT $13 OFFSET $14`, + wantParameters: []interface{}{"canceled", int64(1), "active", "free", "premium", "", "%test%", "%test%", "%test%", "%test%", "%test%", "%test%", int64(20), int64(40)}, wantErr: false, }, { diff --git a/internal/store/postgres/org_projects_repository.go b/internal/store/postgres/org_projects_repository.go index 32b20523b..16faa63b8 100644 --- a/internal/store/postgres/org_projects_repository.go +++ b/internal/store/postgres/org_projects_repository.go @@ -194,9 +194,9 @@ func (r OrgProjectsRepository) applyStringFilter(filter rql.Filter, field string switch filter.Operator { case OPERATOR_EMPTY: - condition = goqu.L(fmt.Sprintf("coalesce(%s, '') = ''", field)) + condition = goqu.Or(goqu.I(field).IsNull(), goqu.I(field).Eq("")) case OPERATOR_NOT_EMPTY: - condition = goqu.L(fmt.Sprintf("coalesce(%s, '') != ''", field)) + condition = goqu.And(goqu.I(field).IsNotNull(), goqu.I(field).Neq("")) case OPERATOR_IN, OPERATOR_NOT_IN: condition = goqu.Ex{field: goqu.Op{filter.Operator: strings.Split(filter.Value.(string), ",")}} case OPERATOR_LIKE, OPERATOR_NOT_LIKE: @@ -216,7 +216,7 @@ func (r OrgProjectsRepository) applyStringFilter(filter rql.Filter, field string func (r OrgProjectsRepository) applyDatetimeFilter(filter rql.Filter, field string, stmt *goqu.SelectDataset) *goqu.SelectDataset { condition := goqu.Ex{ field: goqu.Op{ - filter.Operator: goqu.L(fmt.Sprintf("timestamp '%s'", filter.Value)), + filter.Operator: goqu.Cast(goqu.V(filter.Value), "TIMESTAMP"), }, } return stmt.Where(condition) diff --git a/internal/store/postgres/org_projects_repository_test.go b/internal/store/postgres/org_projects_repository_test.go index d1a420f2c..d3681cac1 100644 --- a/internal/store/postgres/org_projects_repository_test.go +++ b/internal/store/postgres/org_projects_repository_test.go @@ -59,8 +59,8 @@ func TestOrgProjectsRepository_prepareDataQuery(t *testing.T) { Limit: 10, Offset: 0, }, - wantSQL: `SELECT "projects"."id", "projects"."name", "projects"."title", "projects"."state", "projects"."created_at", "projects"."org_id", COUNT(DISTINCT("policies"."principal_id")) AS "member_count", array_agg(DISTINCT users.id) AS "user_ids" FROM "policies" INNER JOIN "projects" ON ("policies"."resource_id" = "projects"."id") INNER JOIN "users" ON ("policies"."principal_id" = "users"."id") WHERE ((("principal_type" = $1) AND ("projects"."org_id" = $2)) AND ("projects"."created_at" > timestamp '2023-11-02T12:10:21.470756Z')) GROUP BY "projects"."id", "projects"."name", "projects"."title", "projects"."state", "projects"."created_at", "projects"."org_id" LIMIT $3`, - wantArgs: []interface{}{"app/user", "org123", int64(10)}, + wantSQL: `SELECT "projects"."id", "projects"."name", "projects"."title", "projects"."state", "projects"."created_at", "projects"."org_id", COUNT(DISTINCT("policies"."principal_id")) AS "member_count", array_agg(DISTINCT users.id) AS "user_ids" FROM "policies" INNER JOIN "projects" ON ("policies"."resource_id" = "projects"."id") INNER JOIN "users" ON ("policies"."principal_id" = "users"."id") WHERE ((("principal_type" = $1) AND ("projects"."org_id" = $2)) AND ("projects"."created_at" > CAST($3 AS TIMESTAMP))) GROUP BY "projects"."id", "projects"."name", "projects"."title", "projects"."state", "projects"."created_at", "projects"."org_id" LIMIT $4`, + wantArgs: []interface{}{"app/user", "org123", "2023-11-02T12:10:21.470756Z", int64(10)}, wantErr: false, }, { diff --git a/internal/store/postgres/org_users_repository.go b/internal/store/postgres/org_users_repository.go index 52fb327c6..86422e2d1 100644 --- a/internal/store/postgres/org_users_repository.go +++ b/internal/store/postgres/org_users_repository.go @@ -337,9 +337,9 @@ func (r OrgUsersRepository) buildNonRoleFilterCondition(filter rql.Filter) (goqu switch filter.Operator { case "empty": - return goqu.L(fmt.Sprintf("coalesce(%s, '') = ''", columnName)), nil + return goqu.Or(goqu.I(columnName).IsNull(), goqu.I(columnName).Eq("")), nil case "notempty": - return goqu.L(fmt.Sprintf("coalesce(%s, '') != ''", columnName)), nil + return goqu.And(goqu.I(columnName).IsNotNull(), goqu.I(columnName).Neq("")), nil case "in", "notin": return goqu.Ex{columnName: goqu.Op{filter.Operator: strings.Split(filter.Value.(string), ",")}}, nil case "like": diff --git a/internal/store/postgres/org_users_repository_test.go b/internal/store/postgres/org_users_repository_test.go index ebe591472..cd91af004 100644 --- a/internal/store/postgres/org_users_repository_test.go +++ b/internal/store/postgres/org_users_repository_test.go @@ -156,7 +156,7 @@ func TestOrgUsersRepository_BuildNonRoleFilterCondition(t *testing.T) { Name: "title", Operator: "empty", }, - wantSQL: `coalesce(users.title, '') = ''`, + wantSQL: `(("users"."title" IS NULL) OR ("users"."title" = ''))`, }, { name: "invalid operator", diff --git a/internal/store/postgres/prospect_repository.go b/internal/store/postgres/prospect_repository.go index a2e4d94bb..fe1e113e5 100644 --- a/internal/store/postgres/prospect_repository.go +++ b/internal/store/postgres/prospect_repository.go @@ -130,7 +130,7 @@ func (r ProspectRepository) Get(ctx context.Context, id string) (prospect.Prospe } func (r ProspectRepository) List(ctx context.Context, rqlQuery *rql.Query) (prospect.ListProspects, error) { - baseStmt := dialect.From(TABLE_PROSPECTS) + baseStmt := dialect.From(TABLE_PROSPECTS).Prepared(true) // apply filters baseStmt, err := utils.AddRQLFiltersInQuery(baseStmt, rqlQuery, rqlFilerSupportedColumns, prospect.Prospect{}) diff --git a/internal/store/postgres/userpat_repository.go b/internal/store/postgres/userpat_repository.go index e0d5a691e..9a3249568 100644 --- a/internal/store/postgres/userpat_repository.go +++ b/internal/store/postgres/userpat_repository.go @@ -173,7 +173,7 @@ func (r UserPATRepository) buildPATFilteredQuery(userID, orgID string, rqlQuery rqlQuery = utils.NewRQLQuery("", utils.DefaultOffset, utils.DefaultLimit, []rql.Filter{}, []rql.Sort{}, []string{}) } - baseStmt := dialect.From(TABLE_USER_PATS).Where( + baseStmt := dialect.From(TABLE_USER_PATS).Prepared(true).Where( goqu.Ex{"user_id": userID}, goqu.Ex{"org_id": orgID}, goqu.Ex{"deleted_at": nil}, diff --git a/pkg/utils/rql.go b/pkg/utils/rql.go index cb8f61058..f03274c98 100644 --- a/pkg/utils/rql.go +++ b/pkg/utils/rql.go @@ -126,10 +126,11 @@ func AddRQLSearchInQuery(query *goqu.SelectDataset, rql *rql.Query, rqlSearchSup searchExpressions := make([]goqu.Expression, 0) if rql.Search != "" { + searchPattern := "%" + rql.Search + "%" for _, col := range rqlSearchSupportedColumns { - searchExpressions = append(searchExpressions, goqu.L( - fmt.Sprintf(`"%s"::TEXT ILIKE '%%%s%%'`, col, rql.Search), - )) + searchExpressions = append(searchExpressions, + goqu.Cast(goqu.I(col), "TEXT").ILike(searchPattern), + ) } } return query.Where(goqu.Or(searchExpressions...)), nil @@ -167,9 +168,9 @@ func AddRQLFiltersInQuery(query *goqu.SelectDataset, rqlInput *rql.Query, rqlFil func ProcessStringDataType(filter rql.Filter, query *goqu.SelectDataset) *goqu.SelectDataset { switch filter.Operator { case OperatorEmpty: - query = query.Where(goqu.L(fmt.Sprintf("coalesce(%s, '') = ''", filter.Name))) + query = query.Where(goqu.Or(goqu.I(filter.Name).IsNull(), goqu.I(filter.Name).Eq(""))) case OperatorNotEmpty: - query = query.Where(goqu.L(fmt.Sprintf("coalesce(%s, '') != ''", filter.Name))) + query = query.Where(goqu.And(goqu.I(filter.Name).IsNotNull(), goqu.I(filter.Name).Neq(""))) case OperatorIn, OperatorNotIn: // process the values of in and notin operators as comma separated list query = query.Where(goqu.Ex{ @@ -177,10 +178,10 @@ func ProcessStringDataType(filter rql.Filter, query *goqu.SelectDataset) *goqu.S }) case OperatorLike: // some semi-string sql types like UUID require casting to text to support like operator - query = query.Where(goqu.L(fmt.Sprintf(`"%s"::TEXT ILIKE '%s'`, filter.Name, filter.Value.(string)))) + query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").ILike(filter.Value.(string))) case OperatorNotLike: // some semi-string sql types like UUID require casting to text to support like operator - query = query.Where(goqu.L(fmt.Sprintf(`"%s"::TEXT NOT ILIKE '%s'`, filter.Name, filter.Value.(string)))) + query = query.Where(goqu.Cast(goqu.I(filter.Name), "TEXT").NotILike(filter.Value.(string))) default: query = query.Where(goqu.Ex{filter.Name: goqu.Op{filter.Operator: filter.Value.(string)}}) }