Skip to content

Commit

Permalink
Merge pull request #245 from qir-alliance/sezna/236
Browse files Browse the repository at this point in the history
Make enums hashable
  • Loading branch information
idavis committed Aug 29, 2023
2 parents e7d4892 + 25248f5 commit de2c8ed
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 6 deletions.
48 changes: 48 additions & 0 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,18 @@ class FloatConstant(Constant):
class FloatPredicate(Enum):
"""A floating-point comparison predicate."""

def __richcmp__(self, other: Value, op: int) -> bool:
"""
Compares this value to another value.
Only == and != are supported.
:param other: The other value.
:param op: The comparison operator.
:returns: The result of the comparison.
"""
...
def __hash__(self) -> int: ...

FALSE: FloatPredicate
OEQ: FloatPredicate
OGT: FloatPredicate
Expand Down Expand Up @@ -414,6 +426,18 @@ class IntConstant(Constant):
class IntPredicate(Enum):
"""An integer comparison predicate."""

def __richcmp__(self, other: Value, op: int) -> bool:
"""
Compares this value to another value.
Only == and != are supported.
:param other: The other value.
:param op: The comparison operator.
:returns: The result of the comparison.
"""
...
def __hash__(self) -> int: ...

EQ: IntPredicate
NE: IntPredicate
UGT: IntPredicate
Expand Down Expand Up @@ -444,6 +468,18 @@ class IntType(Type):
class Linkage(Enum):
"""The linkage kind for a global value in a module."""

def __richcmp__(self, other: Value, op: int) -> bool:
"""
Compares this value to another value.
Only == and != are supported.
:param other: The other value.
:param op: The comparison operator.
:returns: The result of the comparison.
"""
...
def __hash__(self) -> int: ...

APPENDING: Linkage
AVAILABLE_EXTERNALLY: Linkage
COMMON: Linkage
Expand Down Expand Up @@ -553,6 +589,18 @@ class ModuleFlagBehavior(Enum):
class Opcode(Enum):
"""An instruction opcode."""

def __richcmp__(self, other: Value, op: int) -> bool:
"""
Compares this value to another value.
Only == and != are supported.
:param other: The other value.
:param op: The comparison operator.
:returns: The result of the comparison.
"""
...
def __hash__(self) -> int: ...

ADD: Opcode
ADDR_SPACE_CAST: Opcode
ALLOCA: Opcode
Expand Down
73 changes: 69 additions & 4 deletions pyqir/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
use crate::values::{BasicBlock, Owner, Value};
#[allow(clippy::wildcard_imports)]
use llvm_sys::{core::*, prelude::*, LLVMIntPredicate, LLVMOpcode, LLVMRealPredicate};
use pyo3::{conversion::ToPyObject, prelude::*};
use std::{convert::Into, ptr::NonNull};
use pyo3::{conversion::ToPyObject, prelude::*, pyclass::CompareOp, PyRef};
use std::{
collections::hash_map::DefaultHasher,
convert::Into,
hash::{Hash, Hasher},
ptr::NonNull,
};

/// An instruction.
#[pyclass(extends = Value, subclass)]
Expand Down Expand Up @@ -96,6 +101,7 @@ impl Instruction {

/// An instruction opcode.
#[pyclass]
#[derive(PartialEq, Hash)]
pub(crate) enum Opcode {
#[pyo3(name = "ADD")]
Add,
Expand Down Expand Up @@ -233,6 +239,25 @@ pub(crate) enum Opcode {
ZExt,
}

#[pymethods]
impl Opcode {
// In order to implement the comparison operators, we have to do
// it all in one impl of __richcmp__ for pyo3 to work.
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => self.eq(other).into_py(py),
CompareOp::Ne => (!self.eq(other)).into_py(py),
_ => py.NotImplemented(),
}
}

fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}

impl From<LLVMOpcode> for Opcode {
fn from(opcode: LLVMOpcode) -> Self {
match opcode {
Expand Down Expand Up @@ -378,7 +403,7 @@ impl ICmp {

/// An integer comparison predicate.
#[pyclass]
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Hash)]
pub(crate) enum IntPredicate {
#[pyo3(name = "EQ")]
Eq,
Expand All @@ -402,6 +427,26 @@ pub(crate) enum IntPredicate {
Sle,
}

#[allow(clippy::trivially_copy_pass_by_ref)]
#[pymethods]
impl IntPredicate {
// In order to implement the comparison operators, we have to do
// it all in one impl of __richcmp__ for pyo3 to work.
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => self.eq(other).into_py(py),
CompareOp::Ne => (!self.eq(other)).into_py(py),
_ => py.NotImplemented(),
}
}

fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}

impl From<LLVMIntPredicate> for IntPredicate {
fn from(pred: LLVMIntPredicate) -> Self {
match pred {
Expand Down Expand Up @@ -453,7 +498,7 @@ impl FCmp {

/// A floating-point comparison predicate.
#[pyclass]
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Hash)]
pub(crate) enum FloatPredicate {
#[pyo3(name = "FALSE")]
False,
Expand Down Expand Up @@ -489,6 +534,26 @@ pub(crate) enum FloatPredicate {
True,
}

#[pymethods]
#[allow(clippy::trivially_copy_pass_by_ref)]
impl FloatPredicate {
// In order to implement the comparison operators, we have to do
// it all in one impl of __richcmp__ for pyo3 to work.
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => self.eq(other).into_py(py),
CompareOp::Ne => (!self.eq(other)).into_py(py),
_ => py.NotImplemented(),
}
}

fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}

impl From<LLVMRealPredicate> for FloatPredicate {
fn from(pred: LLVMRealPredicate) -> Self {
match pred {
Expand Down
26 changes: 24 additions & 2 deletions pyqir/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ use llvm_sys::{
ir_reader::LLVMParseIRInContext,
LLVMLinkage, LLVMModule,
};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes};
use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyBytes};
use qirlib::module::FlagBehavior;
use std::{
collections::hash_map::DefaultHasher,
ffi::CString,
hash::{Hash, Hasher},
ops::Deref,
ptr::{self, NonNull},
str,
Expand Down Expand Up @@ -289,7 +291,7 @@ impl PartialEq for Module {

/// The linkage kind for a global value in a module.
#[pyclass]
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Hash)]
pub(crate) enum Linkage {
#[pyo3(name = "APPENDING")]
Appending,
Expand All @@ -315,6 +317,26 @@ pub(crate) enum Linkage {
WeakOdr,
}

#[pymethods]
#[allow(clippy::trivially_copy_pass_by_ref)]
impl Linkage {
// In order to implement the comparison operators, we have to do
// it all in one impl of __richcmp__ for pyo3 to work.
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
match op {
CompareOp::Eq => self.eq(other).into_py(py),
CompareOp::Ne => (!self.eq(other)).into_py(py),
_ => py.NotImplemented(),
}
}

fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}

impl From<Linkage> for LLVMLinkage {
fn from(linkage: Linkage) -> Self {
match linkage {
Expand Down
Loading

0 comments on commit de2c8ed

Please sign in to comment.