# Training DeepSpot2Cell



In [None]:
import os
import torch
import lightning as L

from deepspot2cell import DS2CDataset, DeepSpot2Cell


morphology_model = "phikonv2"
batch_size = 256
max_epochs = 400
num_workers = 4
save_dir = "results"

data_folder = "hest_data"
dataset_variant = "_v1"
train_ids = ["NCBI864", "NCBI873", "NCBI856", "NCBI860"]
val_ids = ["NCBI865", "NCBI857"]


train_dataset = DS2CDataset(
    dataset_variant=dataset_variant,
    ids_list=train_ids,
    data_path=data_folder,
    model_name=morphology_model,
    shuffle=True,
    normalize=True,
    neighb_degree=1,
)
val_dataset = DS2CDataset(
    dataset_variant=dataset_variant,
    ids_list=val_ids,
    data_path=data_folder,
    model_name=morphology_model,
    shuffle=False,
    normalize=True,
    neighb_degree=1,
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

sample_batch = next(iter(train_loader))
input_size = sample_batch['cell_embeddings'].shape[-1]
output_size = sample_batch['spot_expression'].shape[-1]
print(f"Input size: {input_size}, Output size: {output_size}")

model = DeepSpot2Cell(input_size=input_size, output_size=output_size)

trainer = L.Trainer(
    max_epochs=max_epochs,
    accelerator='auto',
    devices=1,
    logger=True
)

trainer.fit(model, train_loader, val_loader)

Save the weights, display the metrics

In [None]:
os.makedirs(save_dir, exist_ok=True)
model_save_path = f"{save_dir}/final_model.pt"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved at {model_save_path}")

trainer_logs = trainer.logged_metrics
available_metrics = {}
for key, value in trainer_logs.items():
    if isinstance(value, torch.Tensor):
        available_metrics[key] = value.item()
    else:
        available_metrics[key] = value

print("Available metrics:")
for key, value in available_metrics.items():
    print(f"  {key}: {value}")