## Modal test

This notebook uses Modal to run code remotely. Before running this notebook, you need to authenticate:

```bash
uv run modal setup
```

Then restart the notebook kernel.

## Remote function

Let's define a mock training function that will run remotely. It just loops a few times and returns a stub model function.

A [distributed `Queue`](https://modal.com/docs/reference/modal.Queue) is used to send progress information back. You can push rich data (like actual Matplotlib figures) onto the queue, and it transparently handles serialization - but let's keep it simple and leave the display to the caller.

We specify the exact packages that we'll need in the image to keep it small. Version specifiers are needed (see [`freeze`](src/ai_eggs/requirements.py)), so that:
- The remote function behaves exactly how it would locally
- Objects can be pickled and sent back and forth.

Since the queue is passed in as a parameter, we add `modal` itself as a dependency. In this example, we also add the local package `ai_eggs`. It's not strictly needed for the function to run, but Modal currently auto-mounts it and issues a warning about it, so we may as well be explicit that that's what is happening.


In [None]:
from time import sleep
from typing import Callable
import modal

from ai_eggs.requirements import freeze

app = modal.App()


@app.function(
    image=(
        modal.Image
        .debian_slim()
        # 'modal' is needed to unpickle the Queue
        .pip_install(freeze('modal'))
        .add_local_python_source('ai_eggs')
    ),
    # gpu="T4",
)
def train(epochs: int, emit_metrics: Callable[[dict], None]):
    for i in range(epochs):
        emit_metrics({"epoch": i, "loss": 1/(i+1)})
        sleep(0.5)

    def stub_model(x):
        if x == "What is the meaning of life?":
            return "Forty-two."
        else:
            return "I don't know that!"
    return stub_model

## Display progress locally

Let's get some metrics displayed right here in the notebook! We'll define a function to chart metrics. This function will be called several times during training, and it should update the chart each time.

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt


def plot_loss(losses: list[float]):
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.set_title('Training progress')
    ax.set_xlabel('Epoch')
    ax.set_ybound(0, 1)
    ax.plot(range(len(losses)), losses, label='Loss')
    ax.legend()
    plt.close(fig)
    return fig


def Progress():
    losses: list[float] = []
    # Use a display_id to update the plot without clearing the whole cell output
    display_id = display(plot_loss(losses), display_id=True)

    def progress(message: dict[str, float | int]):
        losses.append(message['loss'])
        display_id.update(plot_loss(losses))

    return progress

## Run remotely and display progress locally

Now let's run the training code remotely.

If we only cared about the final result, or if we were happy just printing progress to stdout, we could call `train` synchronously. But by calling it asynchronously, we can chart the metrics _while it runs_. The progress function is wrapped in a context manager that encapsulates communication with a `modal.Queue`.

In [None]:
from ai_eggs.comms import send_to

async with app.run(), send_to(Progress()) as emit_metrics:
    model = await train.remote.aio(5, emit_metrics)

The model function was created remotely, serialized, and sent back. Now we can run it locally!

In [None]:
x = "What is the meaning of life?"
print(f"Q: {x}\nA: {model(x)}")

x = "What is the airspeed velocity of an unladen swallow?"
print(f"Q: {x}\nA: {model(x)}")