Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/ruby-rbs/examples/locations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ruby_rbs::node::{Node, parse};

fn main() {
let rbs_code = r#"class Foo[T] < Bar end"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
if let Node::Class(class) = declaration {
Expand Down
87 changes: 46 additions & 41 deletions rust/ruby-rbs/src/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use std::ptr::NonNull;
/// ```rust
/// use ruby_rbs::node::parse;
/// let rbs_code = r#"type foo = "hello""#;
/// let signature = parse(rbs_code.as_bytes());
/// let signature = parse(rbs_code);
/// assert!(signature.is_ok(), "Failed to parse RBS signature");
/// ```
pub fn parse(rbs_code: &[u8]) -> Result<SignatureNode<'_>, String> {
pub fn parse(rbs_code: &str) -> Result<SignatureNode<'_>, String> {
unsafe {
let start_ptr = rbs_code.as_ptr().cast::<std::os::raw::c_char>();
let end_ptr = start_ptr.add(rbs_code.len());
Expand Down Expand Up @@ -253,17 +253,29 @@ impl<'a> RBSString<'a> {
}

#[must_use]
#[allow(clippy::unnecessary_cast)]
pub fn as_bytes(&self) -> &[u8] {
unsafe {
let s = *self.pointer;
std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize)
}
}

#[must_use]
pub fn as_str(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
}
}

impl std::fmt::Display for RBSString<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

impl SymbolNode<'_> {
#[must_use]
pub fn name(&self) -> &[u8] {
pub fn as_bytes(&self) -> &[u8] {
unsafe {
let constant_ptr = rbs_constant_pool_id_to_constant(
&(*self.parser.as_ptr()).constant_pool,
Expand All @@ -277,6 +289,17 @@ impl SymbolNode<'_> {
std::slice::from_raw_parts(constant.start, constant.length)
}
}

#[must_use]
pub fn as_str(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
}
}

impl std::fmt::Display for SymbolNode<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

#[cfg(test)]
Expand All @@ -286,37 +309,34 @@ mod tests {
#[test]
fn test_parse_error_contains_actual_message() {
let rbs_code = "class { end";
let result = parse(rbs_code.as_bytes());
let result = parse(rbs_code);
let error_message = result.unwrap_err();
assert_eq!(error_message, "expected one of class/module/constant name");
}

#[test]
fn test_parse() {
let rbs_code = r#"type foo = "hello""#;
let signature = parse(rbs_code.as_bytes());
let signature = parse(rbs_code);
assert!(signature.is_ok(), "Failed to parse RBS signature");

let rbs_code2 = r#"class Foo end"#;
let signature2 = parse(rbs_code2.as_bytes());
let signature2 = parse(rbs_code2);
assert!(signature2.is_ok(), "Failed to parse RBS signature");
}

#[test]
fn test_parse_integer() {
let rbs_code = r#"type foo = 1"#;
let signature = parse(rbs_code.as_bytes());
let signature = parse(rbs_code);
assert!(signature.is_ok(), "Failed to parse RBS signature");

let signature_node = signature.unwrap();
if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap()
&& let Node::LiteralType(literal) = node.type_()
&& let Node::Integer(integer) = literal.literal()
{
assert_eq!(
"1".to_string(),
String::from_utf8(integer.string_representation().as_bytes().to_vec()).unwrap()
);
assert_eq!(integer.string_representation().as_str(), "1");
} else {
panic!("No literal type node found");
}
Expand All @@ -326,7 +346,7 @@ mod tests {
fn test_rbs_hash_via_record_type() {
// RecordType stores its fields in an RBSHash via all_fields()
let rbs_code = r#"type foo = { name: String, age: Integer }"#;
let signature = parse(rbs_code.as_bytes());
let signature = parse(rbs_code);
assert!(signature.is_ok(), "Failed to parse RBS signature");

let signature_node = signature.unwrap();
Expand All @@ -350,10 +370,10 @@ mod tests {
panic!("Expected ClassInstanceType");
};

let key_name = String::from_utf8(sym.name().to_vec()).unwrap();
let key_name = sym.to_string();
let type_name_node = class_type.name();
let type_name_sym = type_name_node.name();
let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap();
let type_name = type_name_sym.to_string();
field_types.push((key_name, type_name));
}

Expand Down Expand Up @@ -384,28 +404,19 @@ mod tests {
}

fn visit_class_node(&mut self, node: &ClassNode) {
self.visited.push(format!(
"class:{}",
String::from_utf8(node.name().name().name().to_vec()).unwrap()
));
self.visited.push(format!("class:{}", node.name().name()));

crate::node::visit_class_node(self, node);
}

fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
self.visited.push(format!(
"type:{}",
String::from_utf8(node.name().name().name().to_vec()).unwrap()
));
self.visited.push(format!("type:{}", node.name().name()));

crate::node::visit_class_instance_type_node(self, node);
}

fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
self.visited.push(format!(
"super:{}",
String::from_utf8(node.name().name().name().to_vec()).unwrap()
));
self.visited.push(format!("super:{}", node.name().name()));

crate::node::visit_class_super_node(self, node);
}
Expand All @@ -419,10 +430,7 @@ mod tests {
}

fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
self.visited.push(format!(
"method:{}",
String::from_utf8(node.name().name().to_vec()).unwrap()
));
self.visited.push(format!("method:{}", node.name()));

crate::node::visit_method_definition_node(self, node);
}
Expand All @@ -434,10 +442,7 @@ mod tests {
}

fn visit_symbol_node(&mut self, node: &SymbolNode) {
self.visited.push(format!(
"symbol:{}",
String::from_utf8(node.name().to_vec()).unwrap()
));
self.visited.push(format!("symbol:{node}"));

crate::node::visit_symbol_node(self, node);
}
Expand All @@ -449,7 +454,7 @@ mod tests {
end
"#;

let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let mut visitor = Visitor {
visited: Vec::new(),
Expand Down Expand Up @@ -482,7 +487,7 @@ mod tests {
#[test]
fn test_node_location_ranges() {
let rbs_code = r#"type foo = 1"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::TypeAlias(type_alias) = declaration else {
Expand Down Expand Up @@ -510,7 +515,7 @@ mod tests {
#[test]
fn test_sub_locations() {
let rbs_code = r#"class Foo < Bar end"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::Class(class) = declaration else {
Expand Down Expand Up @@ -545,7 +550,7 @@ mod tests {
#[test]
fn test_type_alias_sub_locations() {
let rbs_code = r#"type foo = String"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::TypeAlias(type_alias) = declaration else {
Expand Down Expand Up @@ -573,7 +578,7 @@ mod tests {
#[test]
fn test_module_sub_locations() {
let rbs_code = r#"module Foo[T] : Bar end"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::Module(module) = declaration else {
Expand Down Expand Up @@ -626,7 +631,7 @@ mod tests {
class Bar[out T, in U, V]
end
"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let declarations: Vec<_> = signature.declarations().iter().collect();

Expand Down Expand Up @@ -706,7 +711,7 @@ mod tests {
attr_writer email(@email): String
end
"#;
let signature = parse(rbs_code.as_bytes()).unwrap();
let signature = parse(rbs_code).unwrap();

let Node::Class(class) = signature.declarations().iter().next().unwrap() else {
panic!("Expected Class");
Expand Down
2 changes: 1 addition & 1 deletion rust/ruby-rbs/tests/sanity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn all_included_rbs_can_be_parsed() {
for file in &files {
let content = std::fs::read_to_string(file).unwrap();

if let Err(e) = parse(content.as_bytes()) {
if let Err(e) = parse(&content) {
failures.push(format!("{}: {}", file.display(), e));
}
}
Expand Down