Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/ide_assists/src/handlers/move_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ pub(crate) fn move_bounds_to_where_clause(acc: &mut Assists, ctx: &AssistContext

for type_param in type_param_list.type_params() {
if let Some(tbl) = type_param.type_bound_list() {
if let Some(predicate) = build_predicate(type_param.clone()) {
where_clause.add_predicate(predicate.clone_for_update())
if let Some(predicate) = build_predicate(type_param) {
where_clause.add_predicate(predicate)
}
tbl.remove()
}
Expand All @@ -69,7 +69,7 @@ fn build_predicate(param: ast::TypeParam) -> Option<ast::WherePred> {
make::path_unqualified(segment)
};
let predicate = make::where_pred(path, param.type_bound_list()?.bounds());
Some(predicate)
Some(predicate.clone_for_update())
}

#[cfg(test)]
Expand Down
41 changes: 20 additions & 21 deletions crates/syntax/src/ast/edit_in_place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use parser::T;
use crate::{
ast,
ted::{self, Position},
AstNode, Direction, SyntaxElement,
AstNode, Direction,
};

use super::NameOwner;
Expand All @@ -21,11 +21,11 @@ impl GenericParamsOwnerEdit for ast::Fn {
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(ty) = self.ret_type() {
Position::after(ty.syntax().clone())
Position::after(ty.syntax())
} else if let Some(param_list) = self.param_list() {
Position::after(param_list.syntax().clone())
Position::after(param_list.syntax())
} else {
Position::last_child_of(self.syntax().clone())
Position::last_child_of(self.syntax())
};
create_where_clause(position)
}
Expand All @@ -37,9 +37,9 @@ impl GenericParamsOwnerEdit for ast::Impl {
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(items) = self.assoc_item_list() {
Position::before(items.syntax().clone())
Position::before(items.syntax())
} else {
Position::last_child_of(self.syntax().clone())
Position::last_child_of(self.syntax())
};
create_where_clause(position)
}
Expand All @@ -51,9 +51,9 @@ impl GenericParamsOwnerEdit for ast::Trait {
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(items) = self.assoc_item_list() {
Position::before(items.syntax().clone())
Position::before(items.syntax())
} else {
Position::last_child_of(self.syntax().clone())
Position::last_child_of(self.syntax())
};
create_where_clause(position)
}
Expand All @@ -69,13 +69,13 @@ impl GenericParamsOwnerEdit for ast::Struct {
ast::FieldList::TupleFieldList(it) => Some(it),
});
let position = if let Some(tfl) = tfl {
Position::after(tfl.syntax().clone())
Position::after(tfl.syntax())
} else if let Some(gpl) = self.generic_param_list() {
Position::after(gpl.syntax().clone())
Position::after(gpl.syntax())
} else if let Some(name) = self.name() {
Position::after(name.syntax().clone())
Position::after(name.syntax())
} else {
Position::last_child_of(self.syntax().clone())
Position::last_child_of(self.syntax())
};
create_where_clause(position)
}
Expand All @@ -87,11 +87,11 @@ impl GenericParamsOwnerEdit for ast::Enum {
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(gpl) = self.generic_param_list() {
Position::after(gpl.syntax().clone())
Position::after(gpl.syntax())
} else if let Some(name) = self.name() {
Position::after(name.syntax().clone())
Position::after(name.syntax())
} else {
Position::last_child_of(self.syntax().clone())
Position::last_child_of(self.syntax())
};
create_where_clause(position)
}
Expand All @@ -100,19 +100,18 @@ impl GenericParamsOwnerEdit for ast::Enum {
}

fn create_where_clause(position: Position) {
let where_clause: SyntaxElement =
make::where_clause(empty()).clone_for_update().syntax().clone().into();
ted::insert(position, where_clause);
let where_clause = make::where_clause(empty()).clone_for_update();
ted::insert(position, where_clause.syntax());
}

impl ast::WhereClause {
pub fn add_predicate(&self, predicate: ast::WherePred) {
if let Some(pred) = self.predicates().last() {
if !pred.syntax().siblings_with_tokens(Direction::Next).any(|it| it.kind() == T![,]) {
ted::append_child_raw(self.syntax().clone(), make::token(T![,]));
ted::append_child_raw(self.syntax(), make::token(T![,]));
}
}
ted::append_child(self.syntax().clone(), predicate.syntax().clone())
ted::append_child(self.syntax(), predicate.syntax())
}
}

Expand All @@ -123,7 +122,7 @@ impl ast::TypeBoundList {
{
ted::remove_all(colon..=self.syntax().clone().into())
} else {
ted::remove(self.syntax().clone())
ted::remove(self.syntax())
}
}
}
65 changes: 46 additions & 19 deletions crates/syntax/src/ted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,33 @@ use parser::T;

use crate::{ast::make, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken};

/// Utility trait to allow calling `ted` functions with references or owned
/// nodes. Do not use outside of this module.
pub trait Element {
fn syntax_element(self) -> SyntaxElement;
}

impl<E: Element + Clone> Element for &'_ E {
fn syntax_element(self) -> SyntaxElement {
self.clone().syntax_element()
}
}
impl Element for SyntaxElement {
fn syntax_element(self) -> SyntaxElement {
self
}
}
impl Element for SyntaxNode {
fn syntax_element(self) -> SyntaxElement {
self.into()
}
}
impl Element for SyntaxToken {
fn syntax_element(self) -> SyntaxElement {
self.into()
}
}

#[derive(Debug)]
pub struct Position {
repr: PositionRepr,
Expand All @@ -20,24 +47,24 @@ enum PositionRepr {
}

impl Position {
pub fn after(elem: impl Into<SyntaxElement>) -> Position {
let repr = PositionRepr::After(elem.into());
pub fn after(elem: impl Element) -> Position {
let repr = PositionRepr::After(elem.syntax_element());
Position { repr }
}
pub fn before(elem: impl Into<SyntaxElement>) -> Position {
let elem = elem.into();
pub fn before(elem: impl Element) -> Position {
let elem = elem.syntax_element();
let repr = match elem.prev_sibling_or_token() {
Some(it) => PositionRepr::After(it),
None => PositionRepr::FirstChild(elem.parent().unwrap()),
};
Position { repr }
}
pub fn first_child_of(node: impl Into<SyntaxNode>) -> Position {
let repr = PositionRepr::FirstChild(node.into());
pub fn first_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
let repr = PositionRepr::FirstChild(node.clone().into());
Position { repr }
}
pub fn last_child_of(node: impl Into<SyntaxNode>) -> Position {
let node = node.into();
pub fn last_child_of(node: &(impl Into<SyntaxNode> + Clone)) -> Position {
let node = node.clone().into();
let repr = match node.last_child_or_token() {
Some(it) => PositionRepr::After(it),
None => PositionRepr::FirstChild(node),
Expand All @@ -46,11 +73,11 @@ impl Position {
}
}

pub fn insert(position: Position, elem: impl Into<SyntaxElement>) {
insert_all(position, vec![elem.into()])
pub fn insert(position: Position, elem: impl Element) {
insert_all(position, vec![elem.syntax_element()])
}
pub fn insert_raw(position: Position, elem: impl Into<SyntaxElement>) {
insert_all_raw(position, vec![elem.into()])
pub fn insert_raw(position: Position, elem: impl Element) {
insert_all_raw(position, vec![elem.syntax_element()])
}
pub fn insert_all(position: Position, mut elements: Vec<SyntaxElement>) {
if let Some(first) = elements.first() {
Expand All @@ -73,17 +100,17 @@ pub fn insert_all_raw(position: Position, elements: Vec<SyntaxElement>) {
parent.splice_children(index..index, elements);
}

pub fn remove(elem: impl Into<SyntaxElement>) {
let elem = elem.into();
pub fn remove(elem: impl Element) {
let elem = elem.syntax_element();
remove_all(elem.clone()..=elem)
}
pub fn remove_all(range: RangeInclusive<SyntaxElement>) {
replace_all(range, Vec::new())
}

pub fn replace(old: impl Into<SyntaxElement>, new: impl Into<SyntaxElement>) {
let old = old.into();
replace_all(old.clone()..=old, vec![new.into()])
pub fn replace(old: impl Element, new: impl Element) {
let old = old.syntax_element();
replace_all(old.clone()..=old, vec![new.syntax_element()])
}
pub fn replace_all(range: RangeInclusive<SyntaxElement>, new: Vec<SyntaxElement>) {
let start = range.start().index();
Expand All @@ -92,11 +119,11 @@ pub fn replace_all(range: RangeInclusive<SyntaxElement>, new: Vec<SyntaxElement>
parent.splice_children(start..end + 1, new)
}

pub fn append_child(node: impl Into<SyntaxNode>, child: impl Into<SyntaxElement>) {
pub fn append_child(node: &(impl Into<SyntaxNode> + Clone), child: impl Element) {
let position = Position::last_child_of(node);
insert(position, child)
}
pub fn append_child_raw(node: impl Into<SyntaxNode>, child: impl Into<SyntaxElement>) {
pub fn append_child_raw(node: &(impl Into<SyntaxNode> + Clone), child: impl Element) {
let position = Position::last_child_of(node);
insert_raw(position, child)
}
Expand Down