Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Expose the list of dialects without using xDSLOptMain #1079

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/xdsl_opt/test_xdsl_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.utils.exceptions import DiagnosticException
from xdsl.xdsl_opt_main import xDSLOptMain
from xdsl.xdsl_opt_main import get_all_dialects, get_all_passes, xDSLOptMain


def test_dialects_and_passes():
assert len(get_all_dialects()) > 0
Copy link
Member

@superlopuh superlopuh Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I meant something like this:

Suggested change
assert len(get_all_dialects()) > 0
names = [d.__name__ for d in get_all_dialects()]
assert names == sorted(names), "Names of dialects should be in alphabetical order"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to enforce an invariant like that? To me, the dialects are a set of dialects, so it does not even make all that much sense to talk about order 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, I don't think we want to have that!
I just want it in the function definition to make it cleaner!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would automate this little bit of the process, I don't see the harm in it TBH

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't have a strong opinion, it's just that comments that say "please" in them aren't followed as closely as those that have a matching test to catch transgressors :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I would be against it to be honest.
The harm is that we are adding additional constraints on the API that we don't need, and that may change in the future. So I would prefer not introducing them in the first place.
Also, this check won't work, as dialects are object, not classes, so __name__ won't exist anyway!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not see the comment before, IMO, we should remove that 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@webmiche, do you mean we should remove the comment about keeping them in order?
I added it only because I felt it should be sorted, both for readability and conficts.
But if this cause too much trouble I'm fine removing it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not sure. I would remove the comment, but I am fine with leaving it in, as long as we do not actually enforce it.

assert len(get_all_passes()) > 0


def test_opt():
Expand Down
92 changes: 54 additions & 38 deletions xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from io import StringIO
from xdsl.frontend.symref import Symref

from xdsl.ir import MLContext
from xdsl.ir import Dialect, MLContext
from xdsl.parser import Parser, ParseError
from xdsl.passes import ModulePass
from xdsl.printer import Printer
Expand Down Expand Up @@ -60,6 +60,55 @@
from typing import IO, Dict, Callable, List, Sequence, Type


def get_all_dialects() -> list[Dialect]:
"""Return the list of all available dialects."""
return [
Affine,
Arith,
Builtin,
Cf,
CMath,
DMP,
FIR,
Func,
GPU,
IRDL,
LLVM,
Math,
MemRef,
MPI,
PDL,
RISCV,
RISCV_Func,
Scf,
Snitch,
SnitchRuntime,
StencilExp,
Stencil,
Symref,
Test,
Vector,
]


def get_all_passes() -> list[type[ModulePass]]:
"""Return the list of all available passes."""
return [
ConvertStencilToLLMLIRPass,
DeadCodeElimination,
DesymrefyPass,
DmpScatterGatherTrivialLowering,
GlobalStencilToLocalStencil2DHorizontal,
LowerHaloToMPI,
LowerMPIPass,
LowerRISCVFunc,
LowerSnitchPass,
LowerSnitchRuntimePass,
RISCVRegisterAllocation,
StencilShapeInferencePass,
]


class xDSLOptMain:
ctx: MLContext
args: argparse.Namespace
Expand Down Expand Up @@ -230,31 +279,8 @@ def register_all_dialects(self):

Add other/additional dialects by overloading this function.
"""
self.ctx.register_dialect(Builtin)
self.ctx.register_dialect(Func)
self.ctx.register_dialect(Arith)
self.ctx.register_dialect(MemRef)
self.ctx.register_dialect(Affine)
self.ctx.register_dialect(Scf)
self.ctx.register_dialect(Cf)
self.ctx.register_dialect(CMath)
self.ctx.register_dialect(Math)
self.ctx.register_dialect(LLVM)
self.ctx.register_dialect(Vector)
self.ctx.register_dialect(MPI)
self.ctx.register_dialect(GPU)
self.ctx.register_dialect(StencilExp)
self.ctx.register_dialect(Stencil)
self.ctx.register_dialect(PDL)
self.ctx.register_dialect(Symref)
self.ctx.register_dialect(Test)
self.ctx.register_dialect(RISCV)
self.ctx.register_dialect(Snitch)
self.ctx.register_dialect(SnitchRuntime)
self.ctx.register_dialect(RISCV_Func)
self.ctx.register_dialect(IRDL)
self.ctx.register_dialect(FIR)
self.ctx.register_dialect(DMP)
for dialect in get_all_dialects():
self.ctx.register_dialect(dialect)

def register_all_frontends(self):
"""
Expand Down Expand Up @@ -282,18 +308,8 @@ def register_all_passes(self):

Add other/additional passes by overloading this function.
"""
self.register_pass(LowerMPIPass)
self.register_pass(ConvertStencilToLLMLIRPass)
self.register_pass(StencilShapeInferencePass)
self.register_pass(GlobalStencilToLocalStencil2DHorizontal)
self.register_pass(DesymrefyPass)
self.register_pass(DeadCodeElimination)
self.register_pass(LowerSnitchPass)
self.register_pass(LowerSnitchRuntimePass)
self.register_pass(RISCVRegisterAllocation)
self.register_pass(LowerRISCVFunc)
self.register_pass(LowerHaloToMPI)
self.register_pass(DmpScatterGatherTrivialLowering)
for pass_ in get_all_passes():
self.register_pass(pass_)

def register_all_targets(self):
"""
Expand Down