# Spotiflow: training your own model in 3D

**NOTE**: this notebook requires `napari` to be installed if you want to visualize the data (optional but recommended). You can install it e.g. via `pip install napari[all]` (see [the instructions](https://napari.org/stable/tutorials/fundamentals/installation.html) if you have any issue).

Let's first load all the libraries we're gonna need to detect spots in our volumes.

In [None]:
from spotiflow.model import Spotiflow, SpotiflowModelConfig
from spotiflow.sample_data import load_dataset
from spotiflow.utils import get_data

Similarly to the 2D case, we first load our dataset. We will use the `synth_3d` dataset (corresponding to `synthetic-3d` in the paper, which is a good starting point if you want to then fine-tune on your own data). If you have your own annotated data, you can load it and store it in six different variables corresponding to the training images and spots, to the validation images and spots and to the test images and spots. You can use the `load_data()` function to that end (please [see the docs](https://weigertlab.github.io/spotiflow) to check the data format that the function allows).

In [None]:
trX, trY, valX, valY, testX, testY = load_dataset("synth_3d", include_test=True)
# trX, trY, valX, valY, testX, testY = get_data("/FOLDER/WITH/DATA", include_test=True, is_3d=True) # with test data
# trX, trY, valX, valY = get_data("/FOLDER/WITH/DATA", include_test=False, is_3d=True) # without test data

The first two variables should contain the training images and annotations, the third and fourth the validation ones, and the last two the test ones (if `include_test=True` was given). While visualizing the images in Python is quite straightforward, that is not the case for 3D volumes. We will use the `napari` library to visualize the volumes. If you don't have it installed, you can do so by checking the first cell in the notebook. The cell below won't run if you don't have `napari` installed, but you can still run the rest of the notebook without it.

In [None]:
try:
    import napari
    viewer = napari.Viewer(ndisplay=3)
    viewer.add_image(trX[0], name="Training volume", colormap="gray")
    viewer.add_points(trY[0], name="Training spots", face_color="orange", edge_color="orange", size=5, symbol="ring")
except ImportError as _:
    print("napari not installed, skipping visualization")
    viewer = None
except Exception as e:
    raise e

Training with the default model configuration is straightforward, althought not as much as in the 2D case. First we need to instantiate the model configuration (check [the documentation](https://weigertlab.github.io/spotiflow) for more information about other options):

In [None]:
config = SpotiflowModelConfig(
    is_3d=True, # 3D model
    grid=(2, 2, 2), # predict on a downsampled grid, this is the value used in the paper
)
model = Spotiflow(config=config)

We can now train the model with calling `.fit()` after setting where we want the model to be stored. Again, you need to define the training parameters. If you want to change some values (_e.g._ the number of epochs), simply change the parameter accordingly (for more information, check [the documentation](https://weigertlab.github.io/spotiflow)):

In [None]:
save_folder = "models/synth_3d" # change to where you want to store the model
train_config = {
    "num_epochs": 1,
    "crop_size": 128,
    "crop_size_depth": 32,
    "smart_crop": True,
}
model.fit(
    trX,
    trY,
    valX,
    valY,
    save_dir=save_folder,
    train_config=train_config,
)

Our model is now ready to be used! Let's first check the save folder to make sure the model was stored properly (there should be two `.pt` files (`best.pt` and `last.pt`) as well as three `.yaml` configuration files.)

In [None]:
!ls $save_folder

We can also quickly predict on a test image (if you used `include_test=True` earlier) which was not seen during training (see [the inference notebook](./2_inference.ipynb) for more information about predicting as well as model loading): 

In [None]:
test_pred, _ = model.predict(testX[0], device="auto")

Let's visualize the results now using `napari` (if it is already running):

In [None]:
if viewer is not None:
    while len(viewer.layers) > 0:
        viewer.layers.pop()
    viewer.add_image(testX[0], name="Test volume", colormap="gray")
    viewer.add_points(test_pred, name="Predicted test spots", face_color="orange", edge_color="orange", size=5, symbol="ring")

This notebook shows the most user-friendly way to train models. If you want to dive deeper into the model architecture and tweak the code and you are already comfortable with training DL models, please check [the documentation](https://weigertlab.github.io/spotiflow) to get started.