diff --git a/crates/ide/src/code_action.rs b/crates/ide/src/code_action.rs index d2690077..e6ddcfec 100644 --- a/crates/ide/src/code_action.rs +++ b/crates/ide/src/code_action.rs @@ -434,11 +434,13 @@ mod handlers { mod add_instance_parens; mod add_missing_connections; mod add_missing_parameters; + mod convert_literal_base; mod convert_ordered_connections; mod remove_empty_port_connections; pub(crate) fn all() -> &'static [Handler] { &[ + convert_literal_base::convert_literal_base, add_missing_connections::add_missing_connections, add_missing_parameters::add_missing_parameters, convert_ordered_connections::convert_ordered_ports, @@ -510,6 +512,23 @@ mod tests { } fn apply_action_without_diagnostics(text: &str, action_name: &str) -> Option { + apply_action_without_diagnostics_by(text, |action| action.id.name == action_name) + } + + fn apply_action_without_diagnostics_with_label( + text: &str, + action_name: &str, + label: &str, + ) -> Option { + apply_action_without_diagnostics_by(text, |action| { + action.id.name == action_name && action.label == label + }) + } + + fn apply_action_without_diagnostics_by( + text: &str, + pred: impl Fn(&CodeAction) -> bool, + ) -> Option { let (db, file_id, offset) = db_with_file(text); let actions = code_action( &db, @@ -518,7 +537,7 @@ mod tests { CodeActionDiagnostics::default(), CodeActionResolveStrategy::All, ); - let action = actions.into_iter().find(|action| action.id.name == action_name)?; + let action = actions.into_iter().find(pred)?; let mut text = text.replace("/*caret*/", ""); let edit = action.source_change?.text_edits.remove(&file_id)?; edit.apply(&mut text); @@ -649,6 +668,90 @@ mod tests { ); } + #[test] + fn literal_base_converts_plain_decimal_to_sized_signed_hexadecimal() { + let text = "module top; localparam int value = /*caret*/42; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "convert_literal_base", + "Convert literal to hexadecimal", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam int value = 32'sh2a; endmodule\n"); + } + + #[test] + fn literal_base_preserves_plain_decimal_sign_bit() { + let text = "module top; localparam longint value = /*caret*/2147483648; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "convert_literal_base", + "Convert literal to hexadecimal", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam longint value = 33'sh80000000; endmodule\n"); + } + + #[test] + fn literal_base_preserves_size_and_signed_base() { + let text = "module top; localparam logic [7:0] value = /*caret*/8'sh2A; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "convert_literal_base", + "Convert literal to binary", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam logic [7:0] value = 8'sb101010; endmodule\n"); + } + + #[test] + fn literal_base_converts_unsized_based_literal_to_based_decimal() { + let text = "module top; localparam int value = /*caret*/'hff; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "convert_literal_base", + "Convert literal to decimal", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam int value = 'd255; endmodule\n"); + } + + #[test] + fn literal_base_preserves_unsized_signed_base() { + let text = "module top; localparam int value = /*caret*/'shff; endmodule\n"; + let fixed = apply_action_without_diagnostics_with_label( + text, + "convert_literal_base", + "Convert literal to decimal", + ) + .unwrap(); + + assert_eq!(fixed, "module top; localparam int value = 'sd255; endmodule\n"); + } + + #[test] + fn literal_base_does_not_offer_decimal_for_unknown_bits() { + let labels = action_labels_without_diagnostics( + "module top; logic [3:0] value = /*caret*/'hx; endmodule\n", + ); + + assert!(labels.iter().any(|label| label == "Convert literal to binary")); + assert!(!labels.iter().any(|label| label == "Convert literal to decimal")); + } + + #[test] + fn literal_base_is_not_available_for_string_literals() { + let labels = action_labels_without_diagnostics( + "module top; string value = /*caret*/\"42\"; endmodule\n", + ); + + assert!(!labels.iter().any(|label| label.starts_with("Convert literal to "))); + } + #[test] fn missing_connection_repair_fills_named_connections() { let text = "module child(input a, input b); endmodule\nmodule top; child u(/*caret*/.a()); endmodule\n"; diff --git a/crates/ide/src/code_action/handlers/convert_literal_base.rs b/crates/ide/src/code_action/handlers/convert_literal_base.rs new file mode 100644 index 00000000..95d75e83 --- /dev/null +++ b/crates/ide/src/code_action/handlers/convert_literal_base.rs @@ -0,0 +1,170 @@ +use syntax::{ + LiteralBase, SVInt, + ast::{self, AstNode}, + has_text_range::HasTextRange, +}; +use utils::text_edit::TextRange; + +use crate::code_action::{CodeActionCollector, CodeActionCtx, CodeActionId, CodeActionKind}; + +const ID: CodeActionId = CodeActionId { + name: "convert_literal_base", + kind: CodeActionKind::RefactorRewrite, + repair: None, +}; + +pub(super) fn convert_literal_base( + collector: &mut CodeActionCollector, + ctx: &CodeActionCtx, +) -> Option<()> { + let literal = literal_at(ctx)?; + + for target_base in IntegerBase::ALL { + if target_base == literal.base { + continue; + } + + let Some(replacement) = literal.render(target_base) else { + continue; + }; + let label = format!("Convert literal to {}", target_base.label()); + collector.add(ID, label, literal.range, |builder| { + builder.replace(literal.range, replacement); + }); + } + + Some(()) +} + +#[derive(Debug)] +struct IntegerLiteral { + range: TextRange, + value: SVInt, + base: IntegerBase, + notation: IntegerLiteralNotation, +} + +impl IntegerLiteral { + fn render(&self, base: IntegerBase) -> Option { + if base == IntegerBase::Dec && self.value.has_unknown() { + return None; + } + + let digits = self.value.serialize(base.radix()); + Some(match &self.notation { + IntegerLiteralNotation::PlainDecimal => { + format!("{}'s{}{}", self.plain_decimal_width(), base.specifier(), digits) + } + IntegerLiteralNotation::Based { size: Some(size), signed } => { + format!("{size}'{}{}{}", signed_specifier(*signed), base.specifier(), digits) + } + IntegerLiteralNotation::Based { size: None, signed } => { + format!("'{}{}{}", signed_specifier(*signed), base.specifier(), digits) + } + }) + } + + fn plain_decimal_width(&self) -> usize { + let width = self.value.get_bit_width(); + if width < 32 { + 32 + } else if self.value.is_signed() { + width + } else { + width + 1 + } + } +} + +#[derive(Debug)] +enum IntegerLiteralNotation { + PlainDecimal, + Based { size: Option, signed: bool }, +} + +fn signed_specifier(signed: bool) -> &'static str { + if signed { "s" } else { "" } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum IntegerBase { + Bin, + Oct, + Dec, + Hex, +} + +impl IntegerBase { + const ALL: [Self; 4] = [Self::Bin, Self::Oct, Self::Dec, Self::Hex]; + + fn from_literal_base(base: LiteralBase) -> Self { + match base { + LiteralBase::Bin => Self::Bin, + LiteralBase::Oct => Self::Oct, + LiteralBase::Dec => Self::Dec, + LiteralBase::Hex => Self::Hex, + } + } + + fn radix(self) -> usize { + match self { + Self::Bin => 2, + Self::Oct => 8, + Self::Dec => 10, + Self::Hex => 16, + } + } + + fn specifier(self) -> &'static str { + match self { + Self::Bin => "b", + Self::Oct => "o", + Self::Dec => "d", + Self::Hex => "h", + } + } + + fn label(self) -> &'static str { + match self { + Self::Bin => "binary", + Self::Oct => "octal", + Self::Dec => "decimal", + Self::Hex => "hexadecimal", + } + } +} + +fn literal_at(ctx: &CodeActionCtx) -> Option { + if let Some(literal) = + ctx.find_node_at_offset::().and_then(integer_vector_literal) + { + return Some(literal); + } + + let literal = ctx.find_node_at_offset::()?; + let ast::LiteralExpression::IntegerLiteralExpression(integer) = literal else { + return None; + }; + + let token = integer.child_token(0)?; + Some(IntegerLiteral { + range: integer.text_range()?, + value: token.int()?, + base: IntegerBase::Dec, + notation: IntegerLiteralNotation::PlainDecimal, + }) +} + +fn integer_vector_literal(literal: ast::IntegerVectorExpression) -> Option { + let base = literal.base()?; + let value = literal.value()?; + Some(IntegerLiteral { + range: literal.syntax().text_range()?, + value: value.int()?, + base: IntegerBase::from_literal_base(base.base()?), + notation: IntegerLiteralNotation::Based { + size: literal.size().map(|size| size.raw_text().to_string()), + signed: base.raw_text().as_bytes().iter().any(|byte| byte.eq_ignore_ascii_case(&b's')), + }, + }) +}