## Getting started

This notebook uses Modal to run code remotely. Before running this notebook, you need to authenticate:

```bash
./go auth
```

Then restart the notebook kernel.

### Logging

We'll start by configuring logging, with the same config used both locally and remotely. Locally it's applied here, while remote config is applied in a `before_each` hook in the next cell.

In [None]:
import logging
from utils.logging import SimpleLoggingConfig

logging_config = SimpleLoggingConfig().info('notebook', 'utils', 'mini')
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger('notebook')

### Experiment

An [Experiment](../src/mini/experiment.py) is a specialised Modal app. It simplifies running a mixture of local and remote code.

We need to tell it which libraries to install remotely.

In [None]:
import modal

import mini
from utils.requirements import freeze, project_packages

run = mini.Experiment('demo')
run.image = (
    modal.Image.debian_slim()
    .pip_install(freeze('modal'))
    .add_local_python_source(*project_packages())
)  # fmt: skip

_ = run.before_each(logging_config.apply)

Sometimes a stale build can cause issues. If a `@run.thither` call gets stuck, you can force a rebuild.

In [None]:
# run.image.force_build = True

### Local metrics with `@run.hither`

Let's get some metrics displayed right here in the notebook! We'll define a function to draw a loss chart. This function will be called several times during training, and it should update the chart each time.

The [`@run.hither`](../src/mini/experiment.py) decorator causes the `track` function to run locally, even when called from a remote function. Run-hither functions can't return anything, but they can take any pickleable parameters.

In [None]:
from dataclasses import dataclass


@dataclass
class Metrics:
    epoch: int
    loss: float


@run.hither
def track() -> mini.AsyncCallback[Metrics]:
    # This is a factory that returns a function that always runs locally.
    from utils.nb import displayer

    history: list[Metrics] = []
    display = displayer()

    async def _track(metrics: Metrics):
        # This is the function that runs locally.
        history.append(metrics)
        fig = plot_history(history)
        display(fig)

    return _track


def plot_history(history: list[Metrics]):
    import matplotlib.pyplot as plt

    xs = [h.epoch for h in history]
    ys = [h.loss for h in history]

    fig, ax = plt.subplots(figsize=(8, 4))
    ax.set_title('Training progress')
    ax.set_xlabel('Epoch')
    ax.set_ybound(0, 1)
    ax.plot(xs, ys, label='Loss')
    ax.legend()
    plt.close(fig)
    return fig

### Remote functions with `@run.thither`

Here we define a mock training function that will run remotely. It just loops a few times and returns a stub model function.

We specify the exact packages that we'll need in the image to keep it small. Version specifiers are needed (see [`freeze`](../src/dair/requirements.py)), so that:
- The remote function behaves exactly how it would locally
- Objects can be pickled and sent back and forth.

The `track` function called in the training loop is the plotting function defined above! It's passed in as a parameter because it needs to be hooked up to a queue, which we'll see in the next step.

In [None]:
from mini import Callback


@run.thither(gpu=None)
async def train(epochs: int, track: Callback[Metrics]):
    # This is the function that runs in the cloud.
    from time import sleep

    log.info('Training...')

    for i in range(epochs):
        track(Metrics(epoch=i + 1, loss=1 / (i + 1)))
        sleep(0.2)

    def stub_model(x):
        if x == 'What is your quest?':
            return 'To seek the Holy Grail.'
        elif x == 'What is the air-speed velocity of an unladen swallow?':
            return 'What do you mean? An African or European swallow?'
        else:
            return "I don't know that!"

    log.info('Training complete')
    # Send the trained model back to the local machine.
    return stub_model

### Training

Now let's run the training code remotely.

Behind the scenes, a [distributed `Queue`](https://modal.com/docs/reference/modal.Queue) is used to send progress information back during training. You can push rich data onto the queue (like actual Matplotlib figures), and it transparently handles serialization - but in this example, a simple dataclass is emitted. To get that to work, we need to use the run-hither function (`track`) as a context manager.

The context object `track_stub` is passed to the training function. It's `track_stub` that sends the calls back to the local machine, where they are executed in the real `track` function.

In [None]:
# Create a stub to send to the remote function
async with run(), track() as track_stub:
    # Call the remote function
    model = await train(20, track_stub)

### Testing

The model was created remotely, serialized, and sent back. Now we can run it locally!

In [None]:
from textwrap import dedent

x = 'What is your quest?'
print(
    dedent(f"""
    {x}
    {model(x)}
    """).strip()
)

x = 'What is the air-speed velocity of an unladen swallow?'
print(
    dedent(f"""
    {x}
    {model(x)}
    {model(model(x))}
    """).strip()
)