diff --git a/backend/gocql/db.go b/backend/gocql/db.go index 4e6dcdd..5a45d81 100644 --- a/backend/gocql/db.go +++ b/backend/gocql/db.go @@ -62,7 +62,19 @@ func (db *DB) ExecCAS(ctx context.Context, stmt string, vs ...interface{}) cql.C } func (db *DB) QueryRow(ctx context.Context, stmt string, vs ...interface{}) cql.Scanner { - return db.query(ctx, stmt, vs) + return scanner{db.query(ctx, stmt, vs)} +} + +type scanner struct { + cql.Scanner +} + +func (s scanner) Scan(vs ...interface{}) error { + if err := s.Scanner.Scan(vs...); err != gocql.ErrNotFound { + return err + } + + return cql.ErrNoRows } type cursor struct { diff --git a/cqltest/static_source.go b/cqltest/static_source.go new file mode 100644 index 0000000..80b8517 --- /dev/null +++ b/cqltest/static_source.go @@ -0,0 +1,47 @@ +package cqltest + +import ( + "context" + "io" + "io/ioutil" + "strings" + + "github.com/upfluence/cql/x/migration" +) + +type StaticSource struct { + MigrationUp string + MigrationDown string +} + +func (ss StaticSource) ID() uint { + return 1 +} + +func (ss StaticSource) Up() (io.ReadCloser, error) { + return ioutil.NopCloser(strings.NewReader(ss.MigrationUp)), nil +} + +func (ss StaticSource) Down() (io.ReadCloser, error) { + return ioutil.NopCloser(strings.NewReader(ss.MigrationDown)), nil +} + +func (ss StaticSource) Get(_ context.Context, v uint) (migration.Migration, error) { + if v != 1 { + return nil, migration.ErrNotExist + } + + return ss, nil +} + +func (ss StaticSource) First(context.Context) (migration.Migration, error) { + return ss, nil +} + +func (ss StaticSource) Next(context.Context, uint) (bool, uint, error) { + return false, 0, nil +} + +func (ss StaticSource) Prev(context.Context, uint) (bool, uint, error) { + return false, 0, nil +} diff --git a/db.go b/db.go index e8b9496..705ea17 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,11 @@ package cql -import "context" +import ( + "context" + "errors" +) + +var ErrNoRows = errors.New("cql: No rows found") type BatchType uint8 diff --git a/integration/integration_test.go b/integration/integration_test.go index fd5eef0..d4f735c 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -17,9 +17,9 @@ func TestMigrationIntegration(t *testing.T) { cqltest.WithMigratorFunc(func(db cql.DB) migration.Migrator { return migration.NewMigrator( db, - staticSource{ - up: "CREATE TABLE IF NOT EXISTS foo(uuid UUID PRIMARY KEY, data blob)", - down: "DROP TABLE foo", + cqltest.StaticSource{ + MigrationUp: "CREATE TABLE IF NOT EXISTS foo(uuid UUID PRIMARY KEY, data blob)", + MigrationDown: "DROP TABLE foo", }, ) }), diff --git a/integration/static_source.go b/integration/static_source.go index d832503..76ab1b7 100644 --- a/integration/static_source.go +++ b/integration/static_source.go @@ -1,46 +1 @@ package integration - -import ( - "context" - "io" - "io/ioutil" - "strings" - - "github.com/upfluence/cql/x/migration" -) - -type staticSource struct { - up, down string -} - -func (ss staticSource) ID() uint { - return 1 -} - -func (ss staticSource) Up() (io.ReadCloser, error) { - return ioutil.NopCloser(strings.NewReader(ss.up)), nil -} - -func (ss staticSource) Down() (io.ReadCloser, error) { - return ioutil.NopCloser(strings.NewReader(ss.down)), nil -} - -func (ss staticSource) Get(_ context.Context, v uint) (migration.Migration, error) { - if v != 1 { - return nil, migration.ErrNotExist - } - - return ss, nil -} - -func (ss staticSource) First(context.Context) (migration.Migration, error) { - return ss, nil -} - -func (ss staticSource) Next(context.Context, uint) (bool, uint, error) { - return false, 0, nil -} - -func (ss staticSource) Prev(context.Context, uint) (bool, uint, error) { - return false, 0, nil -} diff --git a/x/cqlbuilder/batch_statement.go b/x/cqlbuilder/batch_statement.go new file mode 100644 index 0000000..018dd53 --- /dev/null +++ b/x/cqlbuilder/batch_statement.go @@ -0,0 +1,34 @@ +package cqlbuilder + +import ( + "context" + + "github.com/upfluence/cql" +) + +type BatchStatement struct { + Type cql.BatchType + + Statements []CASStatement +} + +type BatchExecer struct { + QueryBuilder *QueryBuilder + Statement BatchStatement +} + +func (be *BatchExecer) Exec(ctx context.Context, qvs map[string]interface{}) error { + var b = be.QueryBuilder.Batch(ctx, be.Statement.Type) + + for _, s := range be.Statement.Statements { + stmt, vs, err := s.buildQuery(qvs) + + if err != nil { + return err + } + + b.Query(stmt, vs...) + } + + return b.Exec() +} diff --git a/x/cqlbuilder/delete_statement.go b/x/cqlbuilder/delete_statement.go new file mode 100644 index 0000000..9971177 --- /dev/null +++ b/x/cqlbuilder/delete_statement.go @@ -0,0 +1,69 @@ +package cqlbuilder + +import ( + "fmt" + "strings" + "time" +) + +type LWTDeleteClause interface { + LWTClause + + isDeleteClause() +} + +type DeleteStatement struct { + Table string + + Fields []Marker + WhereClause PredicateClause + + Timestamp time.Time + LWTClause LWTDeleteClause +} + +func (ds DeleteStatement) casScanKeys() []string { + if lck, ok := ds.LWTClause.(interface{ keys() []string }); ok { + return lck.keys() + } + + return nil +} + +func (ds DeleteStatement) buildQuery(qvs map[string]interface{}) (string, []interface{}, error) { + var ( + qw queryWriter + + ks = make([]string, len(ds.Fields)) + ) + + for i, f := range ds.Fields { + k := f.ToCQL() + + if i == len(ds.Fields)-1 { + k += " " + } + + ks[i] = k + } + + fmt.Fprintf(&qw, "DELETE %sFROM %s ", strings.Join(ks, ", "), ds.Table) + + DMLOptions{Timestamp: ds.Timestamp}.writeTo(&qw) + + qw.WriteString("WHERE ") + + if err := ds.WhereClause.WriteTo(&qw, qvs); err != nil { + return "", nil, err + } + + if lc := ds.LWTClause; lc != nil { + qw.WriteRune(' ') + + if err := lc.writeTo(&qw, qvs); err != nil { + return "", nil, err + } + } + + return qw.String(), qw.args, nil +} diff --git a/x/cqlbuilder/delete_statement_test.go b/x/cqlbuilder/delete_statement_test.go new file mode 100644 index 0000000..d69fe0a --- /dev/null +++ b/x/cqlbuilder/delete_statement_test.go @@ -0,0 +1,32 @@ +package cqlbuilder + +import "testing" + +func TestDeleteStatement(t *testing.T) { + for _, stc := range []statementTestCase{ + { + name: "basic", + stmt: DeleteStatement{ + Table: "foo", + WhereClause: Eq(Column("bar")), + }, + vs: map[string]interface{}{"bar": 3}, + wantStmt: "DELETE FROM foo WHERE bar = ?", + wantArgs: []interface{}{3}, + }, + { + name: "lwt field", + stmt: DeleteStatement{ + Table: "foo", + Fields: []Marker{Column("fiz")}, + WhereClause: Eq(Column("bar")), + LWTClause: PredicateLWTClause{Predicate: Eq(Column("buz"))}, + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2, "bar": 3}, + wantStmt: "DELETE fiz FROM foo WHERE bar = ? IF buz = ?", + wantArgs: []interface{}{3, 2}, + }, + } { + stc.assert(t) + } +} diff --git a/x/cqlbuilder/dml.go b/x/cqlbuilder/dml.go new file mode 100644 index 0000000..1df7062 --- /dev/null +++ b/x/cqlbuilder/dml.go @@ -0,0 +1,92 @@ +package cqlbuilder + +import ( + "fmt" + "io" + "time" +) + +type DMLOptions struct { + TTL time.Duration + Timestamp time.Time +} + +func (do DMLOptions) writeTo(w io.Writer) { + if do.TTL == 0 && do.Timestamp.IsZero() { + return + } + + io.WriteString(w, " USING") + + if do.TTL > 0 { + fmt.Fprintf(w, " TTL %d", int(do.TTL.Seconds())) + + if !do.Timestamp.IsZero() { + io.WriteString(w, " AND") + } + } + + if !do.Timestamp.IsZero() { + fmt.Fprintf( + w, + " TIMESTAMP %d", + do.Timestamp.Unix()*1000+do.Timestamp.UnixNano()/1000000, + ) + } +} + +type LWTClause interface { + writeTo(QueryWriter, map[string]interface{}) error +} + +type notExistsClause struct{} + +var NotExistsClause = notExistsClause{} + +func (notExistsClause) writeTo(qw QueryWriter, _ map[string]interface{}) error { + _, err := io.WriteString(qw, "IF NOT EXISTS") + return err +} + +func (notExistsClause) isInsertClause() {} +func (notExistsClause) isUpdateClause() {} + +type existsClause struct{} + +var ExistsClause = existsClause{} + +func (existsClause) writeTo(qw QueryWriter, _ map[string]interface{}) error { + _, err := io.WriteString(qw, "IF EXISTS") + return err +} + +func (existsClause) isUpdateClause() {} +func (existsClause) isDeleteClause() {} + +type PredicateLWTClause struct { + Predicate PredicateClause +} + +func (plc PredicateLWTClause) writeTo(qw QueryWriter, vs map[string]interface{}) error { + if _, err := io.WriteString(qw, "IF "); err != nil { + return err + } + + return plc.Predicate.WriteTo(qw, vs) +} + +func (plc PredicateLWTClause) keys() []string { + var ( + ms = plc.Predicate.Markers() + ks = make([]string, len(ms)) + ) + + for i, m := range ms { + ks[i] = m.Binding() + } + + return ks +} + +func (plc PredicateLWTClause) isUpdateClause() {} +func (plc PredicateLWTClause) isDeleteClause() {} diff --git a/x/cqlbuilder/execer.go b/x/cqlbuilder/execer.go new file mode 100644 index 0000000..298f9c7 --- /dev/null +++ b/x/cqlbuilder/execer.go @@ -0,0 +1,71 @@ +package cqlbuilder + +import ( + "context" + + "github.com/upfluence/cql" +) + +type CASScanner interface { + ScanCAS(map[string]interface{}) (bool, error) +} + +type errCASScanner struct{ error } + +func (ecs errCASScanner) ScanCAS(map[string]interface{}) (bool, error) { + return false, ecs.error +} + +type casScanner struct { + sc cql.CASScanner + ks []string +} + +func (cs *casScanner) ScanCAS(qvs map[string]interface{}) (bool, error) { + vs := make([]interface{}, len(cs.ks)) + + for i, k := range cs.ks { + v, ok := qvs[k] + + if !ok { + return false, ErrMissingKey{Key: k} + } + + vs[i] = v + } + + return cs.sc.ScanCAS(vs...) +} + +type Execer interface { + Exec(context.Context, map[string]interface{}) error + ExecCAS(context.Context, map[string]interface{}) CASScanner +} + +type execer struct { + stmt CASStatement + db cql.DB +} + +func (e *execer) Exec(ctx context.Context, qvs map[string]interface{}) error { + var stmt, vs, err = e.stmt.buildQuery(qvs) + + if err != nil { + return err + } + + return e.db.Exec(ctx, stmt, vs...) +} + +func (e *execer) ExecCAS(ctx context.Context, qvs map[string]interface{}) CASScanner { + var stmt, vs, err = e.stmt.buildQuery(qvs) + + if err != nil { + return errCASScanner{err} + } + + return &casScanner{ + sc: e.db.ExecCAS(ctx, stmt, vs...), + ks: e.stmt.casScanKeys(), + } +} diff --git a/x/cqlbuilder/insert_statement.go b/x/cqlbuilder/insert_statement.go new file mode 100644 index 0000000..91bf88f --- /dev/null +++ b/x/cqlbuilder/insert_statement.go @@ -0,0 +1,77 @@ +package cqlbuilder + +import ( + "fmt" + "strings" +) + +type LWTInsertClause interface { + LWTClause + + isInsertClause() +} + +type InsertStatement struct { + Table string + + Fields []Marker + + Options DMLOptions + LWTClause LWTInsertClause +} + +func (is InsertStatement) casScanKeys() []string { + var ks = make([]string, len(is.Fields)) + + for i, f := range is.Fields { + ks[i] = f.Binding() + } + + return ks +} + +func (is InsertStatement) buildQuery(qvs map[string]interface{}) (string, []interface{}, error) { + var ( + qw queryWriter + + ks = make([]string, len(is.Fields)) + qs = make([]string, len(is.Fields)) + ) + + if len(is.Fields) == 0 { + return "", nil, errNoMarkers + } + + for i, f := range is.Fields { + k := f.Binding() + v, ok := qvs[k] + + if !ok { + return "", nil, ErrMissingKey{Key: k} + } + + ks[i] = columnName(f) + qs[i] = "?" + qw.AddArg(v) + } + + fmt.Fprintf( + &qw, + "INSERT INTO %s(%s) VALUES (%s)", + is.Table, + strings.Join(ks, ", "), + strings.Join(qs, ", "), + ) + + if lc := is.LWTClause; lc != nil { + qw.WriteRune(' ') + + if err := lc.writeTo(&qw, qvs); err != nil { + return "", nil, err + } + } + + is.Options.writeTo(&qw) + + return qw.String(), qw.args, nil +} diff --git a/x/cqlbuilder/insert_statement_test.go b/x/cqlbuilder/insert_statement_test.go new file mode 100644 index 0000000..78e0196 --- /dev/null +++ b/x/cqlbuilder/insert_statement_test.go @@ -0,0 +1,40 @@ +package cqlbuilder + +import "testing" + +func TestInsertStatement(t *testing.T) { + for _, stc := range []statementTestCase{ + { + name: "basic", + stmt: InsertStatement{ + Table: "foo", + Fields: []Marker{Column("fiz"), Column("buz")}, + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2}, + wantStmt: "INSERT INTO foo(fiz, buz) VALUES (?, ?)", + wantArgs: []interface{}{1, 2}, + }, + { + name: "basic", + stmt: InsertStatement{ + Table: "foo", + Fields: []Marker{Column("fiz"), Column("buz")}, + LWTClause: NotExistsClause, + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2}, + wantStmt: "INSERT INTO foo(fiz, buz) VALUES (?, ?) IF NOT EXISTS", + wantArgs: []interface{}{1, 2}, + }, + { + name: "missing key", + stmt: InsertStatement{ + Table: "foo", + Fields: []Marker{Column("fiz"), Column("buz")}, + }, + vs: map[string]interface{}{"fiz": 1}, + wantErr: ErrMissingKey{Key: "buz"}, + }, + } { + stc.assert(t) + } +} diff --git a/x/cqlbuilder/marker.go b/x/cqlbuilder/marker.go new file mode 100644 index 0000000..8606b66 --- /dev/null +++ b/x/cqlbuilder/marker.go @@ -0,0 +1,48 @@ +package cqlbuilder + +import ( + "errors" + "fmt" +) + +var errNoMarkers = errors.New("x/sqlbuilder: No marker given to the statement") + +type ErrMissingKey struct{ Key string } + +func (emk ErrMissingKey) Error() string { + return fmt.Sprintf("%q key missing", emk.Key) +} + +type Marker interface { + Binding() string + ToCQL() string + Clone() Marker +} + +func Column(k string) Marker { return column(k) } + +type column string + +func (c column) ColumnName() string { return string(c) } +func (c column) Binding() string { return string(c) } +func (c column) ToCQL() string { return string(c) } +func (c column) Clone() Marker { return c } + +func CQLExpression(m, exp string) Marker { return cqlMarker{m: m, cql: exp} } + +type cqlMarker struct { + m string + cql string +} + +func (cm cqlMarker) Binding() string { return cm.m } +func (cm cqlMarker) ToCQL() string { return cm.cql } +func (cm cqlMarker) Clone() Marker { return cm } + +func columnName(m Marker) string { + if cn, ok := m.(interface{ ColumnName() string }); ok { + return cn.ColumnName() + } + + return m.ToCQL() +} diff --git a/x/cqlbuilder/predicate.go b/x/cqlbuilder/predicate.go new file mode 100644 index 0000000..04fe838 --- /dev/null +++ b/x/cqlbuilder/predicate.go @@ -0,0 +1,244 @@ +package cqlbuilder + +import ( + "errors" + "fmt" + "io" + "reflect" +) + +var errInvalidType = errors.New("x/cqlbuilder: invalid type") + +type PredicateClause interface { + WriteTo(QueryWriter, map[string]interface{}) error + Clone() PredicateClause + Markers() []Marker +} + +func Eq(m Marker) PredicateClause { return signClause(m, "=") } +func Ne(m Marker) PredicateClause { return signClause(m, "!=") } +func Lt(m Marker) PredicateClause { return signClause(m, "<") } +func Lte(m Marker) PredicateClause { return signClause(m, "<=") } +func Gt(m Marker) PredicateClause { return signClause(m, ">") } +func Gte(m Marker) PredicateClause { return signClause(m, ">=") } + +func signClause(m Marker, s string) *basicClause { + return &basicClause{m: m, fn: writeSignClause(s)} +} + +func writeSignClause(s string) func(QueryWriter, interface{}, string) error { + return func(qw QueryWriter, vv interface{}, k string) error { + fmt.Fprintf(qw, "%s %s ?", k, s) + qw.AddArg(vv) + return nil + } +} + +func In(m Marker) PredicateClause { + return &basicClause{m: m, fn: writeInClause} +} + +type basicClause struct { + m Marker + fn func(QueryWriter, interface{}, string) error +} + +func (bc *basicClause) Markers() []Marker { return []Marker{bc.m} } + +func (bc *basicClause) Clone() PredicateClause { + return &basicClause{m: bc.m.Clone(), fn: bc.fn} +} + +func (bc *basicClause) WriteTo(w QueryWriter, vs map[string]interface{}) error { + b := bc.m.Binding() + vv, ok := vs[b] + + if !ok { + return ErrMissingKey{b} + } + + return bc.fn(w, vv, bc.m.ToCQL()) +} + +func parseItems(vv interface{}) ([]interface{}, error) { + var v = reflect.ValueOf(vv) + + if k := v.Kind(); k != reflect.Slice && k != reflect.Array { + return nil, errInvalidType + } + + res := make([]interface{}, v.Len()) + + for i := 0; i < v.Len(); i++ { + res[i] = v.Index(i).Interface() + } + + return res, nil +} + +func writeInClause(qw QueryWriter, vv interface{}, k string) error { + vs, err := parseItems(vv) + + if err != nil { + return err + } + + if len(vs) == 0 { + io.WriteString(qw, "1=0") + return nil + } + + fmt.Fprintf(qw, "%s IN (", k) + + for i, v := range vs { + io.WriteString(qw, "?") + qw.AddArg(v) + + if i < len(vs)-1 { + io.WriteString(qw, ", ") + } + } + + io.WriteString(qw, ")") + return nil +} + +func Static(pc PredicateClause, vs map[string]interface{}) PredicateClause { + return &staticValuePredicateClauseWrapper{ + svpc: &staticClause{pc: pc, vs: vs}, + } +} + +func StaticEq(m Marker, v interface{}) PredicateClause { + return Static(Eq(m), map[string]interface{}{m.Binding(): v}) +} + +type staticValuePredicateClauseWrapper struct { + svpc StaticValuePredicateClause +} + +func (svpcw *staticValuePredicateClauseWrapper) Markers() []Marker { + return svpcw.svpc.Markers() +} + +func (svpcw *staticValuePredicateClauseWrapper) Clone() PredicateClause { + return &staticValuePredicateClauseWrapper{ + svpc: svpcw.svpc.Clone(), + } +} + +func (svpcw *staticValuePredicateClauseWrapper) WriteTo(w QueryWriter, _ map[string]interface{}) error { + return svpcw.svpc.WriteTo(w) +} + +type staticClause struct { + pc PredicateClause + vs map[string]interface{} +} + +func (sc *staticClause) Clone() StaticValuePredicateClause { + vs := make(map[string]interface{}, len(sc.vs)) + + for k, v := range sc.vs { + vs[k] = v + } + + return &staticClause{pc: sc.pc.Clone(), vs: vs} +} + +func (sc *staticClause) WriteTo(w QueryWriter) error { + return sc.pc.WriteTo(w, sc.vs) +} + +func (sc *staticClause) Markers() []Marker { + return sc.pc.Markers() +} + +type StaticValuePredicateClause interface { + WriteTo(QueryWriter) error + Clone() StaticValuePredicateClause + Markers() []Marker +} + +type multiClause struct { + wcs []PredicateClause + + op string +} + +func wrapMultiClause(wcs []PredicateClause, op string) PredicateClause { + var cs []PredicateClause + + for _, wc := range wcs { + if wc == nil { + continue + } + + if mc, ok := wc.(multiClause); ok && mc.op == op { + cs = append(cs, mc.wcs...) + continue + } + + cs = append(cs, wc) + } + + switch len(cs) { + case 0: + return nil + case 1: + return cs[0] + default: + return multiClause{wcs: cs, op: op} + } +} + +func And(wcs ...PredicateClause) PredicateClause { + return wrapMultiClause(wcs, "AND") +} + +func (mc multiClause) Markers() []Marker { + var ms []Marker + + for _, c := range mc.wcs { + ms = append(ms, c.Markers()...) + } + + return ms +} + +func (mc multiClause) Clone() PredicateClause { + var wcs []PredicateClause + + if len(mc.wcs) > 0 { + wcs = make([]PredicateClause, len(mc.wcs)) + + for i, pc := range mc.wcs { + wcs[i] = pc.Clone() + } + } + + return multiClause{wcs: wcs, op: mc.op} +} + +func (mc multiClause) WriteTo(w QueryWriter, vs map[string]interface{}) error { + if len(mc.wcs) == 0 { + io.WriteString(w, "1=0") + return nil + } + + io.WriteString(w, "(") + + for i, wc := range mc.wcs { + if err := wc.WriteTo(w, vs); err != nil { + return err + } + + if i < len(mc.wcs)-1 { + fmt.Fprintf(w, ") %s (", mc.op) + } + } + + io.WriteString(w, ")") + + return nil +} diff --git a/x/cqlbuilder/query_builder.go b/x/cqlbuilder/query_builder.go new file mode 100644 index 0000000..3a1f0e1 --- /dev/null +++ b/x/cqlbuilder/query_builder.go @@ -0,0 +1,70 @@ +package cqlbuilder + +import "github.com/upfluence/cql" + +type QueryBuilder struct { + cql.DB +} + +func (qb *QueryBuilder) PrepareInsert(is InsertStatement) *InsertExecer { + return &InsertExecer{ + execer: execer{stmt: is, db: qb.DB}, + QueryBuilder: qb, + Statement: is, + } +} + +func (qb *QueryBuilder) PrepareDelete(ds DeleteStatement) *DeleteExecer { + return &DeleteExecer{ + execer: execer{stmt: ds, db: qb.DB}, + QueryBuilder: qb, + Statement: ds, + } +} + +func (qb *QueryBuilder) PrepareUpdate(us UpdateStatement) *UpdateExecer { + return &UpdateExecer{ + execer: execer{stmt: us, db: qb.DB}, + QueryBuilder: qb, + Statement: us, + } +} + +func (qb *QueryBuilder) PrepareSelect(ss SelectStatement) *SelectQueryer { + return &SelectQueryer{QueryBuilder: qb, Statement: ss} +} + +func (qb *QueryBuilder) PrepareBatch(bs BatchStatement) *BatchExecer { + return &BatchExecer{QueryBuilder: qb, Statement: bs} +} + +type statement interface { + buildQuery(map[string]interface{}) (string, []interface{}, error) +} + +type CASStatement interface { + statement + + casScanKeys() []string +} + +type InsertExecer struct { + execer + + QueryBuilder *QueryBuilder + Statement InsertStatement +} + +type DeleteExecer struct { + execer + + QueryBuilder *QueryBuilder + Statement DeleteStatement +} + +type UpdateExecer struct { + execer + + QueryBuilder *QueryBuilder + Statement UpdateStatement +} diff --git a/x/cqlbuilder/query_builder_test.go b/x/cqlbuilder/query_builder_test.go new file mode 100644 index 0000000..1804855 --- /dev/null +++ b/x/cqlbuilder/query_builder_test.go @@ -0,0 +1,243 @@ +package cqlbuilder + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/cql" + "github.com/upfluence/cql/cqltest" + "github.com/upfluence/cql/x/migration" +) + +type statementTestCase struct { + name string + + stmt statement + vs map[string]interface{} + + wantStmt string + wantArgs []interface{} + wantErr error +} + +func (stc statementTestCase) assert(t *testing.T) { + t.Helper() + + t.Run(stc.name, func(t *testing.T) { + t.Helper() + + stmt, args, err := stc.stmt.buildQuery(stc.vs) + + assert.Equal(t, stc.wantStmt, stmt) + assert.Equal(t, stc.wantArgs, args) + assert.Equal(t, stc.wantErr, err) + }) +} + +func integrationTest(t *testing.T, fn func(*testing.T, cql.DB)) { + cqltest.NewTestCase( + cqltest.WithMigratorFunc(func(db cql.DB) migration.Migrator { + return migration.NewMigrator( + db, + cqltest.StaticSource{ + MigrationUp: "CREATE TABLE IF NOT EXISTS fuz(foo text PRIMARY KEY, bar blob)", + MigrationDown: "DROP TABLE fuz", + }, + migration.MigrationTable("cqlbuilder_integration_migrations"), + ) + }), + ).Run(t, fn) +} + +func TestCAS(t *testing.T) { + integrationTest(t, func(t *testing.T, db cql.DB) { + qb := QueryBuilder{DB: db} + + ie := qb.PrepareInsert( + InsertStatement{ + Table: "fuz", + Fields: []Marker{Column("foo"), Column("bar")}, + LWTClause: NotExistsClause, + }, + ) + + ok, _, _, err := execCAS(ie, "foo", "bar") + assert.True(t, ok) + assert.NoError(t, err) + + ok, _, bar, err := execCAS(ie, "foo", "foo") + assert.False(t, ok) + assert.NoError(t, err) + assert.Equal(t, "bar", bar) + + ue := qb.PrepareUpdate( + UpdateStatement{ + Table: "fuz", + UpdateClauses: []UpdateClause{{Field: Column("bar"), Op: Set}}, + WhereClause: Eq(Column("foo")), + LWTClause: PredicateLWTClause{ + Predicate: StaticEq(Column("bar"), "bar"), + }, + }, + ) + + ok, _, _, err = execCAS(ue, "foo", "foo") + assert.True(t, ok) + assert.NoError(t, err) + + ok, _, bar, err = execCAS(ue, "foo", "foo") + assert.False(t, ok) + assert.NoError(t, err) + assert.Equal(t, "foo", bar) + + de := qb.PrepareDelete( + DeleteStatement{ + Table: "fuz", + WhereClause: Eq(Column("foo")), + LWTClause: PredicateLWTClause{ + Predicate: StaticEq(Column("bar"), "foo"), + }, + }, + ) + + ok, _, _, err = execCAS(de, "foo", "") + assert.True(t, ok) + assert.NoError(t, err) + }) +} + +func TestEC(t *testing.T) { + integrationTest(t, func(t *testing.T, db cql.DB) { + qb := QueryBuilder{DB: db} + + se := qb.PrepareSelect( + SelectStatement{ + Table: "fuz", + SelectClauses: []Marker{Column("bar")}, + WhereClause: Eq(Column("foo")), + }, + ) + + bar, err := queryRow(se, "foo") + assert.Equal(t, cql.ErrNoRows, err) + assert.Equal(t, "", bar) + + ie := qb.PrepareInsert( + InsertStatement{ + Table: "fuz", + Fields: []Marker{Column("foo"), Column("bar")}, + }, + ) + + err = exec(ie, "foo", "bar") + assert.NoError(t, err) + + bar, err = queryRow(se, "foo") + assert.NoError(t, err) + assert.Equal(t, "bar", bar) + + ue := qb.PrepareUpdate( + UpdateStatement{ + Table: "fuz", + UpdateClauses: []UpdateClause{{Field: Column("bar"), Op: Set}}, + WhereClause: Eq(Column("foo")), + }, + ) + + err = exec(ue, "foo", "foo") + assert.NoError(t, err) + + bar, err = queryRow(se, "foo") + assert.NoError(t, err) + assert.Equal(t, "foo", bar) + + de := qb.PrepareDelete( + DeleteStatement{Table: "fuz", WhereClause: Eq(Column("foo"))}, + ) + + err = exec(de, "foo", "") + assert.NoError(t, err) + + bar, err = queryRow(se, "foo") + assert.Equal(t, cql.ErrNoRows, err) + assert.Equal(t, "", bar) + }) +} + +func TestBatch(t *testing.T) { + integrationTest(t, func(t *testing.T, db cql.DB) { + qb := QueryBuilder{DB: db} + + be := qb.PrepareBatch( + BatchStatement{ + Type: cql.LoggedBatch, + Statements: []CASStatement{ + InsertStatement{ + Table: "fuz", + Fields: []Marker{CQLExpression("foo1", "foo"), Column("bar")}, + }, + InsertStatement{ + Table: "fuz", + Fields: []Marker{CQLExpression("foo2", "foo"), Column("bar")}, + }, + }, + }, + ) + + err := be.Exec( + context.Background(), + map[string]interface{}{"foo1": "foo", "foo2": "bar", "bar": "buz"}, + ) + + assert.NoError(t, err) + + cur := qb.PrepareSelect( + SelectStatement{ + Table: "fuz", + SelectClauses: []Marker{Column("foo"), Column("bar")}, + }, + ).Query(context.Background(), nil) + + var foo, bar string + + vs := make(map[string]string) + + for cur.Scan(map[string]interface{}{"foo": &foo, "bar": &bar}) { + vs[foo] = bar + } + + assert.NoError(t, cur.Close()) + assert.Equal(t, map[string]string{"foo": "buz", "bar": "buz"}, vs) + }) +} + +func queryRow(sq *SelectQueryer, foo string) (string, error) { + var bar string + + return bar, sq.QueryRow( + context.Background(), + map[string]interface{}{"foo": foo}, + ).Scan(map[string]interface{}{"bar": &bar}) +} + +func exec(e Execer, foo, bar string) error { + return e.Exec( + context.Background(), + map[string]interface{}{"foo": foo, "bar": bar}, + ) +} + +func execCAS(e Execer, foo, bar string) (bool, string, string, error) { + var outFoo, outBar string + + ok, err := e.ExecCAS( + context.Background(), + map[string]interface{}{"foo": foo, "bar": bar}, + ).ScanCAS( + map[string]interface{}{"foo": &outFoo, "bar": &outBar}, + ) + + return ok, outFoo, outBar, err +} diff --git a/x/cqlbuilder/query_writer.go b/x/cqlbuilder/query_writer.go new file mode 100644 index 0000000..547ea3b --- /dev/null +++ b/x/cqlbuilder/query_writer.go @@ -0,0 +1,20 @@ +package cqlbuilder + +import ( + "io" + "strings" +) + +type QueryWriter interface { + io.Writer + + AddArg(interface{}) +} + +type queryWriter struct { + strings.Builder + + args []interface{} +} + +func (qw *queryWriter) AddArg(a interface{}) { qw.args = append(qw.args, a) } diff --git a/x/cqlbuilder/queryer.go b/x/cqlbuilder/queryer.go new file mode 100644 index 0000000..570d55f --- /dev/null +++ b/x/cqlbuilder/queryer.go @@ -0,0 +1,80 @@ +package cqlbuilder + +import ( + "context" + + "github.com/upfluence/cql" + "github.com/upfluence/pkg/multierror" +) + +type Queryer interface { + Query(context.Context, map[string]interface{}) (Cursor, error) + QueryRow(context.Context, map[string]interface{}) Scanner +} + +type Scanner interface { + Scan(map[string]interface{}) error +} + +type scanner struct { + sc cql.Scanner + ks []string +} + +func (sc *scanner) Scan(vs map[string]interface{}) error { + var svs = make([]interface{}, len(sc.ks)) + + for i, k := range sc.ks { + v, ok := vs[k] + + if !ok { + return ErrMissingKey{Key: k} + } + + svs[i] = v + } + + return sc.sc.Scan(svs...) +} + +type errScanner struct{ error } + +func (es errScanner) Scan(map[string]interface{}) error { return es.error } + +type Cursor interface { + Scan(map[string]interface{}) bool + Close() error +} + +type cursor struct { + c cql.Cursor + ks []string + + err error +} + +func (c *cursor) Scan(vs map[string]interface{}) bool { + var svs = make([]interface{}, len(c.ks)) + + for i, k := range c.ks { + v, ok := vs[k] + + if !ok { + c.err = ErrMissingKey{Key: k} + return false + } + + svs[i] = v + } + + return c.c.Scan(svs...) +} + +func (c *cursor) Close() error { + return multierror.Combine(c.err, c.c.Close()) +} + +type errCursor struct{ error } + +func (ec errCursor) Scan(map[string]interface{}) bool { return false } +func (ec errCursor) Close() error { return ec.error } diff --git a/x/cqlbuilder/select_statement.go b/x/cqlbuilder/select_statement.go new file mode 100644 index 0000000..d15249d --- /dev/null +++ b/x/cqlbuilder/select_statement.go @@ -0,0 +1,102 @@ +package cqlbuilder + +import ( + "context" + "fmt" + "strings" +) + +type Direction string + +const ( + Asc Direction = "ASC" + Desc Direction = "DESC" +) + +type OrderByClause struct { + Field Marker + Direction Direction +} + +type SelectStatement struct { + Table string + + SelectClauses []Marker + WhereClause PredicateClause + OrderByClause OrderByClause + + AllowFiltering bool +} + +func (ss SelectStatement) scanKeys() []string { + var vs = make([]string, len(ss.SelectClauses)) + + for i, f := range ss.SelectClauses { + vs[i] = f.Binding() + } + + return vs +} + +func (ss SelectStatement) buildQuery(qvs map[string]interface{}) (string, []interface{}, error) { + var ( + qw queryWriter + + ks = make([]string, len(ss.SelectClauses)) + ) + + for i, f := range ss.SelectClauses { + ks[i] = f.ToCQL() + } + + fmt.Fprintf(&qw, "SELECT %s FROM %s", strings.Join(ks, ", "), ss.Table) + + if ss.WhereClause != nil { + qw.WriteString(" WHERE ") + + if err := ss.WhereClause.WriteTo(&qw, qvs); err != nil { + return "", nil, err + } + } + + if obc := ss.OrderByClause; obc.Field != nil { + fmt.Fprintf(&qw, " ORDER BY %s %s", obc.Field.ToCQL(), obc.Direction) + } + + if ss.AllowFiltering { + qw.WriteString(" ALLOW FILTERING") + } + + return qw.String(), qw.args, nil +} + +type SelectQueryer struct { + QueryBuilder *QueryBuilder + Statement SelectStatement +} + +func (sq *SelectQueryer) Query(ctx context.Context, qvs map[string]interface{}) Cursor { + stmt, vs, err := sq.Statement.buildQuery(qvs) + + if err != nil { + return errCursor{err} + } + + return &cursor{ + c: sq.QueryBuilder.Query(ctx, stmt, vs...), + ks: sq.Statement.scanKeys(), + } +} + +func (sq *SelectQueryer) QueryRow(ctx context.Context, qvs map[string]interface{}) Scanner { + stmt, vs, err := sq.Statement.buildQuery(qvs) + + if err != nil { + return errScanner{err} + } + + return &scanner{ + sc: sq.QueryBuilder.QueryRow(ctx, stmt, vs...), + ks: sq.Statement.scanKeys(), + } +} diff --git a/x/cqlbuilder/select_statement_test.go b/x/cqlbuilder/select_statement_test.go new file mode 100644 index 0000000..3d7c8c1 --- /dev/null +++ b/x/cqlbuilder/select_statement_test.go @@ -0,0 +1,33 @@ +package cqlbuilder + +import "testing" + +func TestSelectStatement(t *testing.T) { + for _, stc := range []statementTestCase{ + { + name: "basic", + stmt: SelectStatement{ + Table: "foo", + SelectClauses: []Marker{Column("fiz"), Column("buz")}, + WhereClause: Eq(Column("bar")), + AllowFiltering: true, + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2, "bar": 3}, + wantStmt: "SELECT fiz, buz FROM foo WHERE bar = ? ALLOW FILTERING", + wantArgs: []interface{}{3}, + }, + { + name: "basic and", + stmt: SelectStatement{ + Table: "foo", + SelectClauses: []Marker{Column("fiz"), Column("buz")}, + WhereClause: And(Eq(Column("bar")), Eq(Column("fiz"))), + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2, "bar": 3}, + wantStmt: "SELECT fiz, buz FROM foo WHERE (bar = ?) AND (fiz = ?)", + wantArgs: []interface{}{3, 1}, + }, + } { + stc.assert(t) + } +} diff --git a/x/cqlbuilder/update_statement.go b/x/cqlbuilder/update_statement.go new file mode 100644 index 0000000..cb8f7f0 --- /dev/null +++ b/x/cqlbuilder/update_statement.go @@ -0,0 +1,135 @@ +package cqlbuilder + +import ( + "errors" + "fmt" +) + +var ( + errMissingUpdateValue = errors.New("x/cqlbuilder: missing value of the key for update") + errNoUpdates = errors.New("x/cqlbuilder: no update given") +) + +type LWTUpdateClause interface { + LWTClause + + isUpdateClause() +} + +type UpdateOperation interface { + WriteTo(QueryWriter, string, interface{}, bool) error + Clone() UpdateOperation +} + +type set struct{} + +func (set) WriteTo(qw QueryWriter, k string, v interface{}, ok bool) error { + if !ok { + return errMissingUpdateValue + } + + fmt.Fprintf(qw, "%s = ?", k) + qw.AddArg(v) + + return nil +} + +func (set) Clone() UpdateOperation { return set{} } + +var Set = set{} + +type setOp struct{ op string } + +func (sp setOp) WriteTo(qw QueryWriter, k string, v interface{}, ok bool) error { + if !ok { + return errMissingUpdateValue + } + + fmt.Fprintf(qw, "%s = %s %s ?", k, k, sp.op) + qw.AddArg(v) + + return nil +} + +func (sp setOp) Clone() UpdateOperation { return sp } + +var ( + SetAdd = setOp{op: "+"} + SetRemove = setOp{op: "-"} +) + +type UpdateClause struct { + Field Marker + Op UpdateOperation +} + +func (uc UpdateClause) writeTo(qw QueryWriter, qvs map[string]interface{}) error { + var ( + k = uc.Field.Binding() + v, ok = qvs[k] + ) + + err := uc.Op.WriteTo(qw, columnName(uc.Field), v, ok) + + if err == errMissingUpdateValue { + return ErrMissingKey{Key: k} + } + + return err +} + +type UpdateStatement struct { + Table string + + UpdateClauses []UpdateClause + WhereClause PredicateClause + + Options DMLOptions + LWTClause LWTUpdateClause +} + +func (us UpdateStatement) casScanKeys() []string { + if lck, ok := us.LWTClause.(interface{ keys() []string }); ok { + return lck.keys() + } + + return nil +} + +func (us UpdateStatement) buildQuery(qvs map[string]interface{}) (string, []interface{}, error) { + var qw queryWriter + + if len(us.UpdateClauses) == 0 { + return "", nil, errNoUpdates + } + + fmt.Fprintf(&qw, "UPDATE %s", us.Table) + us.Options.writeTo(&qw) + qw.WriteString(" SET ") + + for i, uc := range us.UpdateClauses { + if err := uc.writeTo(&qw, qvs); err != nil { + return "", nil, err + } + + if i < len(us.UpdateClauses)-1 { + qw.WriteString(", ") + } + } + + qw.WriteString(" WHERE ") + + if err := us.WhereClause.WriteTo(&qw, qvs); err != nil { + return "", nil, err + } + + if lc := us.LWTClause; lc != nil { + qw.WriteRune(' ') + + if err := lc.writeTo(&qw, qvs); err != nil { + return "", nil, err + } + } + + return qw.String(), qw.args, nil +} diff --git a/x/cqlbuilder/update_statement_test.go b/x/cqlbuilder/update_statement_test.go new file mode 100644 index 0000000..05f6ba1 --- /dev/null +++ b/x/cqlbuilder/update_statement_test.go @@ -0,0 +1,99 @@ +package cqlbuilder + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/cql" + "github.com/upfluence/cql/cqltest" + "github.com/upfluence/cql/x/migration" +) + +func TestUpdateStatement(t *testing.T) { + for _, stc := range []statementTestCase{ + { + name: "basic", + stmt: UpdateStatement{ + Table: "foo", + UpdateClauses: []UpdateClause{ + {Field: Column("fiz"), Op: Set}, + {Field: Column("buz"), Op: Set}, + }, + WhereClause: Eq(Column("bar")), + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2, "bar": 3}, + wantStmt: "UPDATE foo SET fiz = ?, buz = ? WHERE bar = ?", + wantArgs: []interface{}{1, 2, 3}, + }, + { + name: "complex lwt", + stmt: UpdateStatement{ + Table: "foo", + UpdateClauses: []UpdateClause{ + {Field: Column("fiz"), Op: Set}, + }, + WhereClause: Eq(Column("bar")), + LWTClause: PredicateLWTClause{Predicate: Eq(Column("buz"))}, + }, + vs: map[string]interface{}{"fiz": 1, "buz": 2, "bar": 3}, + wantStmt: "UPDATE foo SET fiz = ? WHERE bar = ? IF buz = ?", + wantArgs: []interface{}{1, 3, 2}, + }, + { + name: "SettAdd", + stmt: UpdateStatement{ + Table: "foo", + UpdateClauses: []UpdateClause{ + {Field: Column("fiz"), Op: SetAdd}, + }, + WhereClause: In(Column("bar")), + }, + vs: map[string]interface{}{"fiz": []int{1}, "bar": []int{3, 4}}, + wantStmt: "UPDATE foo SET fiz = fiz + ? WHERE bar IN (?, ?)", + wantArgs: []interface{}{[]int{1}, 3, 4}, + }, + } { + stc.assert(t) + } +} + +func TestIntegrationSet(t *testing.T) { + cqltest.NewTestCase( + cqltest.WithMigratorFunc(func(db cql.DB) migration.Migrator { + return migration.NewMigrator( + db, + cqltest.StaticSource{ + MigrationUp: "CREATE TABLE IF NOT EXISTS fuz(foo text PRIMARY KEY, bar set)", + MigrationDown: "DROP TABLE fuz", + }, + migration.MigrationTable("cqlbuilder_set_integration_migrations"), + ) + }), + ).Run(t, func(t *testing.T, db cql.DB) { + qb := QueryBuilder{DB: db} + + ue := qb.PrepareUpdate( + UpdateStatement{ + Table: "fuz", + UpdateClauses: []UpdateClause{{Field: Column("bar"), Op: SetAdd}}, + WhereClause: Eq(Column("foo")), + }, + ) + + err := ue.Exec( + context.Background(), + map[string]interface{}{"foo": "foo", "bar": []string{"foo"}}, + ) + + assert.NoError(t, err) + + err = ue.Exec( + context.Background(), + map[string]interface{}{"foo": "foo", "bar": []string{"foo"}}, + ) + + assert.NoError(t, err) + }) +} diff --git a/x/migration/options.go b/x/migration/options.go index 65d8be9..0c81213 100644 --- a/x/migration/options.go +++ b/x/migration/options.go @@ -32,6 +32,8 @@ var defaultOptions = options{ type Option func(*options) +func MigrationTable(t string) Option { return func(o *options) { o.table = t } } + type options struct { table string