In [165]:
import PIL
from PIL import Image
import requests

from transformers import CLIPProcessor, CLIPModel
import torch

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)


class Expr:
    pass


class OpaqueFunc:
    def __init__(self, name, source):
        self.name = name
        self.source = source

    def __call__(self, *args):
        return self.func(*args)

    def __repr__(self):
        return f"OpaqueFunc({self.name=})"

from typing import List

class Reference:
    def __init__(self, *names):
        self.names = list(names)

    def __repr__(self):
        return f"Reference({self.names=})"

class UDFCall(Expr):
    def __init__(self, name : Reference, args : List[Expr]):
        self.name = name
        self.args = args

    def __str__(self):
        return f"{self.name}({', '.join(map(str, self.args))})"

    def __repr__(self):
        return f"UDFCall({self.name=},\n{self.args=})"

class ExprTemplate:
    def __init__(self, params, expr : Expr):
        self.params = params
        self.expr = expr

    def __to_expr__(self):
        return self.expr

    def __repr__(self):
        return f"ExprTemplate({self.params=}, {self.expr=})"

class Constant(Expr):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return str(self.value)

    def __repr__(self):
        return f"Constant({self.value})"


class Parameter(Expr):
    def __init__(self, name):
        """ placeholder for expression used passed on to template """
        self.name = name

    def __str__(self):
        return self.name

    def __repr__(self):
        return f"Parameter({self.name})"

class DictLookup(Expr):
    def __init__(self, obj, attr):
        self.obj = obj
        self.attr = attr

    def __str__(self):
        return f"{self.obj}.{self.attr}"

    def __repr__(self):
        return f"DictLookup({self.obj=}, {self.attr=})"


def simple_fun(image : PIL.Image.Image) -> torch.Tensor:
    image_tensor = processor.image_processor(image=image, return_tensors='pt')['pixel_values']
    return model.get_image_features(image_tensor)

def complex_fun(image : PIL.Image.Image) -> torch.Tensor:
    image_tensor = processor.image_processor(image=image, return_tensors='pt')['pixel_values']
    while True:
        image_tensor = model.get_image_features(image_tensor)

In [124]:
## manual version of making a template for simple_fun
params = [Parameter('image')]
udf1 = UDFCall(Reference('processor', 'image_processor'), [params[0]])
tmp1 = DictLookup(udf1, 'pixel_values')
udf2 = UDFCall(Reference('model', 'get_image_features'), [tmp1])
ExprTemplate(params=params, expr=udf2)

ExprTemplate(self.params=[Parameter(image)], self.expr=UDFCall(self.name=Reference(self.names=['model', 'get_image_features']),
self.args=[DictLookup(self.obj=UDFCall(self.name=Reference(self.names=['processor', 'image_processor']),
self.args=[Parameter(image)]), self.attr='pixel_values')]))

In [166]:
def udf(fun):
    from collections import deque
    import inspect
    import dis # disassembler
    """
        Decorator for UDF.

        If UDF is simple, eg.
            image or string preproc followed by model call.
              convert it to an Expression template

        Otherwise, eg  bc of complex control flow, bail out by simply registering it as a function.
    """

    supported = True

    sig = inspect.signature(fun)
    bcode = dis.Bytecode(fun.__code__)
    stack = deque()
    store = { k: Parameter(k) for k in sig.parameters.keys() }

    for instr in bcode:
        print(instr.opname, instr.argval)
        if instr.opname == "LOAD_GLOBAL":
            stack.append(Reference(instr.argval))
        elif instr.opname == "LOAD_ATTR" : # getattr
            stack[-1].names.append(instr.argval)
        elif instr.opname == "LOAD_FAST": # load local var
            stack.append(store[instr.argval])
        elif instr.opname == "STORE_FAST": # store local var
            store[instr.argval] = stack.pop()
        elif instr.opname == "LOAD_CONST": # constant val
            stack.append(instr.argval)
        elif instr.opname == "CALL_FUNCTION_KW": # call function
            kwargs = []
            key_words = deque(stack.pop())
            while True:
                top = stack.pop()
                if isinstance(top, Reference):
                    kwargs.reverse()
                    stack.append(UDFCall(top, kwargs))
                    break
                else:
                    keyword = key_words.pop()
                    kwargs.append((keyword,top))
        elif instr.opname == "BINARY_SUBSCR": # dictionary lookup
            attr = stack.pop()
            dc = stack.pop()
            stack.append(DictLookup(dc, attr))
        elif instr.opname == "LOAD_METHOD":
            ref = stack.pop()
            assert isinstance(ref, Reference)
            ref.names.append(instr.argval)
            stack.append(ref)
        elif instr.opname == "CALL_METHOD":
            args = []
            while True:
                top = stack.pop()
                if isinstance(top, Reference):
                    stack.append(UDFCall(top, args))
                    break
                else:
                    args.append(top)
        elif instr.opname == "RETURN_VALUE":
            break
        else:
            # assert False, f'{instr.opname=}, {instr.argval=}'
            supported = False
            break  # bail out, some operation is not supported.

    if supported:
        return ExprTemplate(params=params, expr=stack.pop())
    else:
        return OpaqueFunc(fun.__code__.co_name, inspect.getsource(fun))

In [167]:
@udf
def simple_fun2(image : PIL.Image.Image) -> torch.Tensor:
    image_tensor = processor.image_processor(image=image, return_tensors='pt')['pixel_values']
    return model.get_image_features(image_tensor)

LOAD_GLOBAL processor
LOAD_ATTR image_processor
LOAD_FAST image
LOAD_CONST pt
LOAD_CONST ('image', 'return_tensors')
CALL_FUNCTION_KW 2
LOAD_CONST pixel_values
BINARY_SUBSCR None
STORE_FAST image_tensor
LOAD_GLOBAL model
LOAD_METHOD get_image_features
LOAD_FAST image_tensor
CALL_METHOD 1
RETURN_VALUE None


In [169]:
simple_fun2

ExprTemplate(self.params=[Parameter(image)], self.expr=UDFCall(self.name=Reference(self.names=['model', 'get_image_features']),
self.args=[DictLookup(self.obj=UDFCall(self.name=Reference(self.names=['processor', 'image_processor']),
self.args=[('image', Parameter(image)), ('return_tensors', 'pt')]), self.attr='pixel_values')]))

In [162]:
udf(complex_fun)

LOAD_GLOBAL processor
LOAD_ATTR image_processor
LOAD_FAST image
LOAD_CONST pt
LOAD_CONST ('image', 'return_tensors')
CALL_FUNCTION_KW 2
LOAD_CONST pixel_values
BINARY_SUBSCR None
STORE_FAST image_tensor
LOAD_GLOBAL model
LOAD_METHOD get_image_features
LOAD_FAST image_tensor
CALL_METHOD 1
STORE_FAST image_tensor
JUMP_ABSOLUTE 18


OpaqueFunc(self.name='complex_fun')