In [1]:
from sympy import *
init_printing()

In [2]:
class Compartment(Function):
    nargs = 1

    def __str__(self):
        return f'[{self.args[0]}]'

    def _sympystr(self, printer=None):
        return f'[{self.args[0]}]'

    def _latex(self, printer=None):
        return '\\left[' + printer.doprint(self.args[0]) + '\\right]'

class ContentChange(Function):
    nargs = 2

    def __str__(self):
        return f'{self.args}'

    def _sympystr(self, printer=None):
        return f'{self.args}'

    def _latex(self, printer=None):
        return printer.doprint(self.args)


class Context:

    def __init__(self, numSpecies: int):
        self.D = numSpecies
        self.gamma = IndexedBase('\gamma', integer=True, shape=self.D)

    def compartment(self, *args):
        return Compartment(*args)

    def change(self, *args):
        return ContentChange(*args)

    def __str__(self):
        return f'Context({self.D})'

    def doit(self, transition):
        self.deltaM(transition)

    def deltaM(self, transition):
        expr = transition.lhs
        if expr.func == Compartment:
            content = expr.args[ 0 ]
            species = self.getContentPerSpecies(content);
            self.getGamma(species)
        else:
            print("it's not a compartment")
            print(expr.func)
            print(type(expr.func))


    def getContentPerSpecies(self, content):
        """Get an array of scalars representing compartment content for species 0..D"""

        if content.func == Add:
            xs = [self.getContentPerSpecies(arg) for arg in content.args]
            return [Add(*x) for x in zip(*xs)]
        elif content.func == Mul:
            xs = [self.getContentPerSpecies(arg) for arg in content.args]
            return [Mul(*x) for x in zip(*xs)]
        elif content.func == IndexedBase:
            return [content[i] for i in range(self.D)]
        elif content.func == ContentChange:
            return [content.args[i] for i in range(self.D)]
        elif content.func == Integer:
            return [content] * self.D
        else:
            print(type(content))
            print("...hrmm")


    def getGamma(self, content):
        expr = Mul(*[content[i] ** self.gamma[i] for i in range(self.D)])
        display(expr)






class Transition(Basic):
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs

    def __str__(self):
        return f'{self.lhs} ---> {self.rhs})'

    def _latex(self, printer=None):
        # Always use printer.doprint() otherwise nested expressions won't
        # work. See the example of ModOpWrong.
        l = printer.doprint(self.lhs)
        r = printer.doprint(self.rhs)
        return l + '\longrightarrow{}' + r

In [3]:
C = Context(2)

X = IndexedBase('X', integer=True, shape=C.D)
Y = IndexedBase('Y', integer=True, shape=C.D)
Exit = Transition(C.compartment(Y) + C.compartment(X), EmptySet())
display(Exit)
C.doit(Exit)

[X] + [Y] ---> EmptySet())

it's not a compartment
<class 'sympy.core.add.Add'>
<class 'sympy.core.assumptions.ManagedProperties'>
