## DeepSpot Training

In this second notebook, we will provide the logic and a basic example of how to train DeepSpot on your spatial transcriptomics data. We assume that you have already preprocessed your data and prepared it for training.

In [None]:
import os
os.chdir('../../')

Export packages

In [None]:
from deepspot.utils.utils_image import get_morphology_model_and_preprocess
from deepspot.utils.utils import plot_loss_values

from deepspot.spot import DeepSpotDataLoader
from deepspot.spot import DeepSpot


from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from pathlib import Path
import lightning as L
import pandas as pd
import numpy as np
import torch
import yaml
yaml.Dumper.ignore_aliases = lambda *args : True

Here, we specify the input parameters and the dataloader settings.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
samples = set(['ZEN38'])
out_folder = "example_data"

In [None]:
dataloader_param = {
# specify the used foundation model 
# to extract the precomputed tile representations
"morphology_model_name": "inception", 
"batch_size": 1024,
# use the spot, subspots, and neighboring spots.
'spot_context': 'spot_subspot_neighbors',
# the radius used to compute the neighbors around 
# the central spot based on the array coordinates.
'radius_neighbors': 1, 
# oversampling
'resolution': 1,
# if to normalize the data during training 
# and the type of normalization
'normalize': 'standard', # None
'augmentation': 'default' # to use 'aestetik' -> pip install aestetik;
        }
batch_size = dataloader_param["batch_size"]
image_feature_model = dataloader_param['morphology_model_name']
num_workers = max(1, torch.get_num_threads() - 1)

del dataloader_param["batch_size"]

In [None]:
genes = pd.read_csv(f"{out_folder}/data/info_highly_variable_genes_Visium.csv")
selected_genes_bool = genes.isPredicted.values
genes_to_predict = genes[selected_genes_bool]
genes_to_predict

We need the feature_dim when defining the DeepSpot input dimensions

In [None]:
_, _, feature_dim = get_morphology_model_and_preprocess(model_name=image_feature_model, 
                                                                                device=device)
feature_dim

Now, we prepare the DeepSpot dataloader. `out_folder` specifies the parent location where the preprocessed data is stored. `genes_to_keep` is a boolean np.array that indicates which genes to include for training and prediction. `samples` refers to the slide_ids, which are used to load the data.

In [None]:
train_data_loader_custom = DeepSpotDataLoader(
                               out_folder=out_folder, 
                               samples=samples, 
                               genes_to_keep=selected_genes_bool,
                               **dataloader_param
)
train_data_loader = torch.utils.data.DataLoader(dataset=train_data_loader_custom,
                                                      batch_size=batch_size,
                                                      num_workers=num_workers,
                                                      shuffle=True)

Here, you can customize the hyperparameters of DeepSpot. For this example, we will train it using the default parameters. The `scaler` is important to be the same as the one used during training, so that the predictions of DeepSpot can be rescaled back to their original ranges using the `inverse_transform` function. 

##### IMPORTANT: Remember to manually rescale the values, as this is not done automatically.
```
expression_norm = model(X)
expression_norm should be np.array
expression = model.inverse_transform(expression_norm)
```

In [None]:
param = {

        }
param

In [None]:
regressor = DeepSpot(input_size=feature_dim,
                output_size=int(selected_genes_bool.sum()),
                scaler=train_data_loader_custom.scaler,
                **param)

We train the model with early stoppping.

In [None]:
trainer = L.Trainer(max_epochs=10, logger=False, enable_checkpointing=False, callbacks=[EarlyStopping(monitor="train_step",
                                                                    patience=3,
                                                                    min_delta=0.01, 
                                                                    mode="min")])
trainer.fit(regressor, train_data_loader)

We also provide a function to visualize your training loss. Keep in mind that this is only an example

In [None]:
plot_loss_values(regressor.training_loss)

Once DeepSpot is trained, you can export its weights and hyperparameters and use them for inference.

In [None]:
Path(f"pretrained_model_weights/example_model").mkdir(parents=True, exist_ok=True)

In [None]:
model_path = 'pretrained_model_weights/example_model/weights_Visium.pkl'
torch.save(regressor, model_path)

In [None]:
hparam = dict(regressor.hparams)
hparam['image_feature_model'] = dataloader_param['morphology_model_name']
hparam['scaler'] = str(hparam['scaler']) # ignore since it is an object

In [None]:
# Specify the output YAML file path
yaml_file_path = 'pretrained_model_weights/example_model/hparam_Visium.yaml'

# Save the dictionary as a YAML file
with open(yaml_file_path, 'w') as yaml_file:
    yaml.dump(hparam, yaml_file)