## Modal demo

This notebook uses Modal to run code remotely. Before running this notebook, you need to authenticate:

```bash
uv run modal setup
```

Then restart the notebook kernel.

### Local metrics

Let's get some metrics displayed right here in the notebook! We'll define a function to draw a loss chart. This function will be called several times during training, and it should update the chart each time.

In [None]:
from dataclasses import dataclass
from IPython.display import display
import matplotlib.pyplot as plt

from utils.modal import SyncHandler


@dataclass
class Metrics:
    epoch: int
    loss: float


def plot_history(history: list[Metrics]):
    xs = [h.epoch for h in history]
    ys = [h.loss for h in history]

    fig, ax = plt.subplots(figsize=(8, 4))
    ax.set_title('Training progress')
    ax.set_xlabel('Epoch')
    ax.set_ybound(0, 1)
    ax.plot(xs, ys, label='Loss')
    ax.legend()
    plt.close(fig)
    return fig


def progress() -> SyncHandler[Metrics]:
    history: list[Metrics] = []
    display_id = display(plot_history(history), display_id=True)

    def receive(metrics: list[Metrics]):
        history.extend(metrics)
        display_id.update(plot_history(history))

    return receive

### Remote function

Here we define a mock training function that will run remotely. It just loops a few times and returns a stub model function.

We specify the exact packages that we'll need in the image to keep it small. Version specifiers are needed (see [`freeze`](src/dair/requirements.py)), so that:
- The remote function behaves exactly how it would locally
- Objects can be pickled and sent back and forth.

We'll add `modal` itself as a dependency, because it's used by `emit_metrics` (see _Training_ below).


In [None]:
from time import sleep
import modal

from utils.requirements import freeze
from utils.modal import SyncHandler

image = modal.Image.debian_slim().pip_install(freeze('modal', 'matplotlib')).add_local_python_source('utils')

app = modal.App()


@app.function(image=image, gpu=None)
def train(epochs: int, emit_metrics: SyncHandler[Metrics]):
    print('Training...')

    for i in range(epochs):
        emit_metrics([Metrics(epoch=i + 1, loss=1 / (i + 1))])
        sleep(0.2)

    def stub_model(x):
        if x == 'What is your quest?':
            return 'To seek the Holy Grail.'
        elif x == 'What is the air-speed velocity of an unladen swallow?':
            return 'What do you mean? An African or European swallow?'
        else:
            return "I don't know that!"

    print('Training complete')
    return stub_model

Prevent build logs from showing in the next step. Currently Modal doesn't provide a way to separate the build logs from container stdout.

In [None]:
@app.function(image=image)
def prebuild():
    pass


with app.run():
    prebuild.remote()

### Training

Now let's run the training code remotely.

A [distributed `Queue`](https://modal.com/docs/reference/modal.Queue) is used to send progress information back during training. You can push rich data onto the queue (like actual Matplotlib figures), and it transparently handles serialization - but in this example, a simple dataclass is emitted. The progress function is wrapped in [`send_to`](src/dair/comms.py): a context manager that provides a simple interface over the queue.

If we only cared about the final result, or if we were happy just printing progress to stdout, we could call `train` synchronously.

In [None]:
from utils.modal import run, send_to

async with run(app), send_to.batch(progress()) as emit_metrics:
    model = await train.remote.aio(20, emit_metrics)

### Testing

The model was created remotely, serialized, and sent back. Now we can run it locally!

In [None]:
from textwrap import dedent

x = 'What is your quest?'
print(
    dedent(f"""
    {x}
    {model(x)}
    """).strip()
)

x = 'What is the air-speed velocity of an unladen swallow?'
print(
    dedent(f"""
    {x}
    {model(x)}
    {model(model(x))}
    """).strip()
)