Skip to content

Commit

Permalink
Introduce DiagnosticsReporter in place of check_diagnostics (#2018)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkaput committed Feb 3, 2023
1 parent 84608af commit 18b55fe
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 81 deletions.
165 changes: 110 additions & 55 deletions crates/cairo-lang-compiler/src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,80 +17,135 @@ mod test;
#[error("Compilation failed.")]
pub struct DiagnosticsError;

/// Checks if there are diagnostics and reports them to the provided callback as strings.
/// Returns `true` if diagnostics were found.
pub fn check_diagnostics(
db: &mut RootDatabase,
on_diagnostic: Option<&mut (dyn FnMut(String) + '_)>,
) -> bool {
let mut ignore = |_| ();
let on_diagnostic = on_diagnostic.unwrap_or(&mut ignore);

let mut found_diagnostics = false;
for crate_id in db.crates() {
let Ok(module_file) = db.module_main_file(ModuleId::CrateRoot(crate_id)) else {
found_diagnostics = true;
on_diagnostic("Failed to get main module file".to_string());
continue;
};

if db.file_content(module_file).is_none() {
match db.lookup_intern_file(module_file) {
FileLongId::OnDisk(path) => {
on_diagnostic(format!("{} not found\n", path.display()))
}
FileLongId::Virtual(_) => panic!("Missing virtual file."),
trait DiagnosticCallback {
fn on_diagnostic(&mut self, diagnostic: String);
}

impl<'a> DiagnosticCallback for Option<Box<dyn DiagnosticCallback + 'a>> {
fn on_diagnostic(&mut self, diagnostic: String) {
if let Some(callback) = self {
callback.on_diagnostic(diagnostic)
}
}
}

/// Collects compilation diagnostics and presents them in preconfigured way.
pub struct DiagnosticsReporter<'a> {
callback: Option<Box<dyn DiagnosticCallback + 'a>>,
}

impl DiagnosticsReporter<'static> {
/// Create a reporter which does not print or collect diagnostics at all.
pub fn ignoring() -> Self {
Self { callback: None }
}

/// Create a reporter which prints all diagnostics to [`std::io::Stderr`].
pub fn stderr() -> Self {
Self::callback(|diagnostic| {
eprint!("{diagnostic}");
})
}
}

impl<'a> DiagnosticsReporter<'a> {
// NOTE(mkaput): If Rust will ever have intersection types, one could write
// impl<F> DiagnosticCallback for F where F: FnMut(String)
// and `new` could accept regular functions without need for this separate method.
/// Create a reporter which calls `callback` for each diagnostic.
pub fn callback(callback: impl FnMut(String) + 'a) -> Self {
struct Func<F>(F);

impl<F> DiagnosticCallback for Func<F>
where
F: FnMut(String),
{
fn on_diagnostic(&mut self, diagnostic: String) {
(self.0)(diagnostic)
}
found_diagnostics = true;
}

for module_id in &*db.crate_modules(crate_id) {
for file_id in db.module_files(*module_id).unwrap_or_default() {
let diag = db.file_syntax_diagnostics(file_id);
if !diag.get_all().is_empty() {
found_diagnostics = true;
on_diagnostic(diag.format(db));
Self::new(Func(callback))
}

/// Create a reporter which appends all diagnostics to provided string.
pub fn write_to_string(string: &'a mut String) -> Self {
Self::callback(|diagnostic| {
string.push_str(&diagnostic);
})
}

/// Create a reporter which calls [`DiagnosticCallback::on_diagnostic`].
fn new(callback: impl DiagnosticCallback + 'a) -> Self {
Self { callback: Some(Box::new(callback)) }
}

/// Checks if there are diagnostics and reports them to the provided callback as strings.
/// Returns `true` if diagnostics were found.
pub fn check(&mut self, db: &mut RootDatabase) -> bool {
let mut found_diagnostics = false;
for crate_id in db.crates() {
let Ok(module_file) = db.module_main_file(ModuleId::CrateRoot(crate_id)) else {
found_diagnostics = true;
self.callback.on_diagnostic("Failed to get main module file".to_string());
continue;
};

if db.file_content(module_file).is_none() {
match db.lookup_intern_file(module_file) {
FileLongId::OnDisk(path) => {
self.callback.on_diagnostic(format!("{} not found\n", path.display()))
}
FileLongId::Virtual(_) => panic!("Missing virtual file."),
}
found_diagnostics = true;
}

if let Ok(diag) = db.module_semantic_diagnostics(*module_id) {
if !diag.get_all().is_empty() {
found_diagnostics = true;
on_diagnostic(diag.format(db));
for module_id in &*db.crate_modules(crate_id) {
for file_id in db.module_files(*module_id).unwrap_or_default() {
let diag = db.file_syntax_diagnostics(file_id);
if !diag.get_all().is_empty() {
found_diagnostics = true;
self.callback.on_diagnostic(diag.format(db));
}
}
}

if let Ok(diag) = db.module_lowering_diagnostics(*module_id) {
if !diag.get_all().is_empty() {
found_diagnostics = true;
on_diagnostic(diag.format(db));
if let Ok(diag) = db.module_semantic_diagnostics(*module_id) {
if !diag.get_all().is_empty() {
found_diagnostics = true;
self.callback.on_diagnostic(diag.format(db));
}
}

if let Ok(diag) = db.module_lowering_diagnostics(*module_id) {
if !diag.get_all().is_empty() {
found_diagnostics = true;
self.callback.on_diagnostic(diag.format(db));
}
}
}
}
found_diagnostics
}
found_diagnostics
}

/// Checks if there are diagnostics and reports them to the provided callback as strings.
/// Returns `Err` if diagnostics were found.
pub fn ensure_diagnostics(
db: &mut RootDatabase,
on_diagnostic: Option<&mut (dyn FnMut(String) + '_)>,
) -> Result<(), DiagnosticsError> {
if check_diagnostics(db, on_diagnostic) { Err(DiagnosticsError) } else { Ok(()) }
}

pub fn check_and_eprint_diagnostics(db: &mut RootDatabase) -> bool {
check_diagnostics(db, Some(&mut eprint_diagnostic))
/// Checks if there are diagnostics and reports them to the provided callback as strings.
/// Returns `Err` if diagnostics were found.
pub fn ensure(&mut self, db: &mut RootDatabase) -> Result<(), DiagnosticsError> {
if self.check(db) { Err(DiagnosticsError) } else { Ok(()) }
}
}

pub fn eprint_diagnostic(diag: String) {
eprint!("{diag}");
impl Default for DiagnosticsReporter<'static> {
fn default() -> Self {
DiagnosticsReporter::stderr()
}
}

/// Returns a string with all the diagnostics in the db.
///
/// This is a shortcut for `DiagnosticsReporter::write_to_string(&mut string).check(db)`.
pub fn get_diagnostics_as_string(db: &mut RootDatabase) -> String {
let mut diagnostics = String::default();
check_diagnostics(db, Some(&mut |s| diagnostics += &s));
DiagnosticsReporter::write_to_string(&mut diagnostics).check(db);
diagnostics
}
18 changes: 9 additions & 9 deletions crates/cairo-lang-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@ use cairo_lang_sierra_generator::db::SierraGenGroup;
use cairo_lang_sierra_generator::replace_ids::replace_sierra_ids_in_program;

use crate::db::RootDatabase;
use crate::diagnostics::{ensure_diagnostics, eprint_diagnostic};
use crate::diagnostics::DiagnosticsReporter;
use crate::project::{get_main_crate_ids_from_project, setup_project, ProjectConfig};

pub mod db;
pub mod diagnostics;
pub mod project;

/// Configuration for the compiler.
pub struct CompilerConfig {
pub on_diagnostic: Option<Box<dyn FnMut(String)>>,
pub struct CompilerConfig<'c> {
pub diagnostics_reporter: DiagnosticsReporter<'c>,

/// Replaces sierra ids with human-readable ones.
pub replace_ids: bool,
}

/// The default compiler configuration.
impl Default for CompilerConfig {
impl Default for CompilerConfig<'static> {
fn default() -> Self {
CompilerConfig { on_diagnostic: Some(Box::new(eprint_diagnostic)), replace_ids: false }
CompilerConfig { diagnostics_reporter: DiagnosticsReporter::default(), replace_ids: false }
}
}

Expand All @@ -49,7 +49,7 @@ pub type SierraProgram = Arc<Program>;
/// * `Err(anyhow::Error)` - Compilation failed.
pub fn compile_cairo_project_at_path(
path: &Path,
compiler_config: CompilerConfig,
compiler_config: CompilerConfig<'_>,
) -> Result<SierraProgram> {
let mut db = RootDatabase::builder().detect_corelib().build()?;
let main_crate_ids = setup_project(&mut db, path)?;
Expand All @@ -67,7 +67,7 @@ pub fn compile_cairo_project_at_path(
/// * `Err(anyhow::Error)` - Compilation failed.
pub fn compile(
project_config: ProjectConfig,
compiler_config: CompilerConfig,
compiler_config: CompilerConfig<'_>,
) -> Result<SierraProgram> {
let mut db = RootDatabase::builder().with_project_config(project_config.clone()).build()?;
let main_crate_ids = get_main_crate_ids_from_project(&mut db, &project_config);
Expand All @@ -89,9 +89,9 @@ pub fn compile(
pub fn compile_prepared_db(
db: &mut RootDatabase,
main_crate_ids: Vec<CrateId>,
mut compiler_config: CompilerConfig,
mut compiler_config: CompilerConfig<'_>,
) -> Result<SierraProgram> {
ensure_diagnostics(db, compiler_config.on_diagnostic.as_deref_mut())?;
compiler_config.diagnostics_reporter.ensure(db)?;

let mut sierra_program = db
.get_sierra_program(main_crate_ids)
Expand Down
4 changes: 2 additions & 2 deletions crates/cairo-lang-runner/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::path::Path;

use anyhow::{Context, Ok};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::check_and_eprint_diagnostics;
use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
use cairo_lang_compiler::project::setup_project;
use cairo_lang_diagnostics::ToOption;
use cairo_lang_runner::SierraCasmRunner;
Expand Down Expand Up @@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {

let main_crate_ids = setup_project(db, Path::new(&args.path))?;

if check_and_eprint_diagnostics(db) {
if DiagnosticsReporter::stderr().check(db) {
anyhow::bail!("failed to compile: {}", args.path);
}

Expand Down
11 changes: 5 additions & 6 deletions crates/cairo-lang-starknet/src/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::path::Path;

use anyhow::{ensure, Context, Result};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::ensure_diagnostics;
use cairo_lang_compiler::project::setup_project;
use cairo_lang_compiler::CompilerConfig;
use cairo_lang_defs::ids::TopLevelLanguageElementId;
Expand Down Expand Up @@ -93,7 +92,7 @@ pub struct ContractEntryPoint {
/// Compile the contract given by path.
///
/// Errors if no contracts or more than 1 are found.
pub fn compile_path(path: &Path, compiler_config: CompilerConfig) -> Result<ContractClass> {
pub fn compile_path(path: &Path, compiler_config: CompilerConfig<'_>) -> Result<ContractClass> {
let mut db = RootDatabase::builder().detect_corelib().with_starknet().build()?;

let main_crate_ids = setup_project(&mut db, Path::new(&path))?;
Expand All @@ -107,7 +106,7 @@ pub fn compile_path(path: &Path, compiler_config: CompilerConfig) -> Result<Cont
fn compile_only_contract_in_prepared_db(
db: &mut RootDatabase,
main_crate_ids: Vec<CrateId>,
compiler_config: CompilerConfig,
compiler_config: CompilerConfig<'_>,
) -> Result<ContractClass> {
let contracts = find_contracts(db, &main_crate_ids);
ensure!(!contracts.is_empty(), "Contract not found.");
Expand All @@ -133,9 +132,9 @@ fn compile_only_contract_in_prepared_db(
pub fn compile_prepared_db(
db: &mut RootDatabase,
contracts: &[&ContractDeclaration],
mut compiler_config: CompilerConfig,
mut compiler_config: CompilerConfig<'_>,
) -> Result<Vec<ContractClass>> {
ensure_diagnostics(db, compiler_config.on_diagnostic.as_deref_mut())?;
compiler_config.diagnostics_reporter.ensure(db)?;

contracts
.iter()
Expand All @@ -153,7 +152,7 @@ pub fn compile_prepared_db(
fn compile_contract_with_prepared_and_checked_db(
db: &mut RootDatabase,
contract: &ContractDeclaration,
compiler_config: &CompilerConfig,
compiler_config: &CompilerConfig<'_>,
) -> Result<ContractClass> {
let external_functions: Vec<_> = get_module_functions(db, contract, EXTERNAL_MODULE)?
.into_iter()
Expand Down
4 changes: 2 additions & 2 deletions crates/cairo-lang-test-runner/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};

use anyhow::{bail, Context};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::check_and_eprint_diagnostics;
use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
use cairo_lang_compiler::project::setup_project;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::{FreeFunctionId, FunctionWithBodyId, ModuleItemId};
Expand Down Expand Up @@ -76,7 +76,7 @@ fn main() -> anyhow::Result<()> {

let main_crate_ids = setup_project(db, Path::new(&args.path))?;

if check_and_eprint_diagnostics(db) {
if DiagnosticsReporter::stderr().check(db) {
bail!("failed to compile: {}", args.path);
}
let all_tests = find_all_tests(db, main_crate_ids);
Expand Down
10 changes: 5 additions & 5 deletions tests/e2e_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::check_and_eprint_diagnostics;
use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
use cairo_lang_semantic::test_utils::setup_test_module;
use cairo_lang_sierra_generator::db::SierraGenGroup;
use cairo_lang_sierra_generator::replace_ids::replace_sierra_ids_in_program;
Expand Down Expand Up @@ -40,14 +40,14 @@ cairo_lang_test_utils::test_file_test!(
);

fn run_small_e2e_test(inputs: &OrderedHashMap<String, String>) -> OrderedHashMap<String, String> {
let db = &mut RootDatabase::builder().detect_corelib().build().unwrap();
let mut db = RootDatabase::builder().detect_corelib().build().unwrap();
// Parse code and create semantic model.
let test_module = setup_test_module(db, inputs["cairo"].as_str()).unwrap();
assert!(!check_and_eprint_diagnostics(db));
let test_module = setup_test_module(&mut db, inputs["cairo"].as_str()).unwrap();
DiagnosticsReporter::stderr().ensure(&mut db).unwrap();

// Compile to Sierra.
let sierra_program = db.get_sierra_program(vec![test_module.crate_id]).unwrap();
let sierra_program = replace_sierra_ids_in_program(db, &sierra_program);
let sierra_program = replace_sierra_ids_in_program(&db, &sierra_program);
let sierra_program_str = sierra_program.to_string();

// Compute the metadata.
Expand Down
4 changes: 2 additions & 2 deletions tests/examples_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::PathBuf;
use assert_matches::assert_matches;
use cairo_felt::{self as felt, felt_str, Felt};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::check_and_eprint_diagnostics;
use cairo_lang_compiler::diagnostics::DiagnosticsReporter;
use cairo_lang_compiler::project::setup_project;
use cairo_lang_filesystem::ids::CrateId;
use cairo_lang_runner::{RunResultValue, SierraCasmRunner, DUMMY_BUILTIN_GAS_COST};
Expand All @@ -24,7 +24,7 @@ fn setup(name: &str) -> (RootDatabase, Vec<CrateId>) {

let mut db = RootDatabase::builder().detect_corelib().build().unwrap();
let main_crate_ids = setup_project(&mut db, path.as_path()).expect("Project setup failed.");
assert!(!check_and_eprint_diagnostics(&mut db));
DiagnosticsReporter::stderr().ensure(&mut db).unwrap();
(db, main_crate_ids)
}

Expand Down

0 comments on commit 18b55fe

Please sign in to comment.