# Spotiflow: training your own model

Let's first load all the libraries we're gonna need to detect spots in our images.

In [None]:
from spotiflow.model import Spotiflow, SpotiflowModelConfig, SpotiflowTrainingConfig
from spotiflow.sample_data import load_dataset
from spotiflow.utils import get_data

import matplotlib.pyplot as plt
import numpy as np

Then, we can load our dataset. We will use one of the training datasets of the paper, corresponding to the `Synthetic (complex)` dataset (which is a good starting point if you want to then fine-tune on your own data). If you have your own annotated data, you can load it and store it in six different variables corresponding to the training images and spots, to the validation images and spots and to the test images and spots. You can use the `load_data()` function to that end (please [see the docs](https://weigertlab.github.io/spotiflow) to check the data format that the function allows).

In [None]:
trX, trY, valX, valY, testX, testY = load_dataset("synth_complex", include_test=True)
# trX, trY, valX, valY, testX, testY = get_data("/FOLDER/WITH/DATA", include_test=True)

The first two variables should contain the training images and annotations, while the latter the validation ones. Let's define a function that will help us visualize them together. 

In [None]:
def plot_image_with_spots(img, spots, title=None, pred=False):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    clim = tuple(np.percentile(img, (1, 99.8)))
    axs.flat[0].imshow(img, clim=clim, cmap="gray")
    axs.flat[1].imshow(img, clim=clim, cmap="gray")
    axs.flat[1].scatter(spots[:,1], spots[:,0], facecolors='none', edgecolors='orange')
    
    axs.flat[0].axis("off")
    axs.flat[1].axis("off")
    if isinstance(title, str):
        title_subp0 = f"{title}"
        title_subp1 = f"{title} (w/ {'annotation' if not pred else 'prediction'})"
        axs.flat[0].set_title(title_subp0)
        axs.flat[1].set_title(title_subp1)
    return

We can now visualize the images with their annotations. These are the first two images of the training dataset (change the indices to see others!): 

In [None]:
plot_image_with_spots(trX[0], trY[0], title="Training image")
plot_image_with_spots(trX[1], trY[1], title="Training image")

And here are the last two images of the validation dataset:

In [None]:
plot_image_with_spots(valX[-2], valY[-2], title="Validation image")
plot_image_with_spots(valX[-1], valY[-1], title="Validation image")

The images and the corresponding annotations look good, so we can now train our own Spotiflow model on this data! Training with the default model configuration is very straightforward. First we need to instantiate the model:

In [None]:
model = Spotiflow()

Notice that the config is populated with the default values. If you want to change some of the parameters, you can uncomment the following block and change the parameters accordingly (for more information, check [the documentation](https://weigertlab.github.io/spotiflow)):

In [None]:
# config = SpotiflowModelConfig(
#     in_channels=3, # e.g. for RGB
#     sigma=5., # for larger spots
# )
# model = Spotiflow(config=config)

We can now train the model with calling `.fit()` after setting where we want the model to be stored. Again, you will notice that a training configuration is automatically populated. If you want to change some values (_e.g._ the number of epochs), simply uncomment the lines and change the parameters accordingly (for more information, check [the documentation](https://weigertlab.github.io/spotiflow)):

In [None]:
save_folder = "models/synth_complex" # change to where you want to store the model
train_config = SpotiflowTrainingConfig(
    num_epochs=2, # very small number of epochs for debugging purposes
)
model.fit(
    trX,
    trY,
    valX,
    valY,
    save_dir=save_folder,
    train_config=train_config,
)

Our model is now ready to be used! Let's first check the save folder to make sure the model was stored properly (there should be two `.pt` files (`best.pt` and `last.pt`) as well as three `.yaml` configuration files.)

In [None]:
!ls $save_folder

We can also quickly predict on a test image which was not seen during training (see [the inference notebook](./2_inference.ipynb) for more information about predicting as well as model loading): 

In [None]:
test_pred, _ = model.predict(testX[0])

plot_image_with_spots(testX[0], test_pred, title="Test image", pred=True)

This notebook shows the most user-friendly way to train models. If you want to dive deeper into the model architecture and tweak the code and you are already comfortable with training DL models, please check [the documentation](https://weigertlab.github.io/spotiflow) to get started.