# Ablation Experiments on Simulated Dataset

## Loading packages

In [None]:
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import scanpy as sc
import numpy as np
from sklearn.metrics.cluster import adjusted_rand_score
import anndata as ad
import os
import numpy as np
import matplotlib.pyplot as plt

os.environ['R_HOME'] = 'E:/R-4.3.1'
os.environ['R_USER'] = 'E:/anaconda/lib/site-packages/rpy2'
import sys
sys.path.append(r'D:/study/learning\spatial_transcriptome/papers\spatial_multi_omics-main')
from Model.utils import mclust_R
from Model.model import DCCAE
from Model.preprocess import fix_seed
fix_seed(2024)

## Loading data
We use four replicates of the simulated dataset 1 to conduct ablation experiments.

In [None]:
replicate = 1
file_fold_1 = f'D:/study/learning/spatial_transcriptome/papers/spatial_multi_omics-main/data/Spatial_Scenario_{replicate}/simulation{replicate}_RNA'
file_fold_2 = f'D:/study/learning/spatial_transcriptome/papers/spatial_multi_omics-main/data/Spatial_Scenario_{replicate}/simulation{replicate}_Protein'

adata_omics_1 = sc.read_h5ad(file_fold_1 + '.h5ad')
adata_omics_2 = sc.read_h5ad(file_fold_2 + '.h5ad')

adata_omics_1.X = adata_omics_1.uns['INR']
adata_omics_2.X = adata_omics_2.uns['INR']

batch = 1
adata_RNA = adata_omics_1[adata_omics_1.obs['batch'] == batch]
adata_ADT = adata_omics_2[adata_omics_2.obs['batch'] == batch]

sc.tl.pca(adata_RNA, use_highly_variable=False)
sc.tl.pca(adata_ADT, use_highly_variable=False)

## Running SpaKnit

To assess the contribution of each module in SpaKnit, we construct ablation experiments using simulated dataset across four spatial patterns. We compare the performance of the complete SpaKnit framework with modified versions where specific components were removed: 
- SpaKnit without INR to evaluate the impact of removing the INR module, which models spatial continuity; 
- SpaKnit without DCCAE to examine the effect of excluding DCCAE, responsible for learning cross-modal correlations; 
- SpaKnit without CCA Loss to test the importance of CCA in optimizing cross-modal alignment; 
- SpaKnit without Reconstruction Loss to investigate the necessity of reconstruction loss in preserving modality-specific structures. 

In [None]:
n_DCCA = 10

features1 = adata_RNA.obsm['X_pca'].shape[1]  # Feature sizes
features2 = adata_ADT.obsm['X_pca'].shape[1]
layers1 = [256, 256, n_DCCA]  # nodes in each hidden layer and the output size
layers2 = [256, 256, n_DCCA]
X = adata_RNA.obsm['X_pca'].copy()
Y = adata_ADT.obsm['X_pca'].copy()

use_rep = ['DCCA_X', "DCCA_Y", "DCCA"]

epochs = 300
dcca = DCCAE(input_size1=features1, input_size2=features2, n_components=n_DCCA, layer_sizes1=layers1, layer_sizes2=layers2, epoch_num=epochs, learning_rate=0.001)
dcca.fit([X, Y])
Xs_transformed = dcca.transform([X, Y])
adata_RNA.obsm["DCCA_X"] =  Xs_transformed[0]
adata_ADT.obsm["DCCA_Y"] =  Xs_transformed[1]
adata_RNA.obsm["DCCA"] = np.concatenate((adata_RNA.obsm["DCCA_X"], adata_ADT.obsm["DCCA_Y"]),axis=1)

use_rep = ['DCCA_X','DCCA_Y', 'DCCA']
n = 4
mclust_R(adata_RNA, used_obsm=use_rep[0], num_cluster=n)
obs_df = adata_RNA.obs.dropna()
ARI_1 = adjusted_rand_score(obs_df['clusters_mclust'], obs_df['Ground Truth'])
print(f'n={n}, DCCA_X, ARI = {ARI_1}')
    
mclust_R(adata_ADT, used_obsm=use_rep[1], num_cluster=n)
obs_df = adata_ADT.obs.dropna()
ARI_2 = adjusted_rand_score(obs_df['clusters_mclust'], obs_df['Ground Truth'])
print(f'n={n}, DCCA_Y, ARI = {ARI_2}')

mclust_R(adata_RNA, used_obsm=use_rep[2], num_cluster=n)
obs_df = adata_RNA.obs.dropna()
ARI_3 = adjusted_rand_score(obs_df['clusters_mclust'], obs_df['Ground Truth'])
print(f'n={n}, DCCA, ARI = {ARI_3}')

Training Progress:   0%|          | 0/300 [00:00<?, ?it/s]

Training Progress: 100%|██████████| 300/300 [00:09<00:00, 32.40it/s]

model training finished!
fitting ...
n=4, DCCA_X, ARI = 0.9423137333367312
fitting ...
  |                                                                            




n=4, DCCA_Y, ARI = 0.8846294170335701
fitting ...
n=4, DCCA, ARI = 0.973675813012235


## Storing results

In [None]:
adata_result = sc.read_h5ad(f'./Results/sensitivity analysis and ablation experiment/Spatial_Scenario/Ablation_Experiment/result{replicate}.h5ad')

In [24]:
name = ['without_INR', 'without_DCCAE', 'without_CCA_loss', 'without_rec_loss', 'without_liear_CCA_layer']
adata_result.obs[name[3]] = adata_RNA.obs['clusters_mclust'].values
adata_result.obsm[name[3]] = adata_RNA.obsm['DCCA']
# adata_result.obsm['spatial'] = adata_RNA.obsm['spatial']

In [None]:
adata_result.write_h5ad(f'./Results/sensitivity analysis and ablation experiment/Spatial_Scenario/Ablation_Experiment/result{replicate}.h5ad')