/
tx.go
38 lines (32 loc) · 828 Bytes
/
tx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package mongo
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
)
type TxFn = func(sessCtx mongo.SessionContext) (interface{}, error)
type TxObject struct {
ctx context.Context
client *mongo.Client
}
func (o *TxObject) Run(fn TxFn) error {
return RunInCtxTransaction(o.ctx, o.client, fn)
}
func RunInCtxTransaction(ctx context.Context, client *mongo.Client, callback TxFn) error {
session, err := client.StartSession()
if err != nil {
return err
}
defer session.EndSession(ctx)
_, err = session.WithTransaction(ctx, callback)
return err
}
func RunInTransaction(client *mongo.Client, callback TxFn) error {
session, err := client.StartSession()
if err != nil {
return err
}
ctx := context.Background()
defer session.EndSession(ctx)
_, err = session.WithTransaction(ctx, callback)
return err
}