Skip to content

Commit

Permalink
Feat(prql): add filter, set operations (#3291)
Browse files Browse the repository at this point in the history
* filter for prql

* add todo

* add todo and test

* add append

* update set op

* update

* update CONJUCTION

* update FILTER
  • Loading branch information
fool1280 committed Apr 9, 2024
1 parent 4cda01e commit eabb708
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
22 changes: 22 additions & 0 deletions sqlglot/dialects/prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from sqlglot.tokens import TokenType


def _select_all(table: exp.Expression) -> t.Optional[exp.Select]:
return exp.select("*").from_(table, copy=False) if table else None


class PRQL(Dialect):
DPIPE_IS_STRING_CONCAT = False

class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
QUOTES = ["'", '"']
Expand All @@ -26,10 +32,26 @@ class Tokenizer(tokens.Tokenizer):
}

class Parser(parser.Parser):
CONJUNCTION = {
**parser.Parser.CONJUNCTION,
TokenType.DAMP: exp.And,
TokenType.DPIPE: exp.Or,
}

TRANSFORM_PARSERS = {
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
"FILTER": lambda self, query: query.where(self._parse_conjunction()),
"APPEND": lambda self, query: query.union(
_select_all(self._parse_table()), distinct=False, copy=False
),
"REMOVE": lambda self, query: query.except_(
_select_all(self._parse_table()), distinct=False, copy=False
),
"INTERSECT": lambda self, query: query.intersect(
_select_all(self._parse_table()), distinct=False, copy=False
),
}

def _parse_statement(self) -> t.Optional[exp.Expression]:
Expand Down
29 changes: 28 additions & 1 deletion tests/dialects/test_prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,35 @@ def test_prql(self):
self.validate_identity("FROM x DERIVE x = a + 1", "SELECT *, a + 1 AS x FROM x")
self.validate_identity("FROM x DERIVE {a + 1}", "SELECT *, a + 1 FROM x")
self.validate_identity("FROM x DERIVE {x = a + 1, b}", "SELECT *, a + 1 AS x, b FROM x")
self.validate_identity(
"FROM x DERIVE {x = a + 1, b} SELECT {y = x, 2}", "SELECT a + 1 AS y, 2 FROM x"
)
self.validate_identity("FROM x TAKE 10", "SELECT * FROM x LIMIT 10")
self.validate_identity("FROM x TAKE 10 TAKE 5", "SELECT * FROM x LIMIT 5")
self.validate_identity("FROM x FILTER age > 25", "SELECT * FROM x WHERE age > 25")
self.validate_identity(
"FROM x DERIVE {x = a + 1, b} SELECT {y = x, 2}", "SELECT a + 1 AS y, 2 FROM x"
"FROM x DERIVE {x = a + 1, b} FILTER age > 25",
"SELECT *, a + 1 AS x, b FROM x WHERE age > 25",
)
self.validate_identity("FROM x FILTER dept != 'IT'", "SELECT * FROM x WHERE dept <> 'IT'")
self.validate_identity(
"FROM x FILTER p == 'product' SELECT { a, b }", "SELECT a, b FROM x WHERE p = 'product'"
)
self.validate_identity(
"FROM x FILTER age > 25 FILTER age < 27", "SELECT * FROM x WHERE age > 25 AND age < 27"
)
self.validate_identity(
"FROM x FILTER (age > 25 && age < 27)", "SELECT * FROM x WHERE (age > 25 AND age < 27)"
)
self.validate_identity(
"FROM x FILTER (age > 25 || age < 27)", "SELECT * FROM x WHERE (age > 25 OR age < 27)"
)
self.validate_identity(
"FROM x FILTER (age > 25 || age < 22) FILTER age > 26 FILTER age < 27",
"SELECT * FROM x WHERE ((age > 25 OR age < 22) AND age > 26) AND age < 27",
)
self.validate_identity("FROM x APPEND y", "SELECT * FROM x UNION ALL SELECT * FROM y")
self.validate_identity("FROM x REMOVE y", "SELECT * FROM x EXCEPT ALL SELECT * FROM y")
self.validate_identity(
"FROM x INTERSECT y", "SELECT * FROM x INTERSECT ALL SELECT * FROM y"
)

0 comments on commit eabb708

Please sign in to comment.