-
Notifications
You must be signed in to change notification settings - Fork 6
/
gorm.go
131 lines (103 loc) · 3.17 KB
/
gorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package db
import (
"context"
"errors"
"regexp"
"strings"
"github.com/jackc/pgx/v5/pgconn"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
)
var tracer = otel.Tracer("github.com/teamkeel/keel/db")
type GormDB struct {
db *gorm.DB
}
var _ Database = &GormDB{}
func (db *GormDB) ExecuteQuery(ctx context.Context, sqlQuery string, values ...any) (*ExecuteQueryResult, error) {
ctx, span := tracer.Start(ctx, "Execute Query")
defer span.End()
span.SetAttributes(attribute.String("sql", sqlQuery))
rows := []map[string]any{}
conn := db.db.WithContext(ctx)
// Check for a transaction
if v, ok := ctx.Value(transactionCtxKey).(*gorm.DB); ok {
conn = v
}
err := conn.Raw(sqlQuery, values...).Scan(&rows).Error
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, toDbError(err)
}
span.SetAttributes(attribute.Int("rows.count", len(rows)))
return &ExecuteQueryResult{Rows: rows}, nil
}
func (db *GormDB) ExecuteStatement(ctx context.Context, sqlQuery string, values ...any) (*ExecuteStatementResult, error) {
ctx, span := tracer.Start(ctx, "Execute Statement")
defer span.End()
span.SetAttributes(attribute.String("sql", sqlQuery))
conn := db.db.WithContext(ctx)
// Check for a transaction
if v, ok := ctx.Value(transactionCtxKey).(*gorm.DB); ok {
conn = v
}
result := conn.Exec(sqlQuery, values...)
err := result.Error
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, toDbError(err)
}
span.SetAttributes(attribute.Int("rows.affected", int(result.RowsAffected)))
return &ExecuteStatementResult{RowsAffected: result.RowsAffected}, nil
}
var transactionCtxKey struct{}
func (db *GormDB) Transaction(ctx context.Context, fn func(context.Context) error) error {
ctx, span := tracer.Start(ctx, "Database Transaction")
defer span.End()
return db.db.Transaction(func(tx *gorm.DB) (err error) {
ctx = context.WithValue(ctx, transactionCtxKey, tx)
return fn(ctx)
})
}
func (db *GormDB) Close() error {
conn, err := db.db.DB()
if err != nil {
return err
}
return conn.Close()
}
func (db *GormDB) GetDB() *gorm.DB {
return db.db
}
func toDbError(err error) error {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
return err
}
dbErr := &DbError{
Table: pgErr.TableName,
Columns: []string{},
Message: pgErr.Message,
PgErrCode: pgErr.Code,
Err: pgErr,
}
switch pgErr.Code {
case PgForeignKeyConstraintViolation:
// Extract column and value from "Key (author_id)=(2L2ar5NCPvTTEdiDYqgcpF3f5QN1) is not present in table \"author\"."
out := regexp.MustCompile(`\(([^)]+)\)`).FindAllStringSubmatch(pgErr.Detail, -1)
dbErr.Columns = []string{out[0][1]}
case PgUniqueConstraintViolation:
// Extract column and value from "Key (code)=(1234) already exists."
out := regexp.MustCompile(`\(([^)]+)\)`).FindAllStringSubmatch(pgErr.Detail, -1)
dbErr.Columns = strings.Split(out[0][1], ", ")
default:
if pgErr.ColumnName != "" {
dbErr.Columns = []string{pgErr.ColumnName}
}
}
return dbErr
}