In [1]:
! antlr4 SCLang.g4 -Dlanguage=Python3 -visitor

In [2]:
import os
import sys
from abc import abstractmethod
from typing import Optional, Union, TypeVar
from loader import load_yaml, load_script
from utils import Ref, RefDict
from types import *
from itertools import chain, combinations, product
from operator import attrgetter


if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self

import logging
logger = logging.getLogger(__name__)
    

T_OP = TypeVar('T_OP')
T_CHILD = TypeVar('T_CHILD')
class Expr(Generic[T_OP, T_CHILD]):
    def __init__(self, operator: T_OP, children: Union[T_CHILD, list[T_CHILD]]):
        self._operator = operator
        self._children = [children] if isinstance(children, T_CHILD) else list(children)

    def __str__(self):
        return str(self._operator).join(map(str, self.children))
    
    def __eq__(self, other):
        return (
            self._operator == other._operator and
            all(c1 == c2 for c1, c2 in zip(self._children, other._children))
        )

    def __len__(self):
        return len(self._children)

    def unique(self) -> Self:
        c_ = sorted([c_.unique() for c_ in self._children], key=attrgetter('_sort_on'))
        return self.__class__(operator=self._operator, children=c_)

    @abstractmethod
    def refactor(self, operator: Optional[T_OP]=None) -> Self:
        raise NotImplementedError()

    @property
    def _sort_on(self):
        raise NotImplementedError()

    # @abstractmethod
    # def contradicts(self, other: Self) -> bool:
    #     raise NotImplementedError()

    # @abstractmethod
    # def complements(self, other: Self) -> bool:
    #     raise NotImplementedError()


class LogicExpr(Expr):
    def __init__(self, operator: LogicOpType, children: Union[LogicExpr, list[LogicExpr]]):
        self._operator = operator
        self._children = [children] if isinstance(children, LogicExpr) else children

    def __str__(self) -> str:
        d_ = {
            LogicModifier.NOT: '¬',
            LogicOperator.AND: '∧',
            LogicOperator.OR: '∨',
        }
        
        op_ = self._operator
        if isinstance(op_, LogicModifier):
            s_ = d_[op_] + f"({self.children[0]})"
        elif isinstance(op_, LogicOperator):
            s_ = d_[op_].join(map(str, self.children))
        
        return s_

    def unique(self) -> Self:
        c_ = [s_ for c in self._children if len(s_:=c.unique())]
        
        if self._operator == LogicOperator.OR:
            if any(c1 == c2 for c1, c2 in combinations(c_, 2)):
                c_ = []
            else:
                c_ = self._children
        elif self._operator == LogicOperator.AND:
            if any(c1.contradicts(c2) for c1, c2 in combinations(c_, 2)):
                c_ = []
            else:
                c_ = self._children
        
        return LogicExpr(self._operator, c_)

    # def contradicts(self, other: LogicExpr) -> bool:
    #     g_ = c1.contradicts(c2) for c1, c2 in product(self._children, other._children)
    #     if self._operator == LogicOperator.OR:
    #         return not all(g_)
    #     elif self._operator == LogicOperator.AND:
    #         return any(g_)
    #     elif self._operator == LogicModifier.NOT:
    #         return self.demorgan().contradicts(other)
    #     else:
    #         return False

    # def complements(self, other: LogicExpr) -> bool:
    #     if self._operator == LogicOperator.OR:
    #         return (not self.contradicts(LogicExpr(operator=LogicModifier.NOT, other)) and
                    
    #     return all(any(c1.complements(c2) for c2 in other._children) for c1 in self._children)

    def refactor(self, target_operator: LogicOpType) -> Self:
        c_ = self._children
        op_ = target_operator
        if op_ == LogicModifier.NOT:
            if self._operator !=  LogicModifier.NOT:
                return self.demorgan().refactor(target_operator=op_) 
        else:
            if self._operator == LogicModifier.NOT:
                return self.demorgan().refactor(target_operator=op_)
            elif self._operator == operator:
                new_children = chain.from_iterable(c.refactor(target_operator=op_)._children for c in c_)
            else:
                new_children = (LogicExpr(self._operator, [c1, c2]) for c1, c2 in combinations(c_, 2))
        return LogicExpr(op_, list(new_children))

    def disjunctive(self) -> Self:
        return self.__normalize_on_op(LogicOperator.OR)

    def conjunctive(self) -> Self:
        return self.__normalize_on_op(LogicOperator.AND)

    def demorgan(self) -> Self:
        if self._operator == LogicModifier.NOT:
            c_ = self._children[0]._children
            new_op_ = self._children[0]._operator
        else:
            c_ = self._children
            new_op_ = LogicModifier.NOT
        
        new_children = [LogicExpr(LogicModifier.NOT, [c]) for c in c_]
            
        if c_._operator == LogicOperator.AND:
            # ¬(A ∧ B) = ¬A ∨ ¬B
            return LogicExpr(LogicOperator.OR, new_children)
        elif child._operator == LogicOperator.OR:
            # ¬(A ∨ B) = ¬A ∧ ¬B
            return LogicExpr(LogicOperator.AND, new_children)
            
        
        # If the expression or its children do not match the pattern for De Morgan's laws, return as is
        return self


class Attrib(Expr):
    def __init__(self, entity: Ref[Entity], method: str, args: Optional[list[LogicExpr]]=None, modifier:Optional[AttribModifier]=None):
        self._entity = entity
        self._method = method
        self._args = [] if args is None else args
        self._modifier = modifier

    def _simplify(self) -> Self:
        return self


class Comparison(LogicExpr):
    def __init__(self, attrib1: Attrib, attrib2: Attrib, compare_op:Comparator):
        self._children = [attrib1, attrib2]
        self._operator = compare_op

    def _simplify(self) -> Self:
        d_ = {
            Comparator.GREATER: Comparator.LESS,
            Comparator.GREATER_EQUAL: Comparator.LESS_,
        }
        if self._compare_op in d_:
            return Comparison(self._attrib2, self._attrib1, d_[self._compare_op])
        elif compare_op in [Comparator.EQUAL,
                            Comparator.NOT_EQUAL]:
            if str(self._attrib1) > str(self._attrib2):
                return Comparison(self._attrib2, self._attrib1, self._compare_op)
            else:
                return copy(self)


class Entity:
    def __init__(self, ent_class: Ref[type], conditions: Optional[LogicExpr]=None):
        self._ent_class = ent_class
        self._conditions = conditions


class Skill: # TODO
    def __init__(self, cls, ent_mgr, name: Optional[str]=None):
        self._cls = cls
        self._name = name

    @property
    def cls(self):
        return self._cls
    
    @property
    def name(self) -> str:
        return self._name


class SkillManager:
    def __init__(self):
        self._skills = {}
    
    def wrap(self, name: Optional[str]=None):
        
        def decorator(cls):
            nonlocal name
            name_ = name or cls.__name__
            skl = Skill()
            self._register(cls, name_)
            return cls
        
        return decorator

    def register_skill(self, skill: Skill) -> None:
        if name_ in manager.skills:
            logger.warning(f"\"{name_}\" already registered to skill manager. If intended, you can ignore this warning.")
        skl = Skill(cls, name)
        self._skills[name] = skl
        try:
            del self.skills 
        except AttributeError:
            raise NotImplementedError()

    def validate(self):
        pass


class Intent:
    def __init__(self, name:str, actor: Ref[Entity],
                 args: Optional[list[LogicExpr]]=None,
                 visibility: VisibilityStat = VisibilityStat.PUBLIC):
        self._name = name
        self._visibility = visibility
        self._actor = actor
        self._args = args
        self._condition = None
        self._effect = None
        self._metadata = None

    def set_visibility(self, visibility: VisibilityStat):
        self._visibility = visibility
    
    def set_condition(self, condition: Condition):
        self._condition = condition

    def set_effect(self, effect: Condition):
        self._effect = effect

    def set_metadata(self, metadata: MetaDataType):
        self._metadata = metadata

    @property
    def visibility(self) -> VisibilityStat:
        return self._visibility
    
    @property
    def actor(self) -> Ref[Entity]:
        return self._actor

    @property
    def condition(self) -> Condition:
        return self._condition
    
    @property
    def effect(self) -> Condition:
        return self._effect

    @property
    def metadata(self) -> MetaDataType:
        return self._metadata

ImportError: cannot import name 'RefManager' from 'utils' (/home/shervin/Desktop/VA Project/utils.py)

In [37]:
class App:
    def __init__(self, mgr: SkillManager):
        self._mgr = mgr

    def init(self):
        pass

In [10]:
mgr = SkillManager()

In [11]:
@mgr.register_skill('definitions.yaml', 'conditions.scl')
class Skill:
    pass

TypeError: SkillManager.register_skill() takes 2 positional arguments but 3 were given

In [5]:
yaml = load_yaml("definitions.yaml")
script = load_script("conditions.scl")

In [None]:
class VectorDatabase:
    pass

class GraphDatabase:
    pass

class DBSchemaBase:
    pass

In [12]:
from typing import Optional

class DBObject(): # decorator with schema as arg?
    def push_db(self):
        pass

NameError: name 'ConditionGroup' is not defined

In [2]:
from antlr4 import ParseTreeVisitor
from SCLangParser import SCLangParser
from utils import RefDict


class SCLangVisitor(ParseTreeVisitor):

    def visitScript(self, ctx:SCLangParser.ScriptContext)
        return self.visitChildren(ctx)


    def visitDefinition(self, ctx:SCLangParser.DefinitionContext) -> Intent:
        self._ent_dict: RefDict[type] = RefDict()
        
        condition_ = None
        if (c_:=ctx.conditionBlock()):
            condition_ = self.visit(c_)
        visib_ = self.visit(ctx.visibility())
        intent_ = self.visit(ctx.actionBlock())
        effect_ = None
        if (c_:=ctx.effectBlock()):
            effect_ = self.visit(c_)
        meta_ = None
        if (c_:=ctx.metaData()):
            meta_ = self.visit(c_)
        
        intent_.set_condition(condition_)
        intent_.set_visibility(visib_)
        intent_.set_effect(effect_)
        intent_.set_metadata(meta_)

        return intent_


    def visitConditionBlock(self, ctx:SCLangParser.ConditionBlockContext) -> LogicExpr:
        return self.visit(ctx.logicExpr())


    def visitVisibility(self, ctx:SCLangParser.VisibilityContext) -> VisibilityStat:
        d_ = {
            'pub': VisibilityStat.Public,
            'priv': VisibilityStat.Private,
        }
        t_ = ctx.getText()
        return d_[t_]


    def visitActionBlock(self, ctx:SCLangParser.ActionBlockContext) -> Intent:
        ent_ = self.visit(ctx.entityGroup())
        method_ = self.visit(ctx.method())
        args_ = self.visit(ctx.argsWrapper())
        intent_ = Intent(name=method_, entity=ent_, args=args_)
        return intent_


    def visitEffectBlock(self, ctx:SCLangParser.EffectBlockContext) -> LogicExpr:
        return self.visit(ctx.logicExpr())


    def visitMetaData(self, ctx:SCLangParser.MetaDataContext) -> dict[str, ValueType]:
        meta_ = {}
        if (c_:=ctx.metaBlock()):
            meta_.update(self.visit(c_))
        return meta_


    def visitLogicExpr(self, ctx:SCLangParser.LogicExprContext) -> LogicExpr:
        mod_ = None
        if (c_:=ctx.logicMod()):
            d_ = {
                '~': LogicModifier.NOT,
            }
            mod_ = d_[self.visit(c_)]

        term_ = self.visit(ctx.logicTerm())
        
        return LogicExpr(term=term_, modifier=mod_)


    def visitLogicTerm(self, ctx:SCLangParser.LogicTermContext) -> Condition:
        if (c_:=ctx.logicGroup()):
            return self.visit(c_)
        elif (c_:=ctx.attrib()):
            attr1_ = self.visit(c_)
            if (cc_:=c_.comparison()):
                cmp_op_, attr2_ = self.visit(cc_)
                return Comparison(attr1_, cmp_op, attr2_)
            else:
                return attr1_


    def visitLogicGroup(self, ctx:SCLangParser.LogicGroupContext) -> Condition:
        return self.visit(ctx.logicOr())


    def visitLogicOr(self, ctx:SCLangParser.LogicOrContext) -> Condition:
        return self.visit(ctx.logicAnd())


    def visitLogicAnd(self, ctx:SCLangParser.LogicAndContext) -> Condition:
        return self.visit(ctx.logicExpr())


    def visitLogicMod(self, ctx:SCLangParser.LogicModContext) -> str:
        return ctx.getText()


    def visitComparison(self, ctx:SCLangParser.ComparisonContext) -> tuple[Comparator, Attrib]:
        return (self.visit(ctx.compareOp),
                self.visit(ctx.attrib())


    def visitCompareOp(self, ctx:SCLangParser.CompareOpContext) -> Comparator:
        d_ = {
            '!=': Comparator.NOT_EQUAL,
            '<': Comparator.LESS,
            '<=': Comparator.LESS_EQUAL,
            '>': Comparator.GREATER,
            '>=': Comparator.GREATER_EQUAL,
            '==': Comparator.EQUAL,
        }
        t_ = ctx.getText()
        return d_[t_]


    def visitEntityRef(self, ctx:SCLangParser.EntityRefContext) -> Ref[Entity]:
        return self.visit(ctx.getChild(0))


    def visitEntityDef(self, ctx:SCLangParser.EntityDefContext) -> Ref[Entity]:
        class_ref_, ent_name_ = self.visit(ctx.entityDecl())
        
        condition_ = None
        if (c:=ctx.logicExpr()) is not None:
            condition_ = self.visit(c_)

        ent_ = Entity(ent_class=class_ref_, condition=condition_)
        self._ent_dict[ent_name_] = ent_

        return self._ent_dict[ent_name_]


    def visitEntityDecl(self, ctx:SCLangParser.EntityDeclContext) -> tuple[Ref[type], str]:
        return (self.visit(ctx.entityType()),
                self.visit(ctx.entity())


    def visitEntityType(self, ctx:SCLangParser.EntityTypeContext) -> Ref[type]:
        return self.visit(ctx.getChild(0))


    def visitEntityClass(self, ctx:SCLangParser.EntityClassContext) -> Ref[type]:
        t_ = ctx.getChild(2).getText()
        ref_: ConstRef[type] = ConstRef(val=d_[t_])
        self._ent_dict[t_] = ref_
        return ref_


    def visitEntityVar(self, ctx:SCLangParser.EntityVarContext) -> Ref[ValueType]:
        d_ = {
            'Bool': bool,
            'Float': float,
            'Int': int,
            'String': str,
        }
        t_ = self.visit(ctx.dataType())
        ref_: ConstRef[type] = ConstRef(val=d_[t_])
        self._ent_type[t_] = ref_
        return ref_


    def visitDataType(self, ctx:SCLangParser.DataTypeContext) -> str:
        return ctx.getText()


    def visitAttrib(self, ctx:SCLangParser.AttribContext) -> Attrib:
        mod_ = None
        if (c_:=ctx.attribMod()) is not None:
            mod_ = self.visit(c_)

        ent_ = self.visit(ctx.entityRef())
        method_, args_ = self.visit(ctx.methodCall())

        return Attrib(entity=ent_, method=method_, args=args_, modifier=mod_)


    def visitMethodCall(self, ctx:SCLangParser.MethodCallContext) -> tuple[str, list[LogicExpr]]:
        method_ = self.visit(ctx.method())
        
        args_ = None
        if (c_:=ctx.argsWrapper()) is not None:
            args_ = self.visit(c_)

        return (method_, args_)


    def visitAttribMod(self, ctx:SCLangParser.AttribModContext) -> AttribModifier:
        d_ = {
            '!': AttribModifier.LIVE,
            '$': AttribModifier.PREV,
        }
        return d_[ctx.getText()]


    def visitArgsWrapper(self, ctx:SCLangParser.ArgsWrapperContext) -> list[Attrib]:
        return self.visit(ctx.argsList())


    def visitArgsList(self, ctx:SCLangParser.ArgsListContext) -> list[Attrib]:
        return self.visit(ctx.attrib())


    def visitMetaBlock(self, ctx:SCLangParser.MetaBlockContext) -> list[tuple[str, ValueType]]:
        return self.visit(ctx.metaEntry())


    def visitMetaEntry(self, ctx:SCLangParser.MetaEntryContext) -> tuple[str, ValueType]:
        return (self.visit(ctx.string()),
                self.visit(ctx.value()))


    def visitValue(self, ctx:SCLangParser.ValueContext) -> ValueType:
        return self.visit(ctx.getChild(0))


    def visitNumber(self, ctx:SCLangParser.NumberContext) -> float:
        return float(ctx.getText())


    def visitString(self, ctx:SCLangParser.StringContext) -> str:
        return eval(ctx.getText())


    def visitBool(self, ctx:SCLangParser.BoolContext) -> bool:
        return ctx.getText() == 'True'


    def visitId(self, ctx:SCLangParser.IdContext) -> str:
        return ctx.getText()


    def visitMethod(self, ctx:SCLangParser.MethodContext) -> str:
        return ctx.getText()

In [5]:
from loader import load_script

script = load_script("conditions.scl")