In [1]:
# TODO: Different registry for different types of things... generators, steps, noise

import numpy as np

from enum import Enum
from types import SimpleNamespace

METHOD_REGISTRY = {}

def register_method(method_name: str) -> callable:
    # TODO: Error handling if method is already in the registry; prevent overwrite
    # TOOD: Add to specific registry? Or just different decorators for each?
    def decorator(method: callable) -> callable:
        METHOD_REGISTRY[method_name] = method
        return method
    return decorator


# def get_method(method_name: str):
#     if not method_name in METHOD_REGISTRY:
#         raise Exception(
#             "Method {method_name} not in METHOD_REGISTRY. Currently registered methods: {METHOD_REGISTRY}")

#     method = METHOD_REGISTRY[method_name]
#     return method

# MethodRegistry = enum.Enum('MethodRegistry', {'foo':42, 'bar':24})

# NO_DATASET_ERR = "Dataset {} not in DATASET_REGISTRY! Available datasets are {}"
# DATASET_REGISTRY = {}


# def RegisterDataset(dataset_name):
#     """Registers a dataset."""

#     def decorator(f):
#         DATASET_REGISTRY[dataset_name] = f
#         return f

#     return decorator


# def get_dataset_class(args):
#     if args.dataset not in DATASET_REGISTRY:
#         raise Exception(
#             NO_DATASET_ERR.format(args.dataset, DATASET_REGISTRY.keys()))

#     return DATASET_REGISTRY[args.dataset]


# def RegisterPool(pool_name):
#     """Registers a pool."""

#     def decorator(f):
#         POOL_REGISTRY[pool_name] = f
#         return f

#     return decorator

# def get_pool(pool_name):
#     """Get pool from POOL_REGISTRY based on pool_name."""

#     if not pool_name in POOL_REGISTRY:
#         raise Exception(NO_POOL_ERR.format(
#             pool_name, POOL_REGISTRY.keys()))

#     pool = POOL_REGISTRY[pool_name]

#     return pool

In [2]:
class StaticEnvironment:
    def __init__(self, params: dict) -> None:
        self.params = SimpleNamespace(**params)
        
    def _noise(self):
        return np.random.normal(loc=0, scale=self.params.y_star_std)
    
    def _ge(self, x_star: float) -> float:
        return self.params.beta_1_star * x_star + self.params.beta_0_star
    
    def generate(self, x_star: float) -> float:
        return self._ge(x_star) + self._noise()

In [3]:
@register_method("noise__zero_centered")
def noise(self) -> float:
        return np.random.normal(loc=0, scale=self.params.y_star_std)
    
@register_method("generating_function__linear")
def ge(self, x_star: float) -> float:
    return self.params.beta_1_star * x_star + self.params.beta_0_star

@register_method("data_generator__noisy")
def generate(self, x_star: float) -> float:
    return self.ge(x_star) + self.noise()

In [4]:
MethodRegistry = Enum("MethodRegistry", METHOD_REGISTRY) 

In [20]:
class StaticEnvironment:
    def __init__(self, params: dict) -> None:
        self.params = SimpleNamespace(**params)

env_params = {
    "beta_0_star" : 3,    # Linear parameter intercept
    "beta_1_star" : 2,    # Linear parameter slope
    "y_star_std"  : 0.5   # Standard deviation of sensory data
}


StaticEnvironment.noise    = MethodRegistry.noise__zero_centered
StaticEnvironment.ge       = MethodRegistry.generating_function__linear
StaticEnvironment.generate = MethodRegistry.data_generator__noisy

In [None]:
# Initialize the environment and support of x
env       = StaticEnvironment(env_params)
x_range   = np.linspace(start=0.01, stop=5, num=500)

# Generate data
y = np.zeros(x_range.shape[0])

for idx, x in enumerate(x_range):
    y[idx] = env.generate(x_star=x)

## Build method using string name

In [9]:
def add_method(method_name: str):
    # TODO: When there are multiple registries, loop through each registry and try to find it. 
    # TODO: Deal with error handling in case of multiple registries to check through
    # TODO: Consider calling this "add___" where the blank is some noise, generating function, etc?
    # if not method_name in METHOD_REGISTRY:
    #     raise Exception(
    #         "Method {method_name} not in METHOD_REGISTRY. Currently registered methods: {METHOD_REGISTRY}")

    # method = METHOD_REGISTRY[method_name]
    method = getattr(MethodRegistry, method_name)
    return method

In [10]:
class StaticEnvironment:
    def __init__(self, params: dict) -> None:
        self.params = SimpleNamespace(**params)

env_params = {
    "beta_0_star" : 3,    # Linear parameter intercept
    "beta_1_star" : 2,    # Linear parameter slope
    "y_star_std"  : 0.5   # Standard deviation of sensory data
}

StaticEnvironment.noise    = add_method("noise__zero_centered")
StaticEnvironment.ge       = add_method("generating_function__linear")
StaticEnvironment.generate = add_method("data_generator__noisy")

In [11]:
# Initialize the environment and support of x
env       = StaticEnvironment(env_params)
x_range   = np.linspace(start=0.01, stop=5, num=500)

# Generate data
y = np.zeros(x_range.shape[0])

for idx, x in enumerate(x_range):
    y[idx] = env.generate(x_star=x)

## Scratch

In [6]:
getattr(MethodRegistry, 'noise__zero_centered')

<function __main__.noise(self) -> float>

In [25]:
MethodRegistry["noise__zero_centered"].value

KeyError: 'noise__zero_centered'

In [28]:
MethodRegistry.keys()

AttributeError: keys

In [21]:
# # Initialize environment
# env          = StaticEnvironment(env_params)

# # Populate environment with class methods
# env.noise    = MethodRegistry.noise__zero_centered
# env.ge       = MethodRegistry.generating_function__linear
# env.generate = MethodRegistry.data_generator__noisy