## Modal test

This notebook uses Modal to run code remotely. Before running this notebook, you need to authenticate:

```bash
uv run modal setup
```

Then restart the notebook kernel.

## Remote function

Let's define a mock training function that will run remotely. It just loops a few times and returns a stub model function.

A [distributed `Queue`](https://modal.com/docs/reference/modal.Queue) is used to send progress information back. You can push rich data (like actual Matplotlib figures) onto the queue, and it transparently handles serialization - but let's keep it simple and leave the display to the caller.

We specify the exact packages that we'll need in the image to keep it small. Version specifiers are needed (see [`freeze`](src/ai_eggs/requirements.py)), so that:
- The remote function behaves exactly how it would locally
- Objects can be pickled and sent back and forth.

Since the queue is passed in as a parameter, we add `modal` itself as a dependency. In this example, we also add the local package `ai_eggs`. It's not strictly needed for the function to run, but Modal currently auto-mounts it and issues a warning about it, so we may as well be explicit that that's what is happening.


In [None]:
from time import sleep
from typing import Callable
import modal

from ai_eggs.requirements import freeze

app = modal.App()


@app.function(
    image=(
        modal.Image
        .debian_slim()
        # 'modal' is needed to unpickle the Queue
        .pip_install(freeze('modal'))
        .add_local_python_source('ai_eggs')
    ),
    # gpu="T4",
)
async def train(epochs: int, emit_metrics: Callable[[dict], None]):
    for i in range(epochs):
        emit_metrics({"epoch": i, "loss": 1/(i+1)})
        sleep(0.5)

    def stub_model(x): return f"model({x})"
    return stub_model

## Call and display progress locally

Now let's run that code remotely.

If we only cared about the final result, or if we were happy just printing progress to stdout, we could call `train` synchronously. But by calling it asynchronously with `.remote.aio(...)`, we can chart the metrics while it runs.

In [None]:
from IPython.display import clear_output, display
import matplotlib.pyplot as plt

from ai_eggs.comms import simple_comms


losses: list[float] = []
def progress(message: dict[str, float | int]):
    losses.append(message['loss'])
    clear_output(wait=True)
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(range(len(losses)), losses)
    display(fig)
    plt.close(fig)


async with app.run(), simple_comms(progress) as emit_metrics:
    model = await train.remote.aio(5, emit_metrics)

print(model(42))