From 8f904dc83c52c75549117663b68958c49d016eac Mon Sep 17 00:00:00 2001 From: adrianwit Date: Wed, 28 Feb 2018 10:11:05 -0800 Subject: [PATCH] updated sql parser with SQLCriteria --- file_manager.go | 4 +- file_scanner.go | 1 - manager.go | 6 +- record_mapper.go | 7 +-- sql_manager.go | 4 -- sql_parser.go | 132 ++++++++++++++++++++++++++---------------- sql_parser_test.go | 4 +- sql_predicate.go | 21 +++---- sql_predicate_test.go | 9 ++- 9 files changed, 108 insertions(+), 80 deletions(-) diff --git a/file_manager.go b/file_manager.go index 0adf496..8eccdef 100755 --- a/file_manager.go +++ b/file_manager.go @@ -194,7 +194,7 @@ func (m *FileManager) modifyRecords(tableURL string, statement *DmlStatement, pa var err error var predicate toolbox.Predicate if len(statement.Criteria) > 0 { - predicate, err = NewSQLCriteriaPredicate(parameters, statement.Criteria...) + predicate, err = NewSQLCriteriaPredicate(parameters, statement.SQLCriteria) if err != nil { return 0, fmt.Errorf("failed to read data from %v due to %v", statement.SQL, err) } @@ -429,7 +429,7 @@ func (m *FileManager) ReadAllOnWithHandlerOnConnection(connection Connection, qu var predicate toolbox.Predicate if len(statement.Criteria) > 0 { parameters := toolbox.NewSliceIterator(sqlParameters) - predicate, err = NewSQLCriteriaPredicate(parameters, statement.Criteria...) + predicate, err = NewSQLCriteriaPredicate(parameters, statement.SQLCriteria) if err != nil { return fmt.Errorf("failed to read data from %v due to %v", query, err) } diff --git a/file_scanner.go b/file_scanner.go index 0512718..b13f24d 100755 --- a/file_scanner.go +++ b/file_scanner.go @@ -29,7 +29,6 @@ func (s *FileScanner) Scan(destinations ...interface{}) (err error) { return nil } } - var columns, _ = s.Columns() for i, dest := range destinations { if value, found := s.Values[columns[i]]; found { diff --git a/manager.go b/manager.go index 336e2b2..c743951 100755 --- a/manager.go +++ b/manager.go @@ -250,7 +250,6 @@ func (am *AbstractManager) PersistAllOnConnection(connection Connection, dataPoi if err != nil { return 0, 0, err } - var isStructPointer = structType.Kind() == reflect.Ptr var insertableMapping map[int]int if descriptor.Autoincrement { @@ -375,11 +374,9 @@ func (am *AbstractManager) PersistData(connection Connection, data []interface{} func (am *AbstractManager) fetchDataInBatches(connection Connection, sqlsWihtArguments []*ParametrizedSQL, mapper RecordMapper) (*[][]interface{}, error) { var rows = make([][]interface{}, 0) for _, sqlWihtArguments := range sqlsWihtArguments { - if len(sqlWihtArguments.Values) == 0 { break } - err := am.Manager.ReadAllOnConnection(connection, &rows, sqlWihtArguments.SQL, sqlWihtArguments.Values, mapper) if err != nil { return nil, err @@ -410,7 +407,7 @@ func (am *AbstractManager) fetchExistingData(connection Connection, table string //ClassifyDataAsInsertableOrUpdatable classifies passed in data as insertable or updatable. func (am *AbstractManager) ClassifyDataAsInsertableOrUpdatable(connection Connection, dataPointer interface{}, table string, provider DmlProvider) ([]interface{}, []interface{}, error) { if provider == nil { - return nil, nil, errors.New("Provider was nil") + return nil, nil, errors.New("provider was nil") } var rowsByKey = make(map[string]interface{}, 0) @@ -433,6 +430,7 @@ func (am *AbstractManager) ClassifyDataAsInsertableOrUpdatable(connection Connec //process existing rows and add mapped entires as updatables for _, row := range rows { + key := toolbox.JoinAsString(row, "") if instance, ok := rowsByKey[key]; ok { updatables = append(updatables, instance) diff --git a/record_mapper.go b/record_mapper.go index a1d44bb..0e3bc3c 100755 --- a/record_mapper.go +++ b/record_mapper.go @@ -184,6 +184,7 @@ func (rm *mapRecordMapper) Map(scanner Scanner) (interface{}, error) { return nil, err } columns, _ := scanner.Columns() + aMap := make(map[string]interface{}) for i, column := range columns { aMap[column] = result[i] @@ -221,7 +222,7 @@ func NewRecordMapper(targetType reflect.Type) RecordMapper { return mapper } default: - panic("unsupported type: " + targetType.Name()) + panic(fmt.Sprintf("unsupported type: %v ", targetType.Name())) } return nil } @@ -237,15 +238,12 @@ func NewRecordMapperIfNeeded(mapper RecordMapper, targetType reflect.Type) Recor //ScanRow takes scanner to scans row. func ScanRow(scanner Scanner) ([]interface{}, []string, error) { columns, _ := scanner.Columns() - count := len(columns) var valuePointers = make([]interface{}, count) var rowValues = make([]interface{}, count) - for i := range rowValues { valuePointers[i] = &rowValues[i] } - err := scanner.Scan(valuePointers...) if err != nil { return nil, nil, fmt.Errorf("failed to scan row due to %v", err) @@ -262,5 +260,6 @@ func ScanRow(scanner Scanner) ([]interface{}, []string, error) { } rowValues[i] = value } + return rowValues, columns, nil } diff --git a/sql_manager.go b/sql_manager.go index dd9adc2..1ced37d 100755 --- a/sql_manager.go +++ b/sql_manager.go @@ -6,10 +6,6 @@ import ( "reflect" ) -var Logf = func(format string, args ...interface{}) { - -} - func asSQLDb(wrapped interface{}) (*sql.DB, error) { if result, ok := wrapped.(*sql.DB); ok { return result, nil diff --git a/sql_parser.go b/sql_parser.go index e29f3b4..053b1fc 100755 --- a/sql_parser.go +++ b/sql_parser.go @@ -16,22 +16,56 @@ type SQLColumn struct { FunctionArguments string } +//SQLCriteria represents SQL criteria +type SQLCriteria struct { + Criteria []*SQLCriterion + LogicalOperator string +} + +//CriteriaValues returns criteria values extracted from binding parameters, starting from parametersOffset, +func (s *SQLCriteria) CriteriaValues(parameters toolbox.Iterator) ([]interface{}, error) { + var values = make([]interface{}, 0) + for _, criterion := range s.Criteria { + if criterion.Criteria != nil && len(criterion.Criteria.Criteria) > 0 { + criteriaValues, err := criterion.Criteria.CriteriaValues(parameters) + if err != nil { + return nil, err + } + values = append(values, criteriaValues) + continue + } + var criterionValues = criterion.RightOperands + if len(criterionValues) == 0 { + criterionValues = []interface{}{criterion.RightOperand} + } + for i := range criterionValues { + value, err := bindValueIfNeeded(criterionValues[i], parameters) + if err != nil { + return nil, err + } + values = append(values, value) + } + } + return values, nil +} + //SQLCriterion represents single where clause condiction type SQLCriterion struct { - LeftOperand interface{} - Operator string - RightOperand interface{} - RightOperands []interface{} - Inverse bool // if not operator presents - LogicalOperator string + Criteria *SQLCriteria + LeftOperand interface{} + Operator string + RightOperand interface{} + RightOperands []interface{} + Inverse bool // if not operator presents + } //BaseStatement represents a base query and dml statement type BaseStatement struct { - SQL string - Table string - Columns []*SQLColumn - Criteria []*SQLCriterion + *SQLCriteria + SQL string + Table string + Columns []*SQLColumn } //ColumnNames returns a column names. @@ -67,28 +101,9 @@ func bindValueIfNeeded(source interface{}, parameters toolbox.Iterator) (interfa return source, nil } -//CriteriaValues returns criteria values extracted from binding parameters, starting from parametersOffset, -func (bs BaseStatement) CriteriaValues(parameters toolbox.Iterator) ([]interface{}, error) { - var values = make([]interface{}, 0) - for _, criterion := range bs.Criteria { - var criterionValues = criterion.RightOperands - if len(criterionValues) == 0 { - criterionValues = []interface{}{criterion.RightOperand} - } - for i := range criterionValues { - value, err := bindValueIfNeeded(criterionValues[i], parameters) - if err != nil { - return nil, err - } - values = append(values, value) - } - } - return values, nil -} - //QueryStatement represents SQL query statement. type QueryStatement struct { - BaseStatement + *BaseStatement AllField bool UnionTables []string GroupBy []*SQLColumn @@ -96,7 +111,7 @@ type QueryStatement struct { //DmlStatement represents dml statement. type DmlStatement struct { - BaseStatement + *BaseStatement Type string Values []interface{} } @@ -319,8 +334,7 @@ func (bp *baseParser) readInValues(tokenizer *toolbox.Tokenizer) (string, []inte return value, values, nil } -func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *BaseStatement, token *toolbox.Token) (err error) { - statement.Criteria = make([]*SQLCriterion, 0) +func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, sqlCriteria *SQLCriteria, token *toolbox.Token) (err error) { for { token, err = bp.expectWhitespaceFollowedBy(tokenizer, "value", sqlValue) if err != nil { @@ -330,27 +344,27 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base return fmt.Errorf("Expected criteria at %v", tokenizer.Index) } - index := len(statement.Criteria) - statement.Criteria = append(statement.Criteria, &SQLCriterion{LeftOperand: token.Matched}) + index := len(sqlCriteria.Criteria) + sqlCriteria.Criteria = append(sqlCriteria.Criteria, &SQLCriterion{LeftOperand: token.Matched}) token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", inOperatorKeyword, likeOperatorKeyword, betweenOperatorKeyword, notKeyword, isOperatorKeyword, operator, eof) if err != nil { return err } - statement.Criteria[index].Operator = token.Matched + sqlCriteria.Criteria[index].Operator = token.Matched switch token.Token { case notKeyword: - statement.Criteria[index].Inverse = true + sqlCriteria.Criteria[index].Inverse = true token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", inOperatorKeyword, likeOperatorKeyword) if err != nil { return err } - statement.Criteria[index].Operator = token.Matched + sqlCriteria.Criteria[index].Operator = token.Matched fallthrough case inOperatorKeyword: var value, values, err = bp.readInValues(tokenizer) - statement.Criteria[index].RightOperand = value - statement.Criteria[index].RightOperands = values + sqlCriteria.Criteria[index].RightOperand = value + sqlCriteria.Criteria[index].RightOperands = values if err != nil { return err } @@ -359,19 +373,19 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base if err != nil { return err } - statement.Criteria[index].RightOperand = token.Matched + sqlCriteria.Criteria[index].RightOperand = token.Matched case isOperatorKeyword: token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", notKeyword, nullKeyword) if err != nil { return err } if token.Token == notKeyword { - statement.Criteria[index].Inverse = true + sqlCriteria.Criteria[index].Inverse = true token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", nullKeyword) if err != nil { return err } - statement.Criteria[index].RightOperand = token.Matched + sqlCriteria.Criteria[index].RightOperand = token.Matched } case betweenOperatorKeyword: token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "value", sqlValue) @@ -388,7 +402,7 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base return err } toValue := token.Matched - statement.Criteria[index].RightOperands = []interface{}{fromValue, toValue} + sqlCriteria.Criteria[index].RightOperands = []interface{}{fromValue, toValue} } token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "or | and | eof", eof, logicalOperator, groupKeyword) if err != nil { @@ -401,7 +415,14 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base tokenizer.Index -= len(token.Matched) break } - statement.Criteria[index].LogicalOperator = token.Matched + if len(sqlCriteria.Criteria) == 1 { + sqlCriteria.LogicalOperator = token.Matched + } else if token.Matched == sqlCriteria.LogicalOperator { + continue + } else { + return fmt.Errorf("various logical operator not supported yet") + } + } return nil } @@ -500,7 +521,12 @@ outer: //Parse parses SQL query to build QueryStatement func (qp *QueryParser) Parse(query string) (*QueryStatement, error) { tokenizer := toolbox.NewTokenizer(query, illegal, eof, sqlMatchers) - baseStatement := BaseStatement{SQL: query} + baseStatement := &BaseStatement{ + SQL: query, + SQLCriteria: &SQLCriteria{ + Criteria: make([]*SQLCriterion, 0), + }, + } result := &QueryStatement{BaseStatement: baseStatement} var token *toolbox.Token @@ -544,7 +570,7 @@ func (qp *QueryParser) Parse(query string) (*QueryStatement, error) { } if token.Token == whereKeyword { - err = qp.readCriteria(tokenizer, &result.BaseStatement, token) + err = qp.readCriteria(tokenizer, result.BaseStatement.SQLCriteria, token) if err != nil { return nil, err } @@ -724,7 +750,7 @@ func (dp *DmlParser) parseUpdate(tokenizer *toolbox.Tokenizer, statement *DmlSta return nil } - err = dp.readCriteria(tokenizer, &statement.BaseStatement, token) + err = dp.readCriteria(tokenizer, statement.BaseStatement.SQLCriteria, token) if err != nil { return err } @@ -750,7 +776,7 @@ func (dp *DmlParser) parseDelete(tokenizer *toolbox.Tokenizer, statement *DmlSta return nil } - err = dp.readCriteria(tokenizer, &statement.BaseStatement, token) + err = dp.readCriteria(tokenizer, statement.BaseStatement.SQLCriteria, token) if err != nil { return err } @@ -759,8 +785,12 @@ func (dp *DmlParser) parseDelete(tokenizer *toolbox.Tokenizer, statement *DmlSta //Parse parses input to create DmlStatement. func (dp *DmlParser) Parse(input string) (*DmlStatement, error) { - baseStatement := BaseStatement{SQL: input} - result := &DmlStatement{BaseStatement: baseStatement} + baseStatement := BaseStatement{ + SQL: input, + SQLCriteria: &SQLCriteria{ + Criteria: make([]*SQLCriterion, 0), + }} + result := &DmlStatement{BaseStatement: &baseStatement} tokenizer := toolbox.NewTokenizer(input, illegal, eof, sqlMatchers) token, err := dp.expectOptionalWhitespaceFollowedBy(tokenizer, "INSERT INTO | UPDATE | DELETE", insertKeyword, updateKeyword, deleteKeyword) if err != nil { diff --git a/sql_parser_test.go b/sql_parser_test.go index 3a6f2ff..7325599 100755 --- a/sql_parser_test.go +++ b/sql_parser_test.go @@ -122,14 +122,14 @@ func TestQueryParser(t *testing.T) { assert.Equal(t, "column1", query.Criteria[0].LeftOperand) assert.Equal(t, "=", query.Criteria[0].Operator) assert.Equal(t, "2", query.Criteria[0].RightOperand) - assert.Equal(t, "AND", query.Criteria[0].LogicalOperator) + assert.Equal(t, "AND", query.LogicalOperator) } { assert.Equal(t, "column2", query.Criteria[1].LeftOperand) assert.Equal(t, "!=", query.Criteria[1].Operator) assert.Equal(t, "?", query.Criteria[1].RightOperand) - assert.Equal(t, "", query.Criteria[1].LogicalOperator) + assert.Equal(t, "AND", query.LogicalOperator) } diff --git a/sql_predicate.go b/sql_predicate.go index 788d315..eaab663 100755 --- a/sql_predicate.go +++ b/sql_predicate.go @@ -91,7 +91,7 @@ func NewBooleanPredicate(leftOperand bool, operator string) toolbox.Predicate { } type sqlCriteriaPredicate struct { - criteria []*SQLCriterion + *SQLCriteria predicates []toolbox.Predicate } @@ -103,8 +103,8 @@ func (p *sqlCriteriaPredicate) Apply(source interface{}) bool { result := true var logicalPredicate toolbox.Predicate - for i := 0; i < len(p.criteria); i++ { - criterion := p.criteria[i] + for i := 0; i < len(p.Criteria); i++ { + criterion := p.Criteria[i] value := sourceMap[toolbox.AsString(criterion.LeftOperand)] predicate := p.predicates[i] result = predicate.Apply(value) @@ -114,27 +114,28 @@ func (p *sqlCriteriaPredicate) Apply(source interface{}) bool { if logicalPredicate != nil { result = logicalPredicate.Apply(result) } - if criterion.LogicalOperator != "" { - if strings.ToLower(criterion.LogicalOperator) == "and" && !result { + if p.LogicalOperator != "" { + if strings.ToLower(p.LogicalOperator) == "and" && !result { //shortcut break } - logicalPredicate = NewBooleanPredicate(result, criterion.LogicalOperator) + logicalPredicate = NewBooleanPredicate(result, p.LogicalOperator) } } return result } //NewSQLCriteriaPredicate create a new sql criteria predicate, it takes binding parameters iterator, and actual criteria. -func NewSQLCriteriaPredicate(parameters toolbox.Iterator, criteria ...*SQLCriterion) (toolbox.Predicate, error) { +func NewSQLCriteriaPredicate(parameters toolbox.Iterator, sqlCriteria *SQLCriteria) (toolbox.Predicate, error) { var predicates = make([]toolbox.Predicate, 0) - for i := 0; i < len(criteria); i++ { - criterion := criteria[i] + + for i := 0; i < len(sqlCriteria.Criteria); i++ { + criterion := sqlCriteria.Criteria[i] predicate, err := NewSQLCriterionPredicate(criterion, parameters) if err != nil { return nil, err } predicates = append(predicates, predicate) } - return &sqlCriteriaPredicate{criteria: criteria, predicates: predicates}, nil + return &sqlCriteriaPredicate{SQLCriteria: sqlCriteria, predicates: predicates}, nil } diff --git a/sql_predicate_test.go b/sql_predicate_test.go index c1ce229..1632f35 100755 --- a/sql_predicate_test.go +++ b/sql_predicate_test.go @@ -64,8 +64,13 @@ func TestNewCriteriaPredicate(t *testing.T) { parameters := []interface{}{"abc%", 123} iterator := toolbox.NewSliceIterator(parameters) predicate, err := dsc.NewSQLCriteriaPredicate(iterator, - &dsc.SQLCriterion{LeftOperand: "column1", Operator: "Like", RightOperand: "?", LogicalOperator: "or"}, - &dsc.SQLCriterion{LeftOperand: "column2", Operator: "=", RightOperand: "?"}, + &dsc.SQLCriteria{ + LogicalOperator: "or", + Criteria: []*dsc.SQLCriterion{ + {LeftOperand: "column1", Operator: "Like", RightOperand: "?"}, + {LeftOperand: "column2", Operator: "=", RightOperand: "?"}, + }, + }, ) assert.Nil(t, err) {