In [18]:
import importlib.util
import sys
import inspect
import pkgutil
import ast
import os
import yaml
from collections import namedtuple
from dioptra.sdk.utilities.contexts import plugin_dirs
from pathlib import Path
with plugin_dirs(["/c/users/jtsexton/documents/github/dioptra/task-plugins", "/c/users/jtsexton/documents/github/dioptra/examples"]):
    import dioptra_builtins

def import_file_as_module(filepath):
    spec = importlib.util.spec_from_file_location("tmp_module", filepath)
    module = importlib.util.module_from_spec(spec)
    sys.modules["tmp_module"] = module
    spec.loader.exec_module(module)
    return module

def walk_modules(module, st):
    subs = submodules(module)
    if (subs == []):
        st += [module.__name__]
    for i in subs:
        m = importlib.import_module(module.__name__ + '.' + i.name)
        st += walk_modules(m, st)
    return set(st)

def is_mod_function(mod, func):
    return inspect.isfunction(func) and inspect.getmodule(func) == mod

def list_functions(mod):
    return [func.__name__ for func in mod.__dict__.values() 
            if is_mod_function(mod, func)]
    
def generate_stubs(module, plugin_path):
    modules = walk_modules(module, [])
    for m in modules:
        filepth = m.split('.')
        mod = importlib.import_module(m)
        build_stub(mod, plugin_path)
        #write_to_file(fcre, build_stub(mod))

def build_stub(module, plugin_path):
    return generate_functions(module, plugin_path)

def generate_functions(module, plugin_path):
    #print(module)
    ret = []
    fs = list_functions(module)
    for fname in fs:
        #print(module)
        if not fname.startswith('_'):
            func = getattr(module, fname)
            build_task(func, plugin_path)
    return ret
    
def get_function_parameter_names(f):
    #print(getattr(f, "__annotations__", None))
    #print(f.__decorators__)
    argspec = inspect.getfullargspec(f)
    #print(argspec.annotations)

def submodules(m):
    try:
        return pkgutil.iter_modules(m.__path__)
    except:
        return []
        
def get_ret_name():                                                                                                          
    def extract_return(file, fname):                                                                            
        for x in ast.walk(ast.parse(open(file).read())):                                                        
            if not(isinstance(x, ast.FunctionDef)):                                                             
                continue                                                                                        
            if not(x.name == fname):                                                                            
                continue                                                                                        
            for b in x.body:                                                                                    
                if isinstance(b, ast.Return):                                                                   
                    if isinstance(b.value, ast.Name):                                                           
                        yield b.value.id                                                                        
default_types = ["integer", "string", "any", "number"]
defined_types = {}
def create_yml_type_if_not_exists(t):
    global default_types
    global defined_types
    if not isinstance(t,str):
        t = t.__name__ if hasattr(t, '__name__') else str(t)
        #print("here..")
        t = t.replace("typing.", "")
        t = t.replace("pathlib.", "")
        #print(t)
    if (t is None): 
        return ["null"], "null"
    t = t.strip()
    #print("type call:" + t)
    type_lookup = {"int":"integer", "str": "string", "float":"number", "Any":"any"}
    if t in type_lookup.keys():
        return [type_lookup[t]], type_lookup[t]
    elif t in defined_types.keys():
        return [defined_types[t]], t
    else:
        storage_name = ""
        unwrapped = t
        if (t.startswith('Optional')):
            storage_name = "union_null_"
            unwrapped = unwrap(t, 'Optional')
            #print("call1:" + str(t))
            m, name = create_yml_type_if_not_exists(unwrapped)
            storage_name += name
            t_out = {"union": [m, "null"]}
        elif (t.startswith('Dict')):
            storage_name = "mapping_" 
            unwrapped = unwrap(t, 'Dict')
            #print("call2:" + str(t))
            m, name = create_yml_type_if_not_exists(unwrapped)
            #print("m:" + str(m), "name:" + str(name))
            storage_name += name
            t_out = {"mapping": [z for z in m]}
            #print("t_out", t_out)
        elif (t.startswith('Tuple')):
            storage_name = "tuple_" 
            unwrapped = unwrap(t, 'Tuple')
            #print("call2:" + str(t))
            m, name = create_yml_type_if_not_exists(unwrapped)
            storage_name += name
            t_out = {"tuple": [z for z in m]}
        elif (t.startswith('Union')):
            storage_name = "union_"
            unwrapped = unwrap(t, 'Union')
            #print("call3:" + str(t))
            m, name = create_yml_type_if_not_exists(unwrapped)
            storage_name += name
            t_out = {"union": [z for z in m]}
        elif (t.startswith('List')):
            storage_name = "list_"
            unwrapped = unwrap(t, 'List')
            #print("call4:" + str(t))
            m, name = create_yml_type_if_not_exists(unwrapped)
            #print("m:" + str(m), "name:" + str(name))
            storage_name += name
            t_out = {"list": m}
        else:
            if ',' in unwrapped and (unwrapped.find(',') < unwrapped.find('[') or unwrapped.find('[') < 0):
                types = unwrapped.split(', ')
                first_in_list = unwrapped[:unwrapped.find(',')]
                rest_of_list = unwrapped[unwrapped.find(', ')+1:]
                #print("call5:" + str(t))
                first_t, first_name = create_yml_type_if_not_exists(first_in_list)
                #print("call6:" + str(t))
                rest_t, rest_name = create_yml_type_if_not_exists(rest_of_list)
                #print(rest_t, rest_name)
                ret1= first_t + rest_t
                ret2= storage_name + first_name + '_' + rest_name
                #print("rets:", ret1, ret2)
                return ret1, ret2
            else:
                t_out = t.lower()
                storage_name=t.lower()
                #print("ERROR:" + unwrapped)
        storage_name.replace('.', '_')
        defined_types[storage_name] = t_out
        return [t_out], storage_name
def unwrap(s, s2):
    return s[len(s2 + '['):-1]
def generate_yml_func(): 
    task_list = {}
    ret = {"tasks":task_list}
def get_plugin_name(f):
    return f.__name__
def get_plugin_inputs(f):
    argspec = inspect.getfullargspec(f)
    #print(argspec)
    defaults = argspec.defaults
    #print(defaults)
    num_def = len(defaults) if defaults is not None else 0
    #print(num_def)
    annotations = argspec.annotations
    annotations = {i:annotations[i] for i in annotations if i!='return'}
    num_param = len(annotations) if annotations is not None else 0
    ret = []
    for m in annotations:
        #print( "mastercall:" + str(m), str(annotations[m]))
        q, type_name = create_yml_type_if_not_exists(annotations[m])
        #print("type:" + str(q))
        if (num_param - num_def > 0):
            ret += [{m:type_name}]
        else:
            ret += [{"name":m, "type":type_name, "required": False}]
        num_param-=1
    return ret
def get_plugin_outputs(f):
    argspec = inspect.getfullargspec(f)
    #print(argspec)
    returns = argspec.annotations['return']
    ret = []
    #print(returns)
    if not isinstance(returns, str):
        returns = returns.__name__ if hasattr(returns, '__name__') else str(returns)
        returns = returns.replace("typing.", "")
        returns = returns.replace("pathlib.", "")

    if returns.startswith("Tuple"):
        returns = unwrap(returns, "Tuple")
        returns = returns.split(', ')
        m = 0
        for b in returns:
            m += 1
            q, type_name = create_yml_type_if_not_exists(b)    
            ret += [{"ret" + str(m):type_name}]
    else:
        q, type_name = create_yml_type_if_not_exists(returns)
        ret = {"ret":type_name}
    return ret
def build_task(f, plugin_path):
    f = inspect.unwrap(f)
    argspec = inspect.getfullargspec(f)
    if (len(argspec.annotations) > 0):
        ret = {}
        fdesc = {}
        fdesc["plugin"] = plugin_path + '.' + get_plugin_name(f)
        fdesc["inputs"] = get_plugin_inputs(f)
        fdesc["outputs"] = get_plugin_outputs(f)
        if (fdesc["inputs"] == []):
            del fdesc["inputs"]
        #print (type(fdesc["outputs"]))
        if (isinstance(fdesc["outputs"], dict) and fdesc["outputs"]["ret"] == "none"):
            del fdesc["outputs"]
        ret[f.__name__] = fdesc
        print(yaml.dump(ret))
    
def walk_example_plugins(directory):
    import os
    for root, dirs, files in os.walk(directory, topdown=False):
        for name in files:
            if (name.endswith(".py")):
                path = os.path.join(root, name)
                plugin_path = path.replace('.py','')
                plugin_path = plugin_path.replace(directory, '')
                plugin_path = plugin_path.replace('/', '.')
                #print(plugin_path)
                if (".venv" not in path and '.ipynb' not in path):
                    try: 
                        mod = import_file_as_module(path)
                        generate_stubs(mod, plugin_path)
                    except:
                        pass
        #for name in dirs:
        #print(os.path.join(root, name))

def generate_yml_defs(directory):
    pass
walk_example_plugins("/c/users/jtsexton/documents/github/dioptra/examples/task-plugins/")
#generate_stubs(dioptra_builtins)

create_adversarial_fgm_dataset:
  inputs:
  - data_dir: string
  - adv_data_dir: union_string_path
  - keras_classifier: kerasclassifier
  - image_size: tuple_integer_integer_integer
  - name: distance_metrics_list
    required: false
    type: union_null_list_tuple_string_callable[..., np.ndarray]
  - name: rescale
    required: false
    type: number
  - name: batch_size
    required: false
    type: integer
  - name: label_mode
    required: false
    type: string
  - name: eps
    required: false
    type: number
  - name: eps_step
    required: false
    type: number
  - name: minimal
    required: false
    type: number
  - name: norm
    required: false
    type: number
  - name: target_index
    required: false
    type: integer
  - name: targeted
    required: false
    type: bool
  outputs:
    ret: pd.dataframe
  plugin: dioptra_custom.custom_fgm_plugins.attacks_fgm.create_adversarial_fgm_dataset

create_image_dataset:
  inputs:
  - data_dir: string
  - subset: union_null_st

In [None]:
import tensorflow

2023-09-19 11:00:31.754524: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-19 11:00:31.986232: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-19 11:00:33.291342: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-09-19 11:00:33.295364: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
