# FuseMap Tutorial III: Gene Spatial Imputation

This tutorial demonstrates how to perform spatial imputation of unmeasured genes after integrating datasets using FuseMap. We'll use a pre-trained FuseMap gene embedding model to impute the expression of ABCC9 gene on MERFISH data.

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os
import scanpy as sc
import numpy as np
import scipy
from scipy import stats

### 1. Load the pre-trained gene embeddings and cell embeddings

In [3]:
# Define the gene we want to impute and the dataset
target_gene = 'ABCC9'
dataset = 'merfish'

# Set paths to pre-trained model outputs
save_dir = f"/Users/mingzeyuan/Workspace/fusemap_deconvolution/raw_data/rerun_version3_impt_{dataset}_{target_gene}/"
# Load pre-trained gene embeddings and transfer cell type annotations
celltype_all = sc.read_h5ad(os.path.join(save_dir, "transfer_celltype.h5ad"))
gene_embed = sc.read_h5ad(os.path.join(save_dir, "ad_gene_embedding.h5ad"))

### 2. Gene imputation
Based on FuseMap's mechanism, we can simply get spatial gene imputation by multiplying cell type embeddings with gene embeddings

In [None]:
# Get cell type annotations for our dataset
celltype_data = celltype_all[celltype_all.obs['name']==dataset]

# Project spatial data into gene embedding space
spatial_gene_proj = celltype_data.X @ gene_embed.X.T
print(celltype_data.shape, gene_embed.shape, spatial_gene_proj.shape)

### 3. Evaluate spatial gene imputation

In [None]:
# Load and preprocess spatial data
spatial_data = sc.read_h5ad(f'/Users/mingzeyuan/Workspace/FuseMap/data/{dataset}.h5ad')
sc.pp.normalize_total(spatial_data)
sc.pp.log1p(spatial_data)
sc.pp.scale(spatial_data, zero_center=False, max_value=10)

# Extract measured and imputed expression for target gene
measured_expr = spatial_data.X[:,np.where(spatial_data.var.index==target_gene)[0][0]]
imputed_expr = spatial_gene_proj[:,np.where(gene_embed.obs.index==target_gene)[0][0]]

# Convert sparse matrices to dense if needed
if scipy.sparse.isspmatrix(measured_expr):
    measured_expr = measured_expr.toarray().flatten()
if scipy.sparse.isspmatrix(imputed_expr):
    imputed_expr = imputed_expr.toarray().flatten()

# Calculate correlation between measured and imputed expression
correlation = stats.pearsonr(measured_expr, imputed_expr)[0]
print(f"Correlation between measured and imputed {target_gene}: {correlation:.3f}")