Skip to content

Commit

Permalink
fix: Fixes #10 typing a letter afer period was not auto completing
Browse files Browse the repository at this point in the history
  • Loading branch information
qharlie committed Apr 2, 2022
1 parent bc9bb72 commit ee8c7f5
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 37 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -6,3 +6,4 @@ __pycache__/
scratch.py
.python-version
.prql-history.txt
*.db
62 changes: 40 additions & 22 deletions pyprql/cli/PRQLCompleter.py
@@ -1,22 +1,30 @@
# -*- coding: utf-8 -*-
"""Prompt_toolkit completion engine for PyPRQL CLI."""
from datetime import datetime
from typing import Dict, Iterable, List, Optional

from enforce_typing import enforce_types
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
from prompt_toolkit.document import Document


def log_to_file(s):
with open('output.txt', 'a') as f:
f.write(datetime.now().strftime("%m/%d/%Y, %H:%M:%S"))
f.write(":" + s)
f.write('\n')


class PRQLCompleter(Completer):
"""Prompt_toolkit completion engine for PyPRQL CLI."""

@enforce_types
def __init__(
self,
table_names: List[str],
column_names: List[str],
column_map: Dict[str, List[str]],
prql_keywords: List[str],
self,
table_names: List[str],
column_names: List[str],
column_map: Dict[str, List[str]],
prql_keywords: List[str],
) -> None:
"""Initialise a completer instance.
Expand Down Expand Up @@ -44,7 +52,7 @@ def __init__(
self.previous_selection: Optional[List[str]] = None

def get_completions(
self, document: Document, complete_event: CompleteEvent
self, document: Document, complete_event: CompleteEvent
) -> Iterable[Completion]:
"""Retrieve completion options.
Expand All @@ -62,6 +70,16 @@ def get_completions(
The completion object.
"""
word_before_cursor = document.get_word_before_cursor(WORD=True)
# We're only interested in everything after the dot
if '.' in word_before_cursor and not word_before_cursor.endswith('.'):
word_before_cursor = word_before_cursor.split('.')[-1]

# Same with the colon
if ':' in word_before_cursor and not word_before_cursor.endswith(':'):
word_before_cursor = word_before_cursor.split(':')[-1]

log_to_file(word_before_cursor)

completion_operators = ["[", "+", ",", ":"]
possible_matches = {
"from": self.table_names,
Expand All @@ -78,11 +96,12 @@ def get_completions(
"filter": self.column_names,
"exit": [""],
}
matches_that_need_prev_word = {
builtin_matches = {
"show": ["tables", "columns", "connection"],
"side:": ["left", "inner", "right", "outer"],
"order:": ["asc", "desc"],
"by:": self.column_names,

}
# print(word_before_cursor)
for op in completion_operators:
Expand All @@ -95,29 +114,28 @@ def get_completions(
self.previous_selection = selection
# This can be reworked to a if not in operator. No pass required.
if (
word_before_cursor == "from"
or word_before_cursor == "join"
or word_before_cursor == "sort"
or word_before_cursor == "select"
or word_before_cursor == "columns"
or word_before_cursor == "show"
or word_before_cursor == ","
or word_before_cursor == "["
or word_before_cursor == "filter"
word_before_cursor == "from"
or word_before_cursor == "join"
or word_before_cursor == "sort"
or word_before_cursor == "select"
or word_before_cursor == "columns"
or word_before_cursor == "show"
or word_before_cursor == ","
or word_before_cursor == "["
or word_before_cursor == "filter"
):
pass
else:
for m in selection:
yield Completion(m, start_position=-len(word_before_cursor))
elif word_before_cursor in matches_that_need_prev_word:
selection = matches_that_need_prev_word[word_before_cursor]
# selection = [f"{x}" for x in selection]
elif word_before_cursor in builtin_matches:
selection = builtin_matches[word_before_cursor]
for m in selection:
yield Completion(m, start_position=0)
# If its an operator
elif (
len(word_before_cursor) >= 1
and word_before_cursor[-1] in completion_operators
len(word_before_cursor) >= 1
and word_before_cursor[-1] in completion_operators
):
selection = possible_matches[word_before_cursor[-1]]
self.previous_selection = selection
Expand All @@ -128,7 +146,7 @@ def get_completions(
selection = self.column_map[table]
self.previous_selection = selection
for m in selection:
yield Completion(m, start_position=0)
yield Completion(str(m), start_position=0)
# This goes back to the first if, this is the delayed completion finally completing
elif self.previous_selection:
selection = [
Expand Down
13 changes: 0 additions & 13 deletions pyprql/lang/prql.py
Expand Up @@ -145,19 +145,6 @@ def get_join_type(self) -> Optional[_JoinType]:
else:
return None

#
# def __post_init__(self) -> None:
# if isinstance(self.name, JoinType):
# temp = self.join_type
# self.join_type = self.name
# self.name = temp # type: ignore[assignment]
# if isinstance(self.join_type, Name):
# # Now we need to shift everything , since join_type is now our left_id
# temp = self.left_id
# self.left_id = self.join_type
# self.right_id = temp
# self.join_type = None


@dataclass
class SelectField(_Ast):
Expand Down
15 changes: 13 additions & 2 deletions tests/test_employee_examples.py
Expand Up @@ -78,7 +78,7 @@ def test_index(self):
self.run_query(text)

def test_cte1(self):
text = """
q = """
table newest_employees = (
from employees
sort tenure
Expand All @@ -89,6 +89,17 @@ def test_cte1(self):
join salary [id]
select [name, salary]
"""
sql = prql.to_sql(text)
sql = prql.to_sql(q)
print(sql)
# self.run_query(text)

def test_derive_issue_20(self):
q = '''
from employees
select [ benefits_cost , salary, title ]
aggregate by:title [ ttl_sal: salary | sum ]
derive [ avg_sal: salary ]
sort ttl_sal order:desc | take 25'''
sql = prql.to_sql(q,True)
assert sql.index("avg_sal") > 0
print(sql)

0 comments on commit ee8c7f5

Please sign in to comment.