Skip to content

Commit 74b960d

Browse files
authored
Merge pull request #1 from CEDARScript/fixes
Fixes and small changes
2 parents ef16908 + 68ef20d commit 74b960d

File tree

4 files changed

+74
-36
lines changed

4 files changed

+74
-36
lines changed

LICENSE

Whitespace-only changes.

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ classifiers = [
2222
]
2323
keywords = ["parser", "ast", "cedarscript", "code-editing", "refactoring", "code-analysis", "sql-like", "ai-assisted-development"]
2424
dependencies = [
25-
"cedarscript-grammar>=0.0.7",
25+
"cedarscript-grammar>=0.0.9",
2626
]
27-
requires-python = ">=3.12"
27+
requires-python = ">=3.8"
2828

2929
[project.urls]
3030
Homepage = "https://github.com/CEDARScript/cedarscript-ast-parser-python"

src/cedarscript_ast_parser/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.5"
1+
__version__ = "0.2.0"
22

33
from .cedarscript_ast_parser import (
44
CEDARScriptASTParser, ParseError, Command,

src/cedarscript_ast_parser/cedarscript_ast_parser.py

+71-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from enum import StrEnum, auto
22
from typing import TypeAlias, NamedTuple, Union
3+
from collections.abc import Sequence
34

45
from tree_sitter import Parser
56
import cedarscript_grammar
@@ -31,14 +32,23 @@ class BodyOrWhole(StrEnum):
3132

3233

3334
MarkerType = StrEnum('MarkerType', 'LINE VARIABLE FUNCTION CLASS')
34-
RelativePositionType = StrEnum('RelativePositionType', 'AT BEFORE AFTER INSIDE')
35+
RelativePositionType = StrEnum('RelativePositionType', 'AT BEFORE AFTER INSIDE_TOP INSIDE_BOTTOM')
36+
37+
38+
class MarkerCompatible:
39+
def as_marker(self) -> 'Marker':
40+
pass
3541

3642
@dataclass
37-
class Marker:
43+
class Marker(MarkerCompatible):
3844
type: MarkerType
3945
value: str
4046
offset: int | None = None
4147

48+
@property
49+
def as_marker(self) -> 'Marker':
50+
return self
51+
4252
def __str__(self):
4353
result = f"{self.type.value} '{self.value}'"
4454
if self.offset is not None:
@@ -59,7 +69,7 @@ def __str__(self):
5969
case RelativePositionType.AT:
6070
pass
6171
case _:
62-
result = f'{result} ({self.qualifier})'
72+
result = f'{result} ({self.qualifier.replace('_', ' ')})'
6373
return result
6474

6575

@@ -73,8 +83,8 @@ def __str__(self):
7383

7484

7585
MarkerOrSegment: TypeAlias = Marker | Segment
76-
Region: TypeAlias = BodyOrWhole | MarkerOrSegment
77-
RegionOrRelativeMarker: Region | RelativeMarker
86+
Region: TypeAlias = BodyOrWhole | Marker | Segment
87+
RegionOrRelativeMarker: BodyOrWhole | Marker | Segment | RelativeMarker
7888
# <file-or-identifier>
7989

8090

@@ -91,13 +101,18 @@ class SingleFileClause:
91101

92102

93103
@dataclass
94-
class IdentifierFromFile(SingleFileClause):
104+
class IdentifierFromFile(SingleFileClause, MarkerCompatible):
95105
where_clause: WhereClause
96-
identifier_type: str # VARIABLE, FUNCTION, CLASS
106+
identifier_type: MarkerType # VARIABLE, FUNCTION, CLASS (but not LINE)
97107
offset: int | None = None
98108

109+
@property
110+
def as_marker(self) -> Marker:
111+
# TODO Handle different values for field and operator in where_clause
112+
return Marker(self.identifier_type, self.where_clause.value, self.offset)
113+
99114
def __str__(self):
100-
result = f"{self.identifier_type.lower()} (self.where_clause)"
115+
result = f"{self.identifier_type.lower()} ({self.where_clause})"
101116
if self.offset is not None:
102117
result += f" at offset {self.offset}"
103118
return f"{result} from file {self.file_path}"
@@ -128,9 +143,13 @@ class DeleteClause(RegionClause):
128143

129144

130145
@dataclass
131-
class InsertClause:
146+
class InsertClause(MarkerCompatible):
132147
insert_position: RelativeMarker
133148

149+
@property
150+
def as_marker(self) -> RelativeMarker:
151+
return self.insert_position
152+
134153

135154
@dataclass
136155
class MoveClause(DeleteClause, InsertClause):
@@ -189,7 +208,7 @@ def files_to_change(self) -> tuple[str, ...]:
189208
class UpdateCommand(Command):
190209
target: FileOrIdentifierWithin
191210
action: EditingAction
192-
content: str | None = None
211+
content: str | tuple[Region, int | None] | None = None
193212

194213
@property
195214
def files_to_change(self) -> tuple[str, ...]:
@@ -241,7 +260,7 @@ def __init__(self):
241260

242261

243262
class CEDARScriptASTParser(_CEDARScriptASTParserBase):
244-
def parse_script(self, code_text: str) -> tuple[list[Command], list[ParseError]]:
263+
def parse_script(self, code_text: str) -> tuple[Sequence[Command], Sequence[ParseError]]:
245264
"""
246265
Parses the CEDARScript code and returns a tuple containing:
247266
- A list of Command objects if parsing is successful.
@@ -315,14 +334,14 @@ def _collect_parse_errors(self, node, code_text, command_ordinal: int) -> list[P
315334
errors.extend(self._collect_parse_errors(child, code_text, command_ordinal))
316335
return errors
317336

318-
def _get_expected_tokens(self, error_node) -> list[str]:
337+
def _get_expected_tokens(self, error_node) -> tuple[str]:
319338
"""
320339
Provides expected tokens based on the error_node's context.
321340
"""
322341
# Since Tree-sitter doesn't provide expected tokens directly,
323342
# you might need to implement this based on the grammar and error context.
324343
# For now, we'll return an empty list to simplify.
325-
return []
344+
return tuple()
326345

327346
def parse_command(self, node):
328347
match node.type:
@@ -341,7 +360,7 @@ def parse_command(self, node):
341360

342361
def parse_create_command(self, node):
343362
file_path = self.parse_singlefile_clause(self.find_first_by_type(node.children, 'singlefile_clause')).file_path
344-
content = self.parse_content_clause(self.find_first_by_type(node.children, 'content_clause'))
363+
content = self.parse_content(node)
345364
return CreateCommand(type='create', file_path=file_path, content=content)
346365

347366
def parse_rm_file_command(self, node):
@@ -356,7 +375,7 @@ def parse_mv_file_command(self, node):
356375
def parse_update_command(self, node):
357376
target = self.parse_update_target(node)
358377
action = self.parse_update_action(node)
359-
content = self.parse_update_content(node)
378+
content = self.parse_content(node)
360379
return UpdateCommand(type='update', target=target, action=action, content=content)
361380

362381
def parse_update_target(self, node):
@@ -377,7 +396,7 @@ def parse_update_target(self, node):
377396
raise ValueError(f"[parse_update_target] Invalid target: {invalid}")
378397

379398
def parse_identifier_from_file(self, node):
380-
identifier_type = node.children[0].type # FUNCTION, CLASS, or VARIABLE
399+
identifier_type = MarkerType(node.children[0].type.casefold())
381400
file_clause = self.find_first_by_type(node.named_children, 'singlefile_clause')
382401
where_clause = self.find_first_by_type(node.named_children, 'where_clause')
383402
offset_clause = self.find_first_by_type(node.named_children, 'offset_clause')
@@ -431,7 +450,7 @@ def parse_move_clause(self, node):
431450
destination = self.find_first_by_type(node.named_children, 'update_move_clause_destination')
432451
insert_clause = self.find_first_by_type(destination.named_children, 'insert_clause')
433452
insert_clause = self.parse_insert_clause(insert_clause)
434-
rel_indent = self.parse_relative_indentation(self.find_first_by_type(destination.named_children, 'relative_indentation'))
453+
rel_indent = self.parse_relative_indentation(destination)
435454
# TODO to_other_file
436455
return MoveClause(
437456
region=source,
@@ -460,7 +479,11 @@ def parse_region(self, node) -> Region:
460479
node = node.named_children[0]
461480
case 'relpos_bai':
462481
node = node.named_children[0]
463-
qualifier = RelativePositionType(node.child(0).type.casefold())
482+
main_type = node.child(0).type.casefold()
483+
match main_type:
484+
case 'inside':
485+
main_type += '_' + node.child(2).type.casefold()
486+
qualifier = RelativePositionType(main_type)
464487
node = node.named_children[0]
465488
case 'relpos_beforeafter':
466489
qualifier = RelativePositionType(node.child(0).type.casefold())
@@ -469,14 +492,14 @@ def parse_region(self, node) -> Region:
469492
node = node.named_children[0]
470493

471494
match node.type.casefold():
472-
case 'marker' | 'linemarker':
495+
case 'marker' | 'linemarker' | 'identifiermarker':
473496
result = self.parse_marker(node)
474497
case 'segment':
475498
result = self.parse_segment(node)
476499
case BodyOrWhole.BODY | BodyOrWhole.WHOLE as bow:
477500
result = BodyOrWhole(bow.lower())
478-
case _ as invalid:
479-
raise ValueError(f"[parse_region] Unexpected node type: {invalid}")
501+
case _:
502+
raise ValueError(f"Unexpected node type: {node.type}")
480503
if qualifier:
481504
result = RelativeMarker(qualifier=qualifier, type=result.type, value=result.value, offset=result.offset)
482505
return result
@@ -502,16 +525,24 @@ def parse_offset_clause(self, node):
502525
return None
503526
return int(self.find_first_by_type(node.children, 'number').text)
504527

505-
def parse_relative_indentation(self, node):
528+
def parse_relative_indentation(self, node) -> int:
529+
node = self.find_first_by_type(node.named_children, 'relative_indentation')
506530
if node is None:
507531
return None
508-
return int(self.find_first_by_type(node.children, 'number').text)
532+
return int(self.find_first_by_type(node.named_children, 'number').text)
533+
534+
def parse_content(self, node) -> str | tuple[Region, int | None]:
535+
content = self.find_first_by_type(node.named_children, ['content_clause', 'content_from_segment'])
536+
if not content:
537+
return None
538+
match content.type:
539+
case 'content_clause':
540+
return self.parse_content_clause(content) # str
541+
case 'content_from_segment':
542+
return self.parse_content_from_segment_clause(content) # tuple[Region, int]
543+
case _:
544+
raise ValueError(f"Invalid content type: {content.type}")
509545

510-
def parse_update_content(self, node):
511-
content_clause = self.find_first_by_type(node.children, 'content_clause')
512-
if content_clause:
513-
return self.parse_content_clause(content_clause)
514-
return None
515546

516547
def parse_singlefile_clause(self, node):
517548
if node is None or node.type != 'singlefile_clause':
@@ -521,9 +552,7 @@ def parse_singlefile_clause(self, node):
521552
raise ValueError("No file_path found in singlefile_clause")
522553
return SingleFileClause(file_path=self.parse_string(path_node))
523554

524-
def parse_content_clause(self, node):
525-
if node is None or node.type != 'content_clause':
526-
raise ValueError("Expected content_clause node")
555+
def parse_content_clause(self, node) -> str:
527556
child_type = ['string', 'relative_indent_block', 'multiline_string']
528557
content_node = self.find_first_by_type(node.children, child_type)
529558
if content_node is None:
@@ -535,6 +564,15 @@ def parse_content_clause(self, node):
535564
elif content_node.type == 'multiline_string':
536565
return self.parse_multiline_string(content_node)
537566

567+
def parse_content_from_segment_clause(self, node) -> tuple[Region, int | None]:
568+
child_type = ['marker_or_segment']
569+
content_node = self.find_first_by_type(node.children, child_type)
570+
# TODO parse relative indentation
571+
if content_node is None:
572+
raise ValueError("No content found in content_from_segment")
573+
rel_indent = self.parse_relative_indentation(node)
574+
return self.parse_region(content_node), rel_indent
575+
538576
def parse_to_value_clause(self, node):
539577
if node is None or node.type != 'to_value_clause':
540578
raise ValueError("Expected to_value_clause node")
@@ -561,7 +599,7 @@ def parse_string(self, node):
561599
def parse_multiline_string(self, node):
562600
return node.text.decode('utf8').strip("'''").strip('"""')
563601

564-
def parse_relative_indent_block(self, node):
602+
def parse_relative_indent_block(self, node) -> str:
565603
lines = []
566604
for line_node in node.children:
567605
if line_node.type == 'relative_indent_line':
@@ -572,7 +610,7 @@ def parse_relative_indent_block(self, node):
572610
lines.append(f"{' ' * (4 * indent)}{content.text}")
573611
return '\n'.join(lines)
574612

575-
def find_first_by_type(self, nodes: list[any], child_type):
613+
def find_first_by_type(self, nodes: Sequence[any], child_type):
576614
if isinstance(child_type, list):
577615
for child in nodes:
578616
if child.type in child_type:

0 commit comments

Comments
 (0)