In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import matplotlib.pyplot as plt

## Callbacks

### Callbacks as GUI events

In [4]:
import ipywidgets as widgets

def f(o): 
    print('hi')

From the [ipywidget docs](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20Events.html):

- *the button widget is used to handle mouse clicks. The on_click method of the Button can be used to register function to be called when the button is clicked*

In [16]:
# create a simple widget button
w = widgets.Button(description='Click me')
w

Button(description='Click me', style=ButtonStyle())

In [18]:
# when the button is clicked, it will call function f
# in this sense f is a callback
w.on_click(f)

**NB: When callbacks are used in this way they are often called "events".**

- **events are a kind of callback**

- **callback is a type of function pointer**

Did you know what you can create interactive apps in Jupyter with these widgets? Here's an example from [plotly](https://plot.ly/python/widget-app/):

![](https://cloud.githubusercontent.com/assets/12302455/16637308/4e476280-43ac-11e6-9fd3-ada2c9506ee1.gif)

### Creating your own callback

In [19]:
from time import sleep

In [11]:
# lets make a simple counting function
def slow_calculation():
    res = 0
    # does five calculations
    for i in range(5):
        res += i*i
        sleep(1)
    return res

In [12]:
slow_calculation()

30

In [20]:
# lets add a callback to our function
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
        if cb: cb(i)
    return res

In [23]:
# this is a simple callback
def show_progress(epoch):
    print(f"Awesome! We've finished epoch {epoch}!")

In [24]:
slow_calculation(show_progress)

Awesome! We've finished epoch 0!
Awesome! We've finished epoch 1!
Awesome! We've finished epoch 2!
Awesome! We've finished epoch 3!
Awesome! We've finished epoch 4!


30

### Lambdas and partials

Rather than defining a function every single time, we can make a temporary function using `lambda` 

In [30]:
# define the callback function inline
slow_calculation(lambda o: print(f"Awesome! We've finished epoch {o}!"))

Awesome! We've finished epoch 0!
Awesome! We've finished epoch 1!
Awesome! We've finished epoch 2!
Awesome! We've finished epoch 3!
Awesome! We've finished epoch 4!


30

In [33]:
# if we want to make something passable in the callback?
# we need to add another parameter, and make it a function with only one arg
def show_progress(exclamation, epoch):
    print(f"{exclamation}! We've finished epoch {epoch}!")

In [34]:
slow_calculation(lambda o: show_progress("OK I guess", o))

OK I guess! We've finished epoch 0!
OK I guess! We've finished epoch 1!
OK I guess! We've finished epoch 2!
OK I guess! We've finished epoch 3!
OK I guess! We've finished epoch 4!


30

In [35]:
# or we can make a function that generates a lambda
def make_show_progress(exclamation):
    _inner = lambda epoch: print(f"{exclamation}! We've finished epoch {epoch}!")
    return _inner

In [36]:
slow_calculation(make_show_progress("Nice!"))

Nice!! We've finished epoch 0!
Nice!! We've finished epoch 1!
Nice!! We've finished epoch 2!
Nice!! We've finished epoch 3!
Nice!! We've finished epoch 4!


30

In [38]:
# or we can make a function that generates a defined function
def make_show_progress(exclamation):
    # Leading "_" is generally understood to be "private"
    def _inner(epoch): 
        print(f"{exclamation}! We've finished epoch {epoch}!")
    return _inner

slow_calculation(make_show_progress("Nice!"))

Nice!! We've finished epoch 0!
Nice!! We've finished epoch 1!
Nice!! We've finished epoch 2!
Nice!! We've finished epoch 3!
Nice!! We've finished epoch 4!


30

In [39]:
slow_calculation(make_show_progress("Nice!"))

Nice!! We've finished epoch 0!
Nice!! We've finished epoch 1!
Nice!! We've finished epoch 2!
Nice!! We've finished epoch 3!
Nice!! We've finished epoch 4!


30

In [40]:
f2 = make_show_progress("Terrific")

In [41]:
slow_calculation(f2)

Terrific! We've finished epoch 0!
Terrific! We've finished epoch 1!
Terrific! We've finished epoch 2!
Terrific! We've finished epoch 3!
Terrific! We've finished epoch 4!


30

In [42]:
slow_calculation(make_show_progress("Amazing"))

Amazing! We've finished epoch 0!
Amazing! We've finished epoch 1!
Amazing! We've finished epoch 2!
Amazing! We've finished epoch 3!
Amazing! We've finished epoch 4!


30

#### Partial is a nice way to turn functions of multiple parameters and turn them into a function of a single parameter

In [46]:
from functools import partial

In [47]:
slow_calculation(partial(show_progress, "OK I guess"))

OK I guess! We've finished epoch 0!
OK I guess! We've finished epoch 1!
OK I guess! We've finished epoch 2!
OK I guess! We've finished epoch 3!
OK I guess! We've finished epoch 4!


30

In [48]:
f2 = partial(show_progress, "OK I guess")

### Callbacks as callable classes

So instead of functions as a callback, we will be passing back classes

In [50]:
class ProgressShowingCallback():
    def __init__(self, exclamation="Awesome"):
        self.exclamation = exclamation

    def __call__(self, epoch): 
        print(f"{self.exclamation}! We've finished epoch {epoch}!")

In [51]:
cb = ProgressShowingCallback("Just super")

In [52]:
slow_calculation(cb)

Just super! We've finished epoch 0!
Just super! We've finished epoch 1!
Just super! We've finished epoch 2!
Just super! We've finished epoch 3!
Just super! We've finished epoch 4!


30

### Multiple callback funcs; `*args` and `**kwargs`

In [53]:
def f(*args, **kwargs): 
    print(f"args: {args}; kwargs: {kwargs}")

In [54]:
f(3, 'a', thing1="hello")

args: (3, 'a'); kwargs: {'thing1': 'hello'}


NB: We've been guilty of over-using kwargs in fastai - it's very convenient for the developer, but is annoying for the end-user unless care is taken to ensure docs show all kwargs too. kwargs can also hide bugs (because it might not tell you about a typo in a param name). In [R](https://www.r-project.org/) there's a very similar issue (R uses `...` for the same thing), and matplotlib uses kwargs a lot too.

For this next section we will be doing a callback before the calculation and also after the calculation. This could be done with two different functions, but that list could become too long. So instead a class will be constructed that will collect all these functions

In [57]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb: cb.after_calc(i, val=res)
    return res

In [61]:
class PrintStepCallback():
    def __init__(self): 
        pass

    def before_calc(self, *args, **kwargs):
        print(f"About to start")

    def after_calc (self, *args, **kwargs): 
        print(f"Done step")

In [62]:
slow_calculation(PrintStepCallback())

About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step


30

In [65]:
# there was some errors about printing out the wrong index, so
# we will be more explict in what is being passed and called

class PrintStatusCallback():
    def __init__(self): 
        pass

    def before_calc(self, epoch, **kwargs): 
        print(f"About to start: {epoch}")
        
    def after_calc (self, epoch, val, **kwargs):
        print(f"After {epoch}: {val}")

In [64]:
slow_calculation(PrintStatusCallback())

About to start: 0
After 0: 0
About to start: 1
After 1: 1
About to start: 2
After 2: 5
About to start: 3
After 3: 14
About to start: 4
After 4: 30


30

### Modifying behavior

Lets try and change something, like:

- cancel out of the loop
- change the value
- what if you ddint want to do a begin_calc

In [66]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb,'before_calc'): cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb and hasattr(cb,'after_calc'):
            if cb.after_calc(i, res):
                print("stopping early")
                break
    return res

In [67]:
class PrintAfterCallback():
    def after_calc (self, epoch, val):
        print(f"After {epoch}: {val}")
        if val>10:
            return True

In [68]:
slow_calculation(PrintAfterCallback())

After 0: 0
After 1: 1
After 2: 5
After 3: 14
stopping early


14

Now we are changing our main function to a class, so it can retain data and attributes. This will allow the callback function to reach into the data and change some of the values

In [71]:
class SlowCalculator():
    def __init__(self, cb=None):
        self.cb,self.res = cb,0
    
    def callback(self, cb_name, *args):
        
        # is the callback defined ( maybe no before_calc)
        if not self.cb:
            return
        
        # grabs if exists and returns
        cb = getattr(self.cb,cb_name, None)
        
        if cb:
            # when it returns the callback it returns itself
            return cb(self, *args)

    def calc(self):
        for i in range(5):
            self.callback('before_calc', i)
            self.res += i*i
            sleep(1)
            if self.callback('after_calc', i):
                print("stopping early")
                break

In [72]:
class ModifyingCallback():
    def after_calc (self, calc, epoch):
        print(f"After {epoch}: {calc.res}")

        if calc.res>10:
            return True

        if calc.res<3: 
            calc.res = calc.res*2

In [73]:
calculator = SlowCalculator(ModifyingCallback())

In [74]:
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15