# T4: 3D slice modeling of Drosophila embryo


In [10]:
import scanpy as sc
import torch
import urllib.request
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
from diffusers import DDPMScheduler
from torch_geometric.loader import NeighborLoader
from stadiffuser import pipeline
from stadiffuser.vae import SpaAE
from stadiffuser.models import SpaUNet1DModel
from stadiffuser import utils as sutils
from stadiffuser import metrics
from stadiffuser.dataset import get_slice_loader
warnings.filterwarnings("ignore")

## Load data

In [8]:
# Please manually download file from https://drive.google.com/file/d/1zyZKeZljbsEqo3YqVc_2-quU1Esm55E1/view?usp=drive_link
# It's ~200 MB.
# load the dowloaded proceesed Stereo-seq data
adata = sc.read_h5ad("adata_processed.h5ad")
adata

AnnData object with n_obs × n_vars = 14634 × 2000
    obs: 'slice_ID', 'raw_x', 'raw_y', 'new_x', 'new_y', 'new_z', 'annotation'
    var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg', 'log1p', 'spatial_net'
    obsm: 'X_umap', 'spatial'
    layers: 'raw_counts'

In [9]:
adata.obs["slice_ID"].value_counts()

slice_ID
E16-18h_a_S11    1193
E16-18h_a_S04    1189
E16-18h_a_S05    1181
E16-18h_a_S08    1131
E16-18h_a_S09    1113
E16-18h_a_S10    1111
E16-18h_a_S07    1096
E16-18h_a_S06    1076
E16-18h_a_S12    1049
E16-18h_a_S13    1022
E16-18h_a_S03    1021
E16-18h_a_S01     985
E16-18h_a_S02     965
E16-18h_a_S14     502
Name: count, dtype: int64

In [13]:
adata = sutils.cal_spatial_net3D(adata, iter_comb=None, batch_id="slice_ID", rad_cutoff=1.4,
                                add_key="spatial_net")
new_spatial = adata.obsm["spatial"].copy()
new_spatial = sutils.quantize_coordination(new_spatial, methods=[("division", 0.8), ("division", 0.8), ("division", 0.35)])
adata.obsm["new_spatial"] = new_spatial

------Calculating spatial network for each batch...
Calculating spatial network for batch E16-18h_a_S01...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3790 edges, 985 cells, 3.8477 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S02...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3718 edges, 965 cells, 3.8528 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S03...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 3932 edges, 1021 cells, 3.8511 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S04...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contains 4594 edges, 1189 cells, 3.8638 neighbors per cell on average.
Calculating spatial network for batch E16-18h_a_S05...
------Calculating spatial graph...
------Spatial graph calculated.
The graph contai

{'CNS': 0,
 'carcass': 1,
 'epidermis': 2,
 'fat body': 3,
 'foregut': 4,
 'hemolymph': 5,
 'midgut': 6,
 'muscle': 7,
 'salivary gland': 8,
 'trachea': 9}

## Training autoencoder

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = SpaAE(input_dim=adata.shape[1],
                        block_list=["AttnBlock"],
                        gat_dim=[512, 32],
                        block_out_dims=[32, 32])

### Pretraining on each slice

In [None]:
batch_list = adata.obs["slice_ID"].unique().tolist()
data = pipeline.prepare_dataset(adata, use_rep=None)
train_loaders = [get_slice_loader(adata, data, batch, use_batch="replicate",
                                  batch_size=256) for batch in batch_list]
autoencoder, autoencoder_loss = pipeline.pretrain_autoencoder_multi(train_loaders,
                                                                    autoencoder,
                                                                    pretrain_epochs=200,
                                                                    device=device)

### Training with triplet loss to align the spot/cell embeddings

In [None]:
autoencoder, autoencoder_loss = pipeline.train_autoencoder_multi(adata, autoencoder, use_batch="replicate",
                                                                 batch_list=batch_list,
                                                                 n_epochs=300,
                                                                 margin=1,
                                                                 lr=1e-4,
                                                                 update_interval=50,
                                                                 device=device)

## Training Latent diffusion model

In [20]:
import numpy as np
cond_name = "annotation"
num_class_embeds = len(np.unique(adata.obs[cond_name]))
class_dict = dict(zip(np.unique(adata.obs[cond_name]), range(num_class_embeds)))
adata.obs["label_"] = adata.obs[cond_name].map(class_dict)
class_dict

{'CNS': 0,
 'carcass': 1,
 'epidermis': 2,
 'fat body': 3,
 'foregut': 4,
 'hemolymph': 5,
 'midgut': 6,
 'muscle': 7,
 'salivary gland': 8,
 'trachea': 9}

In [None]:
adata = pipeline.get_recon(adata, autoencoder, device=device,
                           apply_normalize=False, show_progress=True, batch_mode=True)
normalizer = sutils.MinMaxNormalize(adata.obsm["latent"], dim=0)
adata.obsm["normalized_latent"] = normalizer.normalize(adata.obsm["latent"])

In [None]:
# For 3D slice modeling, in_channels = time_embedding (16) + latent_emebdding (1) + z-axis embedding (concat mode)
denoiser = SpaUNet1DModel(in_channels=18, out_channels=1, spatial_encoding="sinusoidal3d",
                                      spatial3d_concat=True).to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
adata.obs["label_"] = adata.obs[cond_name].map(class_dict)
data_latent = pipeline.prepare_dataset(adata, use_rep="normalized_latent", use_spatial="new_spatial",
                                       use_net="spatial_net", use_label="label_")
train_loader = NeighborLoader(data_latent, num_neighbors=[5, 3], batch_size=256)
denoiser, denoise_loss = pipeline.train_denoiser(train_loader, denoiser, noise_scheduler,
                                                 lr=1e-4, weight_decay=1e-6,
                                                 n_epochs=500,
                                                 num_class_embeds=num_class_embeds,
                                                 device=device)