# Python Decorators for Machine Learning

## Table of Contents

### Relevant Concepts

* Variable Arguments
* Functions are first class Objects
* Closures

### Python Decorators

#### What is a Decorator
#### Implementing Decorators
* Function-based Decorators
* Class-based Decorators

#### Special Topics
* Stateful Decorators
* Decorating Functions With Arguments
* Introspection of Decorated Functions

### Real World Examples for ML

#### Data Validation
#### Model Registry
#### Feature Engineering


### Further Reading

* https://realpython.com/primer-on-python-decorators/#returning-functions-from-functions
* https://github.com/lord63/awesome-python-decorator

### Variable Arguments

* Positional Arguments - the *args idiom 
  - args is a tuple
* Keywords Arguments - the **kwargs idiom
  - kwargs is a dictionary
  
Notes:
  - Both idioms can be used for defining and calling functions
  - args comes before kwargs
  - calling function by key-value pairs always works

In [None]:
# define a function - parking arguments into a tuple
def read_data_files(*paths, format):
    data = ""
    for path in paths:
        with open(path) as handle:
            if format == "csv":
                data += handle.read_csv()
            elif format == "json":    
                data += handle.read_jason()                
    return data

# calling a function - arguments unparking
csv_files = ("file1.csv", "file2.csv", "file3.csv")
read_data_files(*paths, format="csv")

In [None]:
# define a function - parking arguments into a dict
def set_config_defaults(config, **kwargs):
    for key, value in kwargs.items():
        if key not in config:
            config[key] = value
            
# calling a function - arguments unparking
config = {"target": "bad", "verbosity": 3}
set_config_defaults(config, random_state=1234)

In [None]:
# the most general way to define a function
def general_function(*args, **kwargs):
    for arg in args:
        print(arg)
    for key, value in kwargs.items():
        print("{} -> {}".format(key,value))
        

### Functions are first class Objects

Functions have attributes, such as "__name__", and can be

* assigned to a variable,
* passed as an argument, and
* returned from other functions

In [3]:
def inner():
    print ("This is an inner function")

In [4]:
inner.__name__

'inner'

In [7]:
# In this case, the variable y is bound to the same object as inner
y = inner
y()

This is an inner function


In [1]:
def outer(x): 
    def inner(y): 
        return x + y 
    return inner

add15 = outer(15) 
    
print(add15(10))

25


### Closures

* Closure in Python is an inner function object that remembers and has access to variables in the local scope in which it was created even after the outer function has finished executing.
* Another way of relating to a python closure is: Objects can be described as data with methods attached, while closures are functions with data attached.

In [8]:
## add15 is a closure
add15.__closure__
add15.__closure__[0].cell_contents

15

Characteristics of a python closure function

* It is a nested or inner function
* The closure will have access to a 'free' variable that is in encolsing (outer) scope
* It will be returned from the enclosing (outer) function.

### What is a Decorator

* A decorator in Python is any callable object that is used to modify a function or a class.
* It works by passing a reference to the function "func" or the class "cls" to a decorator that returns a modified function or class. The modified functions or classes usually contain calls to the original function or class.
* Decorators in Python are applied to a group of functions or classes and have a purpose to augment their behavior.
* A tool used to build a extensible software frameworks

### Function-based Decorators 

* Decorators often implemented as a (high order) function that takes a function as argument and returns a dfferent (decorated) function
* It works by adding lines of code around a function which are executed before or after the function call
* The decorator function is executed once for each function it decorates.

In [None]:
@some_decorator
def some_function(arg):
    # some code

Note: Python translates this into

* some_function = some_decorator(some_function)

Examples

In [9]:
# decorator function
def printlog(func):
    def wrapper(*args, **kwargs):
        print("CALLING: " , func.__name__)
        return func(*args, **kwargs)
    return wrapper

@printlog
def data_ingest():
    # do something
    pass

In [None]:
# Execution Timer Decorator
def measure_time(func):
    def wrapper(*args, **kwargs):
        from time import time
        start = time()
        result = func(*args, **kwargs)
        print(f'Elapsed time is {time() - start} ms')
        return result
    return wrapper

In [7]:
data_ingest()

CALLING: data_ingest


In [29]:
from time import time

def timeit(func):
    def wrapper(*args, **kwargs):
        start = time()
        value = func(*args, **kwargs)
        print(f'Elapsed time is {time() - start} ms')
        return value 
    return wrapper

@timeit
def any_func():
    count = 0
    for number in range(100):
        count += number

any_func()

Elapsed time is 1.0967254638671875e-05 ms


In [30]:
from datetime import datetime

def logger(func):
    def wrapper(*args, **kwargs):
        print('_' * 25)
        print(f'Run on: {datetime.today().strftime("%Y-%m-%d %H:%M:%S")}')
        print(func.__name__)
        func(*args, **kwargs)
        print('_' * 25)
    return wrapper

In [31]:
@logger
def shutdown():
    pass

@logger
def restart():
    pass

shutdown()
restart()

_________________________
Run on: 2022-04-05 20:46:31
shutdown
_________________________
_________________________
Run on: 2022-04-05 20:46:31
restart
_________________________


### Class-based Decorators

If a decorator is implemented as a class, 

* it needs to take func as an argument in its .__init__() method, and  
* the class instance needs to be callable, that is, you implement the special .__call__() method.

In [10]:
class PrintLog:
    def __init__(self, func):
        self.func = func
        
    def __call__(self, *args, **kwargs):
        print("CALLING: {}".format(self.func.__name__))
        return self.func(*args, **kwargs)

@PrintLog
def data_ingest():
    # do something
    pass

data_ingest()

Class-based decorators have a few advantages over function-based decorators:

* The decorator is a class, which means you can leverage inheritance if you have a family of related decorators.
* When you prefer to track state in object attributes, instead of a closure.

In [4]:
import functools

class CountCalls:
    def __init__(self, func):
        functools.update_wrapper(self, func)
        self.func = func
        self.num_calls = 0

    def __call__(self, *args, **kwargs):
        self.num_calls += 1
        print(f"Call {self.num_calls} of {self.func.__name__!r}")
        return self.func(*args, **kwargs)

@CountCalls
def run_model():
    pass

run_model()

Call 1 of 'run_model'


In [5]:
run_model()

Call 2 of 'run_model'


In [6]:
run_model.num_calls

2

Notes:

* The .__init__() method must store a reference to the function and can do any other necessary initialization. 
* The .__call__() method does essentially the same thing as the wrapper() function in our earlier examples.
* You need to use the functools.update_wrapper() function instead of @functools.wraps.

When creating class-based decorators that take arguments,

* the constructor accepts not the func object to be decorated, the argument(s)
* the __call__ method must take the func object, define a wrapper function, and return it.

In [12]:
class Add:
    def __init__(self, value):
        self.value = value

    def __call__(self, func):
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs) + self.value
        return wrapper

@Add(2)
def foo(x):
    return x ** 2

@Add(4)
def bar(n):
    return n*2

## Special Topics

### Stateful Decorators

* A decorator that can keep track of state
* Using variables inside the decorator function itself, this is not the same as using variables inside the wrapper function
* using function attributes

In [1]:
# keep tracking a running average of what a function returns and 
# to do this for a family of fuctions
def collect_stats(func):
    stats = {"metric": 0, "count": 0}
    def wrapper(*args, **kwargs):
        value = func(*args, **kwargs)
        stats["metric"] += value
        stats["count"] += 1
        return value
    wrapper.stats = stats
    return wrapper

@collect_stats
def model(auc):
    return auc

In [2]:
model.stats

{'metric': 0, 'count': 0}

In [3]:
model(0.8)

0.8

In [4]:
model.stats

{'metric': 0.8, 'count': 1}

### Decorators with Arguments

* A function that returns a decorator

In [5]:
def add(value):
    def decorator(func):
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs) + value
        return wrapper
    return decorator

# Notice the closure here

@add(2)
def foo(x):
    return x ** 2

@add(4)
def bar(n):
    return n*2
     

In [3]:
foo(2)

6

In [4]:
bar(4)

12

In [9]:
# An example decorator from the Flask framework
from flask import Flask

app = Flask(__name__)

@app.route("/")
def hello():
    return "<html><body>Hello, World!</body></html>"

### Introspection of Decorated Functions

* Introspection is the ability of an object to know about its own attributes at runtime. 
* For instance, a function knows its own name and documentation
* Use the @functools.wraps decorator to preserve information about the original function.

In [21]:
print.__name__

'print'

In [22]:
help(print)

Help on built-in function print in module builtins:

print(...)
    print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)
    
    Prints the values to a stream, or to sys.stdout by default.
    Optional keyword arguments:
    file:  a file-like object (stream); defaults to the current sys.stdout.
    sep:   string inserted between values, default a space.
    end:   string appended after the last value, default a newline.
    flush: whether to forcibly flush the stream.



In [37]:
def count_calls(func):
    def wrapper_count_calls(*args, **kwargs):
        wrapper_count_calls.num_calls += 1
        print(f"Call {wrapper_count_calls.num_calls} of {func.__name__!r}")
        return func(*args, **kwargs)
    wrapper_count_calls.num_calls = 0
    return wrapper_count_calls

@count_calls
def run_model():
    pass

In [38]:
run_model.__name__

'wrapper_count_calls'

In [39]:
help(run_model)

Help on function wrapper_count_calls in module __main__:

wrapper_count_calls(*args, **kwargs)



In [40]:
# To fix this using the @functools.wraps decorator

import functools

def count_calls(func):
    @functools.wraps(func)
    def wrapper_count_calls(*args, **kwargs):
        wrapper_count_calls.num_calls += 1
        print(f"Call {wrapper_count_calls.num_calls} of {func.__name__!r}")
        return func(*args, **kwargs)
    wrapper_count_calls.num_calls = 0
    return wrapper_count_calls

@count_calls
def run_model():
    pass

In [41]:
run_model.__name__

'run_model'

In [42]:
help(run_model)

Help on function run_model in module __main__:

run_model()



## Real World Examples for ML

### Creating Singletons

A singleton is a class with only one instance.
The following @singleton decorator turns a class into a singleton by storing the first instance of the class as an attribute. Later attempts at creating an instance simply return the stored instance:

In [1]:
import functools

def singleton(cls):
    """Make a class a Singleton class (only one instance)"""
    @functools.wraps(cls)
    def wrapper_singleton(*args, **kwargs):
        if not wrapper_singleton.instance:
            wrapper_singleton.instance = cls(*args, **kwargs)
        return wrapper_singleton.instance
    wrapper_singleton.instance = None
    return wrapper_singleton

@singleton
class TheModel:
    pass

### Model registry decorator

In [2]:
# Step 0: common use case
from sklearn import datasets
import pandas as pd
iris = datasets.load_iris(as_frame = True)

y = iris['target']
X = iris['data']

from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, ExtraTreesClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

xtrain, xtest, ytrain, ytest = train_test_split(X, y)

rf = RandomForestClassifier()
rf.fit(xtrain, ytrain)
y_pred = rf.predict(xtest)
print(accuracy_score(ytest, y_pred))

rf = AdaBoostClassifier()
rf.fit(xtrain, ytrain)
y_pred = rf.predict(xtest)
print(accuracy_score(ytest, y_pred))

rf = ExtraTreesClassifier()
rf.fit(xtrain, ytrain)
y_pred = rf.predict(xtest)
print(accuracy_score(ytest, y_pred))


# Step 1: create a custom function that will do the fitting/predicting 
def fitter(model, xtrain, xtest, ytrain, ytest): 
    md = model()
    md.fit(xtrain, ytrain)
    y_pred = md.predict(xtest)
    return accuracy_score(ytest, y_pred), y_pred, model


# Step 2: create a model registry decorator for tracking of fitted models
from functools import wraps 

def model_registry(function): 
    @wraps(function)
    def wrapper(*args, **kwargs): 
        score, prediction, fmodel = function(*args, **kwargs)
        registry[args[0].__name__] = {'score': score, 'prediction': prediction} 
        return score, prediction, fmodel
    return wrapper

@model_registry
def fitter(model, xtrain, xtest, ytrain, ytest): 
    md = model()
    md.fit(xtrain, ytrain)
    y_pred = md.predict(xtest)
    return accuracy_score(ytest, y_pred), y_pred, md

# run the decorated fitter
registry = {}

score, prediction, fmodel = fitter(RandomForestClassifier, xtrain, xtest, ytrain, ytest)
score, prediction, fmodel = fitter(ExtraTreesClassifier, xtrain, xtest, ytrain, ytest)
score, prediction, fmodel = fitter(AdaBoostClassifier, xtrain, xtest, ytrain, ytest)


1.0
0.9210526315789473
1.0


In [3]:
registry

{'RandomForestClassifier': {'score': 1.0,
  'prediction': array([1, 0, 0, 1, 0, 1, 0, 0, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 0, 2, 2, 0,
         2, 2, 0, 1, 2, 1, 2, 1, 2, 2, 1, 0, 1, 0, 2, 2])},
 'ExtraTreesClassifier': {'score': 1.0,
  'prediction': array([1, 0, 0, 1, 0, 1, 0, 0, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 0, 2, 2, 0,
         2, 2, 0, 1, 2, 1, 2, 1, 2, 2, 1, 0, 1, 0, 2, 2])},
 'AdaBoostClassifier': {'score': 0.9736842105263158,
  'prediction': array([1, 0, 0, 1, 0, 1, 0, 0, 2, 0, 2, 1, 1, 0, 0, 1, 2, 0, 0, 2, 2, 0,
         2, 2, 0, 1, 2, 1, 2, 1, 2, 2, 1, 0, 1, 0, 2, 2])}}

In [4]:
# Step 3: Modify the decorator definition to log the metrics of models into files
from contextlib import contextmanager

@contextmanager
def write_metrics(*args, **kwargs): 
    with open(*args, **kwargs) as f: 
        yield f 
        

from datetime import datetime
from sklearn.metrics import classification_report 
import json

def model_registry(function): 
    @wraps(function)
    def wrapper(*args, **kwargs): 
        score, prediction, fmodel = function(*args, **kwargs)
        registry[args[0].__name__] = {'score': score, 'prediction': prediction} 
        with write_metrics("{}_report.txt".format(datetime.today().strftime('%Y-%m-%d')), mode = 'w+') as f: 
            report = classification_report(args[4], prediction)
            f.write(args[0].__name__ + "\n" + report)
        return score, prediction, fmodel
    return wrapper

@model_registry
def fitter(model, xtrain, xtest, ytrain, ytest): 
    md = model()
    md.fit(xtrain, ytrain)
    y_pred = md.predict(xtest)
    return accuracy_score(ytest, y_pred), y_pred, md

registry = {}
score, prediction, fmodel = fitter(AdaBoostClassifier, xtrain, xtest, ytrain, ytest)

### Caching

In [6]:
from decorators import count_calls

@count_calls
def fibonacci(num):
    if num < 2:
        return num
    return fibonacci(num - 1) + fibonacci(num - 2)

fibonacci(10)

Call 1 of 'fibonacci'
Call 2 of 'fibonacci'
Call 3 of 'fibonacci'
Call 4 of 'fibonacci'
Call 5 of 'fibonacci'
Call 6 of 'fibonacci'
Call 7 of 'fibonacci'
Call 8 of 'fibonacci'
Call 9 of 'fibonacci'
Call 10 of 'fibonacci'
Call 11 of 'fibonacci'
Call 12 of 'fibonacci'
Call 13 of 'fibonacci'
Call 14 of 'fibonacci'
Call 15 of 'fibonacci'
Call 16 of 'fibonacci'
Call 17 of 'fibonacci'
Call 18 of 'fibonacci'
Call 19 of 'fibonacci'
Call 20 of 'fibonacci'
Call 21 of 'fibonacci'
Call 22 of 'fibonacci'
Call 23 of 'fibonacci'
Call 24 of 'fibonacci'
Call 25 of 'fibonacci'
Call 26 of 'fibonacci'
Call 27 of 'fibonacci'
Call 28 of 'fibonacci'
Call 29 of 'fibonacci'
Call 30 of 'fibonacci'
Call 31 of 'fibonacci'
Call 32 of 'fibonacci'
Call 33 of 'fibonacci'
Call 34 of 'fibonacci'
Call 35 of 'fibonacci'
Call 36 of 'fibonacci'
Call 37 of 'fibonacci'
Call 38 of 'fibonacci'
Call 39 of 'fibonacci'
Call 40 of 'fibonacci'
Call 41 of 'fibonacci'
Call 42 of 'fibonacci'
Call 43 of 'fibonacci'
Call 44 of 'fibonacc

55

In [7]:
import functools
from decorators import count_calls

def cache(func):
    """Keep a cache of previous function calls"""
    @functools.wraps(func)
    def wrapper_cache(*args, **kwargs):
        cache_key = args + tuple(kwargs.items())
        if cache_key not in wrapper_cache.cache:
            wrapper_cache.cache[cache_key] = func(*args, **kwargs)
        return wrapper_cache.cache[cache_key]
    wrapper_cache.cache = dict()
    return wrapper_cache

@cache
@count_calls
def fibonacci(num):
    if num < 2:
        return num
    return fibonacci(num - 1) + fibonacci(num - 2)

fibonacci(10)

Call 1 of 'fibonacci'
Call 2 of 'fibonacci'
Call 3 of 'fibonacci'
Call 4 of 'fibonacci'
Call 5 of 'fibonacci'
Call 6 of 'fibonacci'
Call 7 of 'fibonacci'
Call 8 of 'fibonacci'
Call 9 of 'fibonacci'
Call 10 of 'fibonacci'
Call 11 of 'fibonacci'


55

In [8]:
import functools

@functools.lru_cache(maxsize=4)
def fibonacci(num):
    print(f"Calculating fibonacci({num})")
    if num < 2:
        return num
    return fibonacci(num - 1) + fibonacci(num - 2)

fibonacci(10)

Calculating fibonacci(10)
Calculating fibonacci(9)
Calculating fibonacci(8)
Calculating fibonacci(7)
Calculating fibonacci(6)
Calculating fibonacci(5)
Calculating fibonacci(4)
Calculating fibonacci(3)
Calculating fibonacci(2)
Calculating fibonacci(1)
Calculating fibonacci(0)


55

### Data Validation

In [4]:
import functools
from data_ingestion import get_data

def validate_data(*expected_args):                  
    def decorator_validate_data(func):
        @functools.wraps(func)
        def wrapper_validate_data(*args, **kwargs):
            data_object = get_data()
            for expected_arg in expected_args:      
                if expected_arg not in data_object:
                    raise ...
            return func(*args, **kwargs)
        return wrapper_validate_data
    return decorator_validate_data

In [5]:
@validate_data("score")
def update_score():
    data = get_data()
    # Update score
    return "success!"

### Creating new feature

In [1]:
#!/usr/bin/env python3
# Copyright (c) 2008-11 Qtrac Ltd. All rights reserved.
# This program or module is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version. It is provided for educational
# purposes and is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.

"""
A simplified version of the built-in property class to show a possible
implementation and illustrate how descriptors work.
>>> contact = NameAndExtension("Joe", 135)
>>> contact.name, contact.extension
('Joe', 135)
>>> contact.X
Traceback (most recent call last):
    ...
AttributeError: 'NameAndExtension' object has no attribute 'X'
>>> contact.name = "Jane"
Traceback (most recent call last):
    ...
AttributeError: 'name' is read-only
>>> contact.name
'Joe'
>>> contact.extension = 975
>>> contact.extension
975
"""

class Property:

    def __init__(self, getter, setter=None):
        self.__getter = getter
        self.__setter = setter
        self.__name__ = getter.__name__


    def __get__(self, instance, owner=None):
        if instance is None:
            return self
        return self.__getter(instance)


    def __set__(self, instance, value):
        if self.__setter is None:
            raise AttributeError("'{0}' is read-only".format(
                                 self.__name__))
        return self.__setter(instance, value)


    def setter(self, setter):
        self.__setter = setter
        return self


class NameAndExtension:

    def __init__(self, name, extension):
        self.__name = name
        self.extension = extension


    @Property               # Uses the custom Property descriptor
    def name(self):
        return self.__name


    @Property               # Uses the custom Property descriptor
    def extension(self):
        return self.__extension


    @extension.setter       # Uses the custom Property descriptor
    def extension(self, extension):
        self.__extension = extension


if __name__ == "__main__":
    import doctest
    doctest.testmod()

In [1]:
from dataclasses import dataclass

@dataclass
class Config:
    path: str
    file: str

Additinal conveniences data classes provide

* add default values to data class fields
* allow for ordering of objects
* represent immutable data
* handle inheritance

### Creating Config

In [None]:
# default value
@dataclass
class Config:
    path: str = Null
    file: str = Null

Advanced Default Values using the field() specifier 

* default: Default value of the field
* default_factory: Function that returns the initial value of the field
* init: Use field in .__init__() method? (Default is True.)
* repr: Use field in repr of the object? (Default is True.)
* compare: Include the field in comparisons? (Default is True.)
* hash: Include the field when calculating hash()? (Default is to use the same as for compare.)
* metadata: A mapping with information about the field


In [3]:
from dataclasses import dataclass, field

@dataclass
class Loan:
    id: int
    term: float = field(default=0.0, metadata={'unit': 'count'})
    amount: float = field(default=0.0, metadata={'unit': 'dollar'})

In [2]:
# Inheritance

from dataclasses import dataclass

@dataclass
class Premier:
    name: str
    data_type: str
    default: float
    valid_range: float

@dataclass
class Trended(Premier):
    trended_attribute: str