In [9]:
class RestorableState:
    def __init__(self, context, data, cm_raises_exception):
        self.data = data
        self.restore_data = None
        self.context = context
        self.cm_raises_exception = cm_raises_exception
        
    def __enter__(self):
        if self.cm_raises_exception:
            ex, *args = self.cm_raises_exception
            raise ex(*args)
        return self.context

    def __exit__(self, *args):
        print("restore data to ", self.restore_data)
        self.data = self.restore_data
        return False

#     def __call__(self, *args):
#         print("set restore data")
#         self.restore_data = self.data
#         self.data = args
#         return self #.color

    def __repr__(self):
        return f"{self.__class__.__name__}: {self.data}"

    
def normalize_color(args):
    if args is None:
        return None
    
    if len(args) == 1:
        c = args[0]
        return (c, c, c, 1)
    elif len(args) == 2:
        return (args[0], args[1], 1, 1)
    elif len(args) == 3:
        return (*args, 1)
    else:
        return args

    
# class ColorState(RestorableState):
#     def __init__(self, context, color, cm_raises_exception=None):
#         RestorableState.__init__(self, context, color, cm_raises_exception=cm_raises_exception)
        
#     def __call__(self, *args):
#         color = normalize_color(args)
#         return super().__call__(color)


class Color(RestorableState):
    def __init__(self, context, color, cm_raises_exception=None):
        RestorableState.__init__(self, context, color, cm_raises_exception=cm_raises_exception)
        
    def copy(self, *args):
        return Color(self.context, args, self.cm_raises_exception)
#     def __call__(self, *args):
#         color = normalize_color(args)
#         return super().__call__(*color)

    

IDENTITY_MATRIX = [
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
]
    
class Transform(RestorableState):
    def __init__(self, context, matrix=None):
        if matrix is None:
            matrix = IDENTITY_MATRIX
        RestorableState.__init__(self, context, matrix, cm_raises_exception=False)
        
    def skew(self, x=0, y=0):
        if (x, y) == (0, 0):
            return self
        return self
    
    def rotate(self, degrees):
        return self
    
    def push(self, t):
        self.__call__(t)
        return self.__enter__()
    
    def pop(self):
        return self.__exit__()

class Context:
    def __init__(self):
        self._fill = Color(self, (1, 1, 1, 1))
        self._stroke = Color(self, (0, 0, 0, 1))
        self._transform = Transform(self)
        
    def fill(self, *args):
        if not len(args):
            return self._fill
        self._fill = self._fill.copy(args)
        return self._fill
    
    def stroke(self, *args):
        return self._stroke.copy(args)
        
#     def move(self, x, y):
#         self.transform.move_to(x, y)
    
#     def push(self):
#         return self.transform.push()
    
#     def pop(self):
#         return self.transform.pop()
    
#     def skew(self, *args):
#         return self.transform.skew(*args)
        

ctx = Context()

print("\n\n\n")
ctx.fill(.1, .1, .1, 1)

print(ctx.fill(0))
with ctx.fill(1, 0, 0, 1), ctx.stroke(0):
    print(ctx.fill(.1, .2, .3, 1.), ctx.stroke(.9, .9, .9, 1.))
    
print("-- back to normal")
print(ctx.fill())





Color: ((0,),)
Color: ((0.1, 0.2, 0.3, 1.0),) Color: ((0.9, 0.9, 0.9, 1.0),)
restore data to  None
restore data to  None
-- back to normal
Color: ((0.1, 0.2, 0.3, 1.0),)
