Skip to content

Commit

Permalink
updated sql parser with SQLCriteria
Browse files Browse the repository at this point in the history
  • Loading branch information
adranwit committed Feb 28, 2018
1 parent b53502a commit 8f904dc
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 80 deletions.
4 changes: 2 additions & 2 deletions file_manager.go
Expand Up @@ -194,7 +194,7 @@ func (m *FileManager) modifyRecords(tableURL string, statement *DmlStatement, pa
var err error
var predicate toolbox.Predicate
if len(statement.Criteria) > 0 {
predicate, err = NewSQLCriteriaPredicate(parameters, statement.Criteria...)
predicate, err = NewSQLCriteriaPredicate(parameters, statement.SQLCriteria)
if err != nil {
return 0, fmt.Errorf("failed to read data from %v due to %v", statement.SQL, err)
}
Expand Down Expand Up @@ -429,7 +429,7 @@ func (m *FileManager) ReadAllOnWithHandlerOnConnection(connection Connection, qu
var predicate toolbox.Predicate
if len(statement.Criteria) > 0 {
parameters := toolbox.NewSliceIterator(sqlParameters)
predicate, err = NewSQLCriteriaPredicate(parameters, statement.Criteria...)
predicate, err = NewSQLCriteriaPredicate(parameters, statement.SQLCriteria)
if err != nil {
return fmt.Errorf("failed to read data from %v due to %v", query, err)
}
Expand Down
1 change: 0 additions & 1 deletion file_scanner.go
Expand Up @@ -29,7 +29,6 @@ func (s *FileScanner) Scan(destinations ...interface{}) (err error) {
return nil
}
}

var columns, _ = s.Columns()
for i, dest := range destinations {
if value, found := s.Values[columns[i]]; found {
Expand Down
6 changes: 2 additions & 4 deletions manager.go
Expand Up @@ -250,7 +250,6 @@ func (am *AbstractManager) PersistAllOnConnection(connection Connection, dataPoi
if err != nil {
return 0, 0, err
}

var isStructPointer = structType.Kind() == reflect.Ptr
var insertableMapping map[int]int
if descriptor.Autoincrement {
Expand Down Expand Up @@ -375,11 +374,9 @@ func (am *AbstractManager) PersistData(connection Connection, data []interface{}
func (am *AbstractManager) fetchDataInBatches(connection Connection, sqlsWihtArguments []*ParametrizedSQL, mapper RecordMapper) (*[][]interface{}, error) {
var rows = make([][]interface{}, 0)
for _, sqlWihtArguments := range sqlsWihtArguments {

if len(sqlWihtArguments.Values) == 0 {
break
}

err := am.Manager.ReadAllOnConnection(connection, &rows, sqlWihtArguments.SQL, sqlWihtArguments.Values, mapper)
if err != nil {
return nil, err
Expand Down Expand Up @@ -410,7 +407,7 @@ func (am *AbstractManager) fetchExistingData(connection Connection, table string
//ClassifyDataAsInsertableOrUpdatable classifies passed in data as insertable or updatable.
func (am *AbstractManager) ClassifyDataAsInsertableOrUpdatable(connection Connection, dataPointer interface{}, table string, provider DmlProvider) ([]interface{}, []interface{}, error) {
if provider == nil {
return nil, nil, errors.New("Provider was nil")
return nil, nil, errors.New("provider was nil")
}

var rowsByKey = make(map[string]interface{}, 0)
Expand All @@ -433,6 +430,7 @@ func (am *AbstractManager) ClassifyDataAsInsertableOrUpdatable(connection Connec

//process existing rows and add mapped entires as updatables
for _, row := range rows {

key := toolbox.JoinAsString(row, "")
if instance, ok := rowsByKey[key]; ok {
updatables = append(updatables, instance)
Expand Down
7 changes: 3 additions & 4 deletions record_mapper.go
Expand Up @@ -184,6 +184,7 @@ func (rm *mapRecordMapper) Map(scanner Scanner) (interface{}, error) {
return nil, err
}
columns, _ := scanner.Columns()

aMap := make(map[string]interface{})
for i, column := range columns {
aMap[column] = result[i]
Expand Down Expand Up @@ -221,7 +222,7 @@ func NewRecordMapper(targetType reflect.Type) RecordMapper {
return mapper
}
default:
panic("unsupported type: " + targetType.Name())
panic(fmt.Sprintf("unsupported type: %v ", targetType.Name()))
}
return nil
}
Expand All @@ -237,15 +238,12 @@ func NewRecordMapperIfNeeded(mapper RecordMapper, targetType reflect.Type) Recor
//ScanRow takes scanner to scans row.
func ScanRow(scanner Scanner) ([]interface{}, []string, error) {
columns, _ := scanner.Columns()

count := len(columns)
var valuePointers = make([]interface{}, count)
var rowValues = make([]interface{}, count)

for i := range rowValues {
valuePointers[i] = &rowValues[i]
}

err := scanner.Scan(valuePointers...)
if err != nil {
return nil, nil, fmt.Errorf("failed to scan row due to %v", err)
Expand All @@ -262,5 +260,6 @@ func ScanRow(scanner Scanner) ([]interface{}, []string, error) {
}
rowValues[i] = value
}

return rowValues, columns, nil
}
4 changes: 0 additions & 4 deletions sql_manager.go
Expand Up @@ -6,10 +6,6 @@ import (
"reflect"
)

var Logf = func(format string, args ...interface{}) {

}

func asSQLDb(wrapped interface{}) (*sql.DB, error) {
if result, ok := wrapped.(*sql.DB); ok {
return result, nil
Expand Down
132 changes: 81 additions & 51 deletions sql_parser.go
Expand Up @@ -16,22 +16,56 @@ type SQLColumn struct {
FunctionArguments string
}

//SQLCriteria represents SQL criteria
type SQLCriteria struct {
Criteria []*SQLCriterion
LogicalOperator string
}

//CriteriaValues returns criteria values extracted from binding parameters, starting from parametersOffset,
func (s *SQLCriteria) CriteriaValues(parameters toolbox.Iterator) ([]interface{}, error) {
var values = make([]interface{}, 0)
for _, criterion := range s.Criteria {
if criterion.Criteria != nil && len(criterion.Criteria.Criteria) > 0 {
criteriaValues, err := criterion.Criteria.CriteriaValues(parameters)
if err != nil {
return nil, err
}
values = append(values, criteriaValues)
continue
}
var criterionValues = criterion.RightOperands
if len(criterionValues) == 0 {
criterionValues = []interface{}{criterion.RightOperand}
}
for i := range criterionValues {
value, err := bindValueIfNeeded(criterionValues[i], parameters)
if err != nil {
return nil, err
}
values = append(values, value)
}
}
return values, nil
}

//SQLCriterion represents single where clause condiction
type SQLCriterion struct {
LeftOperand interface{}
Operator string
RightOperand interface{}
RightOperands []interface{}
Inverse bool // if not operator presents
LogicalOperator string
Criteria *SQLCriteria
LeftOperand interface{}
Operator string
RightOperand interface{}
RightOperands []interface{}
Inverse bool // if not operator presents

}

//BaseStatement represents a base query and dml statement
type BaseStatement struct {
SQL string
Table string
Columns []*SQLColumn
Criteria []*SQLCriterion
*SQLCriteria
SQL string
Table string
Columns []*SQLColumn
}

//ColumnNames returns a column names.
Expand Down Expand Up @@ -67,36 +101,17 @@ func bindValueIfNeeded(source interface{}, parameters toolbox.Iterator) (interfa
return source, nil
}

//CriteriaValues returns criteria values extracted from binding parameters, starting from parametersOffset,
func (bs BaseStatement) CriteriaValues(parameters toolbox.Iterator) ([]interface{}, error) {
var values = make([]interface{}, 0)
for _, criterion := range bs.Criteria {
var criterionValues = criterion.RightOperands
if len(criterionValues) == 0 {
criterionValues = []interface{}{criterion.RightOperand}
}
for i := range criterionValues {
value, err := bindValueIfNeeded(criterionValues[i], parameters)
if err != nil {
return nil, err
}
values = append(values, value)
}
}
return values, nil
}

//QueryStatement represents SQL query statement.
type QueryStatement struct {
BaseStatement
*BaseStatement
AllField bool
UnionTables []string
GroupBy []*SQLColumn
}

//DmlStatement represents dml statement.
type DmlStatement struct {
BaseStatement
*BaseStatement
Type string
Values []interface{}
}
Expand Down Expand Up @@ -319,8 +334,7 @@ func (bp *baseParser) readInValues(tokenizer *toolbox.Tokenizer) (string, []inte
return value, values, nil
}

func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *BaseStatement, token *toolbox.Token) (err error) {
statement.Criteria = make([]*SQLCriterion, 0)
func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, sqlCriteria *SQLCriteria, token *toolbox.Token) (err error) {
for {
token, err = bp.expectWhitespaceFollowedBy(tokenizer, "value", sqlValue)
if err != nil {
Expand All @@ -330,27 +344,27 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base
return fmt.Errorf("Expected criteria at %v", tokenizer.Index)
}

index := len(statement.Criteria)
statement.Criteria = append(statement.Criteria, &SQLCriterion{LeftOperand: token.Matched})
index := len(sqlCriteria.Criteria)
sqlCriteria.Criteria = append(sqlCriteria.Criteria, &SQLCriterion{LeftOperand: token.Matched})

token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", inOperatorKeyword, likeOperatorKeyword, betweenOperatorKeyword, notKeyword, isOperatorKeyword, operator, eof)
if err != nil {
return err
}
statement.Criteria[index].Operator = token.Matched
sqlCriteria.Criteria[index].Operator = token.Matched
switch token.Token {
case notKeyword:
statement.Criteria[index].Inverse = true
sqlCriteria.Criteria[index].Inverse = true
token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", inOperatorKeyword, likeOperatorKeyword)
if err != nil {
return err
}
statement.Criteria[index].Operator = token.Matched
sqlCriteria.Criteria[index].Operator = token.Matched
fallthrough
case inOperatorKeyword:
var value, values, err = bp.readInValues(tokenizer)
statement.Criteria[index].RightOperand = value
statement.Criteria[index].RightOperands = values
sqlCriteria.Criteria[index].RightOperand = value
sqlCriteria.Criteria[index].RightOperands = values
if err != nil {
return err
}
Expand All @@ -359,19 +373,19 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base
if err != nil {
return err
}
statement.Criteria[index].RightOperand = token.Matched
sqlCriteria.Criteria[index].RightOperand = token.Matched
case isOperatorKeyword:
token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", notKeyword, nullKeyword)
if err != nil {
return err
}
if token.Token == notKeyword {
statement.Criteria[index].Inverse = true
sqlCriteria.Criteria[index].Inverse = true
token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "operator", nullKeyword)
if err != nil {
return err
}
statement.Criteria[index].RightOperand = token.Matched
sqlCriteria.Criteria[index].RightOperand = token.Matched
}
case betweenOperatorKeyword:
token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "value", sqlValue)
Expand All @@ -388,7 +402,7 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base
return err
}
toValue := token.Matched
statement.Criteria[index].RightOperands = []interface{}{fromValue, toValue}
sqlCriteria.Criteria[index].RightOperands = []interface{}{fromValue, toValue}
}
token, err = bp.expectOptionalWhitespaceFollowedBy(tokenizer, "or | and | eof", eof, logicalOperator, groupKeyword)
if err != nil {
Expand All @@ -401,7 +415,14 @@ func (bp *baseParser) readCriteria(tokenizer *toolbox.Tokenizer, statement *Base
tokenizer.Index -= len(token.Matched)
break
}
statement.Criteria[index].LogicalOperator = token.Matched
if len(sqlCriteria.Criteria) == 1 {
sqlCriteria.LogicalOperator = token.Matched
} else if token.Matched == sqlCriteria.LogicalOperator {
continue
} else {
return fmt.Errorf("various logical operator not supported yet")
}

}
return nil
}
Expand Down Expand Up @@ -500,7 +521,12 @@ outer:
//Parse parses SQL query to build QueryStatement
func (qp *QueryParser) Parse(query string) (*QueryStatement, error) {
tokenizer := toolbox.NewTokenizer(query, illegal, eof, sqlMatchers)
baseStatement := BaseStatement{SQL: query}
baseStatement := &BaseStatement{
SQL: query,
SQLCriteria: &SQLCriteria{
Criteria: make([]*SQLCriterion, 0),
},
}
result := &QueryStatement{BaseStatement: baseStatement}
var token *toolbox.Token

Expand Down Expand Up @@ -544,7 +570,7 @@ func (qp *QueryParser) Parse(query string) (*QueryStatement, error) {
}

if token.Token == whereKeyword {
err = qp.readCriteria(tokenizer, &result.BaseStatement, token)
err = qp.readCriteria(tokenizer, result.BaseStatement.SQLCriteria, token)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -724,7 +750,7 @@ func (dp *DmlParser) parseUpdate(tokenizer *toolbox.Tokenizer, statement *DmlSta
return nil
}

err = dp.readCriteria(tokenizer, &statement.BaseStatement, token)
err = dp.readCriteria(tokenizer, statement.BaseStatement.SQLCriteria, token)
if err != nil {
return err
}
Expand All @@ -750,7 +776,7 @@ func (dp *DmlParser) parseDelete(tokenizer *toolbox.Tokenizer, statement *DmlSta
return nil
}

err = dp.readCriteria(tokenizer, &statement.BaseStatement, token)
err = dp.readCriteria(tokenizer, statement.BaseStatement.SQLCriteria, token)
if err != nil {
return err
}
Expand All @@ -759,8 +785,12 @@ func (dp *DmlParser) parseDelete(tokenizer *toolbox.Tokenizer, statement *DmlSta

//Parse parses input to create DmlStatement.
func (dp *DmlParser) Parse(input string) (*DmlStatement, error) {
baseStatement := BaseStatement{SQL: input}
result := &DmlStatement{BaseStatement: baseStatement}
baseStatement := BaseStatement{
SQL: input,
SQLCriteria: &SQLCriteria{
Criteria: make([]*SQLCriterion, 0),
}}
result := &DmlStatement{BaseStatement: &baseStatement}
tokenizer := toolbox.NewTokenizer(input, illegal, eof, sqlMatchers)
token, err := dp.expectOptionalWhitespaceFollowedBy(tokenizer, "INSERT INTO | UPDATE | DELETE", insertKeyword, updateKeyword, deleteKeyword)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions sql_parser_test.go
Expand Up @@ -122,14 +122,14 @@ func TestQueryParser(t *testing.T) {
assert.Equal(t, "column1", query.Criteria[0].LeftOperand)
assert.Equal(t, "=", query.Criteria[0].Operator)
assert.Equal(t, "2", query.Criteria[0].RightOperand)
assert.Equal(t, "AND", query.Criteria[0].LogicalOperator)
assert.Equal(t, "AND", query.LogicalOperator)

}
{
assert.Equal(t, "column2", query.Criteria[1].LeftOperand)
assert.Equal(t, "!=", query.Criteria[1].Operator)
assert.Equal(t, "?", query.Criteria[1].RightOperand)
assert.Equal(t, "", query.Criteria[1].LogicalOperator)
assert.Equal(t, "AND", query.LogicalOperator)

}

Expand Down

0 comments on commit 8f904dc

Please sign in to comment.