Skip to content

sql/parse: parse ORDER BY and LIMIT clauses. #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"strconv"
"strings"

"github.com/gitql/gitql/sql"
Expand All @@ -25,6 +26,8 @@ const (
OrderState
OrderByState
OrderClauseState
LimitState
LimitNumberState
DoneState

ExprState
Expand All @@ -43,6 +46,7 @@ type parser struct {
relation string
filterClauses []sql.Expression
sortFields []plan.SortField
limit *int
}

func newParser(input io.Reader) *parser {
Expand Down Expand Up @@ -111,16 +115,16 @@ func (p *parser) parse() error {
p.stateStack.pop()
state := p.stateStack.peek()
var (
breakKeyword string
nextState ParseState
breakKeywords []string
nextState ParseState
)

switch state {
case SelectState:
breakKeyword = "from"
breakKeywords = []string{"from"}
nextState = FromState
case WhereState:
breakKeyword = "order"
breakKeywords = []string{"order", "limit"}
nextState = OrderState
default:
p.errorf(`unexpected token %q`, t.Value)
Expand All @@ -133,11 +137,13 @@ func (p *parser) parse() error {
p.stateStack.put(ExprState)
break OuterSwitch
case KeywordToken:
if kwMatches(t.Value, breakKeyword) {
p.lexer.Backup()
p.stateStack.pop()
p.stateStack.put(nextState)
break OuterSwitch
for _, kw := range breakKeywords {
if kwMatches(t.Value, kw) {
p.lexer.Backup()
p.stateStack.pop()
p.stateStack.put(nextState)
break OuterSwitch
}
}
case EOFToken:
p.stateStack.pop()
Expand All @@ -146,8 +152,8 @@ func (p *parser) parse() error {
}
}

if breakKeyword != "" {
p.errorf(`expecting "," or %q`, breakKeyword)
if len(breakKeywords) > 0 {
p.errorf(`expecting "," or %q`, breakKeywords)
} else {
p.errorf(`expecting "," or end of sentence`)
}
Expand Down Expand Up @@ -181,7 +187,9 @@ func (p *parser) parse() error {
p.stateStack.pop()
p.stateStack.put(DoneState)
} else if t.Type != KeywordToken || !kwMatches(t.Value, "where") {
p.errorf("expecting 'WHERE', %q received", t.Value)
p.lexer.Backup()
p.stateStack.pop()
p.stateStack.put(OrderState)
} else {
p.stateStack.put(WhereClauseState)
}
Expand All @@ -202,7 +210,9 @@ func (p *parser) parse() error {
p.stateStack.pop()
p.stateStack.put(DoneState)
} else if t.Type != KeywordToken || !kwMatches(t.Value, "order") {
p.errorf("expecting 'ORDER', %q received", t.Value)
p.lexer.Backup()
p.stateStack.pop()
p.stateStack.put(LimitState)
} else {
p.stateStack.put(OrderByState)
}
Expand All @@ -224,6 +234,35 @@ func (p *parser) parse() error {
} else {
p.sortFields = fields
p.stateStack.pop()
p.stateStack.put(LimitState)
}

case LimitState:
t = p.lexer.Next()
if t == nil || t.Type == EOFToken {
p.stateStack.pop()
p.stateStack.put(DoneState)
} else if t.Type != KeywordToken || !kwMatches(t.Value, "limit") {
p.errorf("expecting 'LIMIT', %q received", t.Value)
} else {
p.stateStack.pop()
p.stateStack.put(LimitNumberState)
}

case LimitNumberState:
t = p.lexer.Next()
if t == nil || t.Type == EOFToken {
p.errorf("expecting integer, nothing received")
} else if t.Type != IntToken {
p.errorf("expecting integer, %q received", t.Value)
} else {
i, err := strconv.Atoi(t.Value)
if err != nil {
p.errorf("error parsing integer: %q", err)
}

p.limit = &i
p.stateStack.pop()
p.stateStack.put(DoneState)
}
}
Expand All @@ -248,6 +287,10 @@ func (p *parser) buildPlan() (sql.Node, error) {
node = plan.NewSort(p.sortFields, node)
}

if p.limit != nil {
node = plan.NewLimit(int64(*p.limit), node)
}

return node, nil
}

Expand Down Expand Up @@ -301,6 +344,14 @@ func parseOrderClause(q tokenQueue) ([]plan.SortField, error) {
field.Order = plan.Descending
} else if kwMatches(tk.Value, "asc") {
field.Order = plan.Ascending
} else if kwMatches(tk.Value, "limit") {
if field == nil {
return nil, errors.New(`unexpected LIMIT, expecting identifier`)
}

q.Backup()
fields = append(fields, *field)
return fields, nil
} else {
return nil, fmt.Errorf(`unexpected keyword %q, expecting "ASC", "DESC" or ","`, tk.Value)
}
Expand All @@ -312,7 +363,7 @@ func parseOrderClause(q tokenQueue) ([]plan.SortField, error) {
fields = append(fields, *field)
field = nil
case EOFToken:
if field == nil || len(fields) == 0 {
if field == nil {
return nil, errors.New(`unexpected end of input, expecting identifier`)
}

Expand Down
141 changes: 105 additions & 36 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,117 @@ import (

"github.com/gitql/gitql/sql"
"github.com/gitql/gitql/sql/expression"
"github.com/gitql/gitql/sql/plan"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)

const testSelectFromWhere = `SELECT foo, bar FROM foo WHERE foo = bar;`
const testSelectFrom = `SELECT foo, bar FROM foo;`

func TestParseSelectFromWhere(t *testing.T) {
p := newParser(strings.NewReader(testSelectFromWhere))
require.Nil(t, p.parse())

require.Equal(t, p.projection, []sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
})

require.Equal(t, p.relation, "foo")

require.Equal(t, p.filterClauses, []sql.Expression{
expression.NewEquals(
var fixtures = map[string]sql.Node{
`SELECT foo, bar FROM foo;`: plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewUnresolvedRelation("foo"),
),
`SELECT foo, bar FROM foo WHERE foo = bar;`: plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewFilter(
expression.NewEquals(
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
),
plan.NewUnresolvedRelation("foo"),
),
})

require.Nil(t, p.sortFields)
require.Nil(t, p.err)
require.Equal(t, DoneState, p.stateStack.pop())
),
`SELECT foo, bar FROM foo WHERE foo = 'bar';`: plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewFilter(
expression.NewEquals(
expression.NewUnresolvedColumn("foo"),
expression.NewLiteral("bar", sql.String),
),
plan.NewUnresolvedRelation("foo"),
),
),
`SELECT foo, bar FROM foo LIMIT 10;`: plan.NewLimit(int64(10),
plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewUnresolvedRelation("foo"),
),
),
`SELECT foo, bar FROM foo ORDER BY baz DESC;`: plan.NewSort(
[]plan.SortField{{expression.NewUnresolvedColumn("baz"), plan.Descending}},
plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewUnresolvedRelation("foo"),
),
),
`SELECT foo, bar FROM foo WHERE foo = bar LIMIT 10;`: plan.NewLimit(int64(10),
plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewFilter(
expression.NewEquals(
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
),
plan.NewUnresolvedRelation("foo"),
),
),
),
`SELECT foo, bar FROM foo ORDER BY baz DESC LIMIT 1;`: plan.NewLimit(int64(1),
plan.NewSort(
[]plan.SortField{{expression.NewUnresolvedColumn("baz"), plan.Descending}},
plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewUnresolvedRelation("foo"),
),
),
),
`SELECT foo, bar FROM foo WHERE qux = 1 ORDER BY baz DESC LIMIT 1;`: plan.NewLimit(int64(1),
plan.NewSort(
[]plan.SortField{{expression.NewUnresolvedColumn("baz"), plan.Descending}},
plan.NewProject(
[]sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
},
plan.NewFilter(
expression.NewEquals(
expression.NewUnresolvedColumn("qux"),
expression.NewLiteral(int64(1), sql.BigInteger),
),
plan.NewUnresolvedRelation("foo"),
),
),
),
),
}

func TestParseSelectFrom(t *testing.T) {
p := newParser(strings.NewReader(testSelectFrom))
require.Nil(t, p.parse())

require.Equal(t, p.projection, []sql.Expression{
expression.NewUnresolvedColumn("foo"),
expression.NewUnresolvedColumn("bar"),
})

require.Equal(t, p.relation, "foo")

require.Nil(t, p.sortFields)
require.Nil(t, p.err)
require.Equal(t, DoneState, p.stateStack.pop())
func TestParse(t *testing.T) {
assert := assert.New(t)
for query, expectedPlan := range fixtures {
p, err := Parse(strings.NewReader(query))
assert.Nil(err)
assert.Exactly(expectedPlan, p,
"plans do not match for query '%s'", query)
}
}