In [15]:
import json
import decision_rules as dr
from decision_rules.serialization import JSONSerializer
from decision_rules.classification import ClassificationRuleSet
from decision_rules import measures
import pandas as pd

df = pd.read_csv('./train.csv')
X, y = df.drop('class', axis=1), df['class']

with open('./deeprules.json', 'r') as f:
    ruleset_json = json.load(f)
    ruleset: ClassificationRuleSet = JSONSerializer().deserialize(ruleset_json, ClassificationRuleSet)


_ = ruleset.update(X, y, measure=measures.c2)


  self._operator_func = lambda A, B: A > B


In [16]:
from decision_rules.similarity import calculate_rule_similarity, SimilarityType, SimilarityMeasure

calculate_rule_similarity(
    ruleset, ruleset, df, similarity_type=SimilarityType.SYNTACTIC
)

AttributeError: 'CompoundCondition' object has no attribute 'column_index'

In [None]:
from __future__ import annotations
from collections import defaultdict
from itertools import product

import numpy as np
import pandas as pd
from decision_rules.conditions import CompoundCondition, LogicOperators
from decision_rules.conditions import ElementaryCondition
from decision_rules.conditions import NominalCondition
from decision_rules.core.ruleset import AbstractRuleSet
from decision_rules.core.simplifier import RulesetSimplifier
from dataclasses import dataclass
from typing import Optional


@dataclass
class Interval:

    left: float
    right: float
    left_closed: bool
    right_closed: bool

    @staticmethod
    def from_elementary_condition(self, condition: ElementaryCondition):
        return Interval(
            left=float("-inf") if condition.left is None else condition.left,
            right=float("inf") if condition.right is None else condition.right,
            left_closed=condition.left_closed,
            right_closed=condition.right_closed,
        )

    def __eq__(self, other):
        if not isinstance(other, Interval):
            return NotImplemented
        return (
            self.left == other.left
            and self.right == other.right
            and self.left_closed == other.left_closed
            and self.right_closed == other.right_closed
        )

    def is_empty(self) -> bool:
        """Checks if the interval is empty."""
        if self.left == float("-inf") or self.right == float(
            "inf"
        ):  # Handle potentially unbounded conditions from source
            return False  # Unbounded cannot be empty in finite sense
        if self.left > self.right:
            return True
        if self.left == self.right:
            return not (self.left_closed and self.right_closed)
        return False

    def calculate_intersection(self, other: Interval) -> Optional[Interval]:
        """
        Calculates the intersection of two intervals.
        Returns a new Interval object or None if the intersection is empty.
        """
        if self.is_empty() or other.is_empty():
            return None

        # Determine the new left boundary
        new_left = max(self.left, other.left)

        # Determine the new right boundary
        new_right = min(self.right, other.right)

        # Determine new left closure
        new_left_closed = False
        if self.left == new_left and self.left_closed:
            new_left_closed = True
        elif other.left == new_left and other.left_closed:
            new_left_closed = True
        # If new_left came from one open and one closed, it should be open
        # Unless both were closed and equal
        if self.left == new_left and not self.left_closed:
            new_left_closed = False
        if other.left == new_left and not other.left_closed:
            new_left_closed = False
        # Special case: If new_left is equal to max(lefts) AND both are closed
        if self.left == new_left and other.left == new_left:
            new_left_closed = self.left_closed and other.left_closed
        elif self.left == new_left:  # only self contributes to new_left
            new_left_closed = self.left_closed
        elif other.left == new_left:  # only other contributes to new_left
            new_left_closed = other.left_closed

        # Determine new right closure
        new_right_closed = False
        if self.right == new_right and self.right_closed:
            new_right_closed = True
        elif other.right == new_right and other.right_closed:
            new_right_closed = True

        if self.right == new_right and not self.right_closed:
            new_right_closed = False
        if other.right == new_right and not other.right_closed:
            new_right_closed = False

        if self.right == new_right and other.right == new_right:
            new_right_closed = self.right_closed and other.right_closed
        elif self.right == new_right:
            new_right_closed = self.right_closed
        elif other.right == new_right:
            new_right_closed = other.right_closed

        # Create a dummy condition object to pass to Interval constructor
        # This is a bit clunky due to the original Interval.__init__ signature
        temp_condition = ElementaryCondition(
            new_left, new_right, new_left_closed, new_right_closed
        )
        intersection_interval = Interval(temp_condition)

        # Check if the intersection is empty (e.g., [1,2) and [2,3))
        if intersection_interval.left > intersection_interval.right:
            return None
        if intersection_interval.left == intersection_interval.right and not (
            intersection_interval.left_closed and intersection_interval.right_closed
        ):
            return None

        return intersection_interval

    def add(self, other: Interval) -> list[Interval]:
        """
        Adds two intervals, merging them if they overlap or touch.
        Returns a list of Intervals. Can be one merged interval or two separate ones.
        """
        # Sort intervals by their left boundary for easier processing
        i1, i2 = (self, other) if self.left <= other.left else (other, self)

        # Check for overlap or touching
        # Overlap: (i1.left, i1.right) & (i2.left, i2.right) have common points
        # Touching: [1,2] and [2,3] -> [1,3]
        # Or even [1,2) and [2,3) -> not mergeable (no common point)
        # For merging, either they overlap, or one's right touches another's left
        # and at least one boundary at the touch point is closed.

        can_merge = False
        # Case 1: Overlap
        if i1.right > i2.left:
            can_merge = True
        # Case 2: Touching
        elif i1.right == i2.left:
            # Can merge if i1 is right-closed OR i2 is left-closed (or both)
            if i1.right_closed or i2.left_closed:
                can_merge = True

        if can_merge:
            new_left = i1.left
            new_left_closed = i1.left_closed

            new_right = i2.right
            new_right_closed = i2.right_closed

            if i1.right > i2.right:  # i1 fully contains i2 or extends further right
                new_right = i1.right
                new_right_closed = i1.right_closed
            elif i1.right == i2.right:  # i1 and i2 end at same point
                new_right_closed = (
                    i1.right_closed or i2.right_closed
                )  # if either is closed, the union is closed

            temp_condition = ElementaryCondition(
                new_left, new_right, new_left_closed, new_right_closed
            )
            return [Interval(temp_condition)]
        else:
            # No overlap, return both intervals as a list
            # Ensure they are sorted for consistent output
            return sorted([self, other], key=lambda x: x.left)


def add_multiple_intervals(intervals: list[Interval]) -> list[Interval]:
    """
    Adds (merges) a list of intervals, returning a minimal set of non-overlapping intervals.
    """
    if not intervals:
        return []

    # Sort intervals by their left boundary
    intervals.sort(key=lambda x: x.left)

    merged_intervals: list[Interval] = []
    current_interval = intervals[0]

    for i in range(1, len(intervals)):
        next_interval = intervals[i]

        # Use the add method of Interval to try and merge
        # The add method returns a list. If it's one item, they merged.
        # If it's two items, they didn't.
        merged_result = current_interval.add(next_interval)

        if len(merged_result) == 1:
            # They merged, update current_interval
            current_interval = merged_result[0]
        else:
            # They did not merge, add the current_interval to the result list
            # and start a new current_interval with the next one
            merged_intervals.append(current_interval)
            current_interval = next_interval

    # Add the last current_interval to the result list
    merged_intervals.append(current_interval)
    return merged_intervals

class AggregatedNumericalConditions

class SyntacticRuleSimilarityCalculator:
    """
    Calculator of syntactic rule similarity.
    Caveat: the assumption is that the conditions in rules are connected only with conjunction operators.
    """

    def __init__(
        self,
        ruleset1: AbstractRuleSet,
        ruleset2: AbstractRuleSet,
        dataset: pd.DataFrame,
    ):
        self.dataset = dataset
        self.ruleset1 = self._parse_rules_to_conditions(self.ruleset1)
        self.ruleset2 = self._parse_rules_to_conditions(self.ruleset2)

    def calculate(self) -> np.ndarray:
        # calculate rule similarity in a matrix of rule pairs in a vectorized way
        rule_pairs = np.array(list(product(self.ruleset1, self.ruleset2)))
        result = [self._calculate_rule_sim(*rule_pair) for rule_pair in rule_pairs]
        result = np.array(result)
        result = result.reshape(len(self.ruleset1), len(self.ruleset2))
        return result

    def _calculate_rule_sim(self, rule1: dict, rule2: dict) -> float:
        # calculations for denominator
        denominator = (
            len(rule1["elementary"])
            + len(rule2["elementary"])
            + len(rule1["nominal"])
            + len(rule2["nominal"])
        )
        # elementary conditions sums
        elem_sum = self._calculate_elementary_condition_sim_sum(
            rule1["elementary"], rule2["elementary"]
        )
        # nominal conditions sums
        nomin_sum = self._calculate_nominal_condition_sim_sum(
            rule1["nominal"], rule2["nominal"]
        )
        return (elem_sum + nomin_sum) / denominator

    def _parse_rules_to_conditions(self, ruleset: AbstractRuleSet) -> list[dict]:
        rules = []
        for rule in ruleset.rules:
            rule_conditions = defaultdict(dict)
            all_conditions = (
                rule.premise.subconditions
                if isinstance(rule.premise, CompoundCondition)
                else [rule.premise]
            )
            for condition in all_conditions:
                if isinstance(condition, ElementaryCondition):
                    key = ruleset.column_names[condition.column_index]
                    left = self._evaluate_boundary(condition.left, key)
                    right = self._evaluate_boundary(condition.right, key)
                    rule_conditions["elementary"][key] = left, right
                elif isinstance(condition, NominalCondition):
                    key = ruleset.column_names[condition.column_index]
                    rule_conditions["nominal"][key] = rule_conditions["nominal"].get(
                        key, []
                    ) + [condition.value]
                else:
                    raise NotImplementedError(
                        "Only elementary and nominal conditions are supported"
                    )
            rules.append(rule_conditions)
        return rules

    def _evaluate_boundary(self, boundary: float, column_name: str) -> float:
        # if a condition is one-sided, change the appropriate +/- inf bound
        # to the actual max/min value of the column in the dataset
        if boundary == float("inf"):
            return self.dataset[column_name].max()
        if boundary == float("-inf"):
            return self.dataset[column_name].min()
        return boundary

    def _calculate_elementary_condition_sim_sum(
        self, elementary_conditions1, elementary_conditions2
    ) -> float:
        keys = set(elementary_conditions1) & set(elementary_conditions2)
        sim_sum = 0.0
        for key in keys:
            interval1, interval2 = (
                elementary_conditions1[key],
                elementary_conditions2[key],
            )
            overlap = self._calculate_overlap(interval1, interval2)
            rirj = overlap / (interval1[1] - interval1[0])
            sim_sum += rirj
            rjri = overlap / (interval2[1] - interval2[0])
            sim_sum += rjri
        return sim_sum

    @staticmethod
    def _calculate_nominal_condition_sim_sum(
        nominal_conditions1, nominal_conditions2
    ) -> float:
        keys = set(nominal_conditions1) & set(nominal_conditions2)
        sim_sum = 0.0
        for key in keys:
            overlap = set(nominal_conditions1[key]) & set(nominal_conditions2[key])
            set_ij = set(nominal_conditions1[key])
            set_ji = set(nominal_conditions2[key])
            sim_sum += len(overlap) / len(set_ij)
            sim_sum += len(overlap) / len(set_ji)
        return sim_sum

    @staticmethod
    def _calculate_overlap(
        interval1: tuple[float, float], interval2: tuple[float, float]
    ) -> float:
        # helper function to calculate overlap between two intervals
        left_bound = max(interval1[0], interval2[0])
        right_bound = min(interval1[1], interval2[1])
        return max(0.0, right_bound - left_bound)