-
Notifications
You must be signed in to change notification settings - Fork 0
/
db.go
285 lines (229 loc) · 8.52 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
package docdb
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
)
type DBIntf interface {
Save(ctx context.Context, collection string, data interface{}) (string, error)
SaveMultiple(context.Context, string, []interface{}) ([]interface{}, error)
GetItem(ctx context.Context, collection string, filter map[string]interface{}, excludedFields map[string]interface{}, result interface{}) error
GetItems(ctx context.Context, collection string, filter map[string]interface{}, limit int64, excludedFields map[string]interface{}, sort map[string]interface{}, results interface{}) error
CountItems(ctx context.Context, collection string, filter map[string]interface{}) (int64, error)
DeleteItem(ctx context.Context, c string, filter map[string]interface{}) (int64, error)
DeleteItems(ctx context.Context, c string, filter map[string]interface{}) (int64, error)
UpdateItem(ctx context.Context, c string, match map[string]interface{}, update map[string]interface{}) (int64, error)
UpdateItems(ctx context.Context, c string, match map[string]interface{}, update map[string]interface{}) (int64, error)
GetCollection(collection string) *mongo.Collection
GetClient() *mongo.Client
StartTxn() (*Txn, error)
}
type Txn struct {
session mongo.Session
txnopts *options.TransactionOptions
fns []func(sessionContext mongo.SessionContext) error
}
// ErrMongoDBDuplicate error
var ErrMongoDBDuplicate = errors.New("duplicate entry")
// ErrInvalidObjectID error
var ErrInvalidObjectID = errors.New("invalid object ID")
// ErrNotFound error
var ErrNotFound = errors.New("item not found")
// MongoDB connection holder
type MongoDB struct {
client *mongo.Client
database string
}
// NewDB creates a DB connection and returns a db instance
func NewDB(ctx context.Context, uri, database string) (db *MongoDB, err error) {
db = &MongoDB{}
client, err := mongo.NewClient(options.Client().ApplyURI(uri))
if err != nil {
return
}
err = client.Connect(ctx)
if err != nil {
return
}
db.database = database
db.client = client
return
}
// Disconnect closes the mongodb connection
func (db *MongoDB) Disconnect(ctx context.Context) {
db.client.Disconnect(ctx)
}
// Ping db
func (db *MongoDB) Ping(ctx context.Context) (bool, error) {
err := db.client.Ping(ctx, readpref.Primary())
if err != nil {
return false, err
}
return true, nil
}
// GetClient func
func (db *MongoDB) GetClient() *mongo.Client {
return db.client
}
// GetCollection func
func (db *MongoDB) GetCollection(collection string) *mongo.Collection {
return db.client.Database(db.database).Collection(collection)
}
// Save func: c stands for collection where data would be saved. e.g save data in 'users' collection in MongoDB
// ctx can be a mongodb session context for transactions
func (db *MongoDB) Save(ctx context.Context, c string, data interface{}) (string, error) {
collection := db.GetCollection(c)
// ctx can be a mongodb session context for transactions
insertResult, err := collection.InsertOne(ctx, data)
if err != nil {
/*
var merr mongo.WriteException
merr = err.(mongo.WriteException)
errCode := merr.WriteErrors[0].Code
if errCode == 11000 {
return "", ErrMongoDBDuplicate
} */
return "", err
}
// update rule with returned ID
return insertResult.InsertedID.(primitive.ObjectID).Hex(), nil
}
// SaveMultiple func: c stands for collection where data would be saved. e.g save data in 'users' collection in MongoDB
// ctx can be a mongodb session context for transactions
func (db *MongoDB) SaveMultiple(ctx context.Context, c string, items []interface{}) ([]interface{}, error) {
collection := db.GetCollection(c)
insertManyResult, err := collection.InsertMany(ctx, items)
if err != nil {
return nil, err
}
return insertManyResult.InsertedIDs, nil
}
// GetItem func: c stands for collection where item should be retrieved. e.g retrieve item from 'users' collection in MongoDB.
// ctx can be a mongodb session context for transactions
// results is a pointer to object to store returned data. nil is returned for error if item is found
func (db *MongoDB) GetItem(ctx context.Context, c string, filter map[string]interface{}, excludedFields map[string]interface{}, result interface{}) error {
collection := db.GetCollection(c)
findOptions := options.FindOne().SetProjection(excludedFields)
// var result interface{}
err := collection.FindOne(ctx, filter, findOptions).Decode(result)
if err != nil {
// TODO check for not found errror and return it
// return nil, ErrNotFound
return err
}
return nil
}
func (db *MongoDB) UpdateItem(ctx context.Context, c string, match map[string]interface{}, update map[string]interface{}) (int64, error) {
collection := db.GetCollection(c)
result, err := collection.UpdateOne(
ctx,
match,
update,
)
if err != nil {
return 0, err
}
return result.ModifiedCount, nil
}
func (db *MongoDB) UpdateItems(ctx context.Context, c string, match map[string]interface{}, update map[string]interface{}) (int64, error) {
collection := db.GetCollection(c)
result, err := collection.UpdateMany(
ctx,
match,
update,
)
if err != nil {
return 0, err
}
return result.ModifiedCount, nil
}
// GetItems func: c stands for collection where data would be saved. e.g save data in 'users' collection in MongoDB. id is string
// ctx can be a mongodb session context for transactions
// results is a pointer to slice of object to store returned data. nil is returned for error if item is found
func (db *MongoDB) GetItems(ctx context.Context, c string, filter map[string]interface{}, limit int64, excludedFields map[string]interface{}, sort map[string]interface{}, results interface{}) error {
collection := db.GetCollection(c)
findOptions := options.Find().SetProjection(excludedFields)
findOptions.SetSort(sort)
findOptions.SetLimit(limit)
// var results []interface{}
cur, err := collection.Find(ctx, filter, findOptions)
if err != nil {
return err
}
// Close the cursor once finished
defer cur.Close(ctx)
if err := cur.All(ctx, results); err != nil {
return err
}
return nil
}
// CountItems func: c stands for collection where items should be counted. e.g count items in 'users' collection in MongoDB.
// ctx can be a mongodb session context for transactions
func (db *MongoDB) CountItems(ctx context.Context, c string, filter map[string]interface{}) (int64, error) {
collection := db.GetCollection(c)
countOptions := options.Count()
var result int64
result, err := collection.CountDocuments(ctx, filter, countOptions)
if err != nil {
return 0, err
}
return result, nil
}
// DeleteItem func: c stands for collection where item should be retrieved. e.g retrieve item from 'users' collection in MongoDB.
// ctx can be a mongodb session context for transactions
func (db *MongoDB) DeleteItem(ctx context.Context, c string, filter map[string]interface{}) (int64, error) {
collection := db.GetCollection(c)
deleteResult, err := collection.DeleteOne(ctx, filter)
if err != nil {
return 0, err
}
return deleteResult.DeletedCount, nil
}
// DeleteItems func: c stands for collection where item should be retrieved. e.g retrieve item from 'users' collection in MongoDB.
// ctx can be a mongodb session context for transactions
func (db *MongoDB) DeleteItems(ctx context.Context, c string, filter map[string]interface{}) (int64, error) {
collection := db.GetCollection(c)
deleteResult, err := collection.DeleteMany(ctx, filter)
if err != nil {
return 0, err
}
return deleteResult.DeletedCount, nil
}
func (db *MongoDB) StartTxn() (*Txn, error) {
wc := writeconcern.New(writeconcern.WMajority())
rc := readconcern.Snapshot()
txnOpts := options.Transaction().SetWriteConcern(wc).SetReadConcern(rc)
session, err := db.GetClient().StartSession()
if err != nil {
return nil, err
}
return &Txn{
txnopts: txnOpts,
session: session,
}, nil
}
func (this *Txn) AddExecution(fn func(sessionContext mongo.SessionContext) error) {
this.fns = append(this.fns, fn)
}
func (this *Txn) Execute(ctx context.Context) error {
defer this.session.EndSession(context.Background())
callback := func(sessionContext mongo.SessionContext) (interface{}, error) {
for _, fn := range this.fns {
err := fn(sessionContext)
if err != nil {
return nil, err
}
}
return nil, nil
}
_, err := this.session.WithTransaction(context.Background(), callback, this.txnopts)
if err != nil {
return err
}
return nil
}