diff --git a/tests/xdsl_opt/test_xdsl_opt.py b/tests/xdsl_opt/test_xdsl_opt.py index 399bedc7ca..89cb111f1e 100644 --- a/tests/xdsl_opt/test_xdsl_opt.py +++ b/tests/xdsl_opt/test_xdsl_opt.py @@ -7,7 +7,12 @@ from xdsl.ir import MLContext from xdsl.passes import ModulePass from xdsl.utils.exceptions import DiagnosticException -from xdsl.xdsl_opt_main import xDSLOptMain +from xdsl.xdsl_opt_main import get_all_dialects, get_all_passes, xDSLOptMain + + +def test_dialects_and_passes(): + assert len(get_all_dialects()) > 0 + assert len(get_all_passes()) > 0 def test_opt(): diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 683eb0028d..08f6e12c6c 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -5,7 +5,7 @@ from io import StringIO from xdsl.frontend.symref import Symref -from xdsl.ir import MLContext +from xdsl.ir import Dialect, MLContext from xdsl.parser import Parser, ParseError from xdsl.passes import ModulePass from xdsl.printer import Printer @@ -60,6 +60,55 @@ from typing import IO, Dict, Callable, List, Sequence, Type +def get_all_dialects() -> list[Dialect]: + """Return the list of all available dialects.""" + return [ + Affine, + Arith, + Builtin, + Cf, + CMath, + DMP, + FIR, + Func, + GPU, + IRDL, + LLVM, + Math, + MemRef, + MPI, + PDL, + RISCV, + RISCV_Func, + Scf, + Snitch, + SnitchRuntime, + StencilExp, + Stencil, + Symref, + Test, + Vector, + ] + + +def get_all_passes() -> list[type[ModulePass]]: + """Return the list of all available passes.""" + return [ + ConvertStencilToLLMLIRPass, + DeadCodeElimination, + DesymrefyPass, + DmpScatterGatherTrivialLowering, + GlobalStencilToLocalStencil2DHorizontal, + LowerHaloToMPI, + LowerMPIPass, + LowerRISCVFunc, + LowerSnitchPass, + LowerSnitchRuntimePass, + RISCVRegisterAllocation, + StencilShapeInferencePass, + ] + + class xDSLOptMain: ctx: MLContext args: argparse.Namespace @@ -230,31 +279,8 @@ def register_all_dialects(self): Add other/additional dialects by overloading this function. """ - self.ctx.register_dialect(Builtin) - self.ctx.register_dialect(Func) - self.ctx.register_dialect(Arith) - self.ctx.register_dialect(MemRef) - self.ctx.register_dialect(Affine) - self.ctx.register_dialect(Scf) - self.ctx.register_dialect(Cf) - self.ctx.register_dialect(CMath) - self.ctx.register_dialect(Math) - self.ctx.register_dialect(LLVM) - self.ctx.register_dialect(Vector) - self.ctx.register_dialect(MPI) - self.ctx.register_dialect(GPU) - self.ctx.register_dialect(StencilExp) - self.ctx.register_dialect(Stencil) - self.ctx.register_dialect(PDL) - self.ctx.register_dialect(Symref) - self.ctx.register_dialect(Test) - self.ctx.register_dialect(RISCV) - self.ctx.register_dialect(Snitch) - self.ctx.register_dialect(SnitchRuntime) - self.ctx.register_dialect(RISCV_Func) - self.ctx.register_dialect(IRDL) - self.ctx.register_dialect(FIR) - self.ctx.register_dialect(DMP) + for dialect in get_all_dialects(): + self.ctx.register_dialect(dialect) def register_all_frontends(self): """ @@ -282,18 +308,8 @@ def register_all_passes(self): Add other/additional passes by overloading this function. """ - self.register_pass(LowerMPIPass) - self.register_pass(ConvertStencilToLLMLIRPass) - self.register_pass(StencilShapeInferencePass) - self.register_pass(GlobalStencilToLocalStencil2DHorizontal) - self.register_pass(DesymrefyPass) - self.register_pass(DeadCodeElimination) - self.register_pass(LowerSnitchPass) - self.register_pass(LowerSnitchRuntimePass) - self.register_pass(RISCVRegisterAllocation) - self.register_pass(LowerRISCVFunc) - self.register_pass(LowerHaloToMPI) - self.register_pass(DmpScatterGatherTrivialLowering) + for pass_ in get_all_passes(): + self.register_pass(pass_) def register_all_targets(self): """