## Mouse Spleen

In [None]:
import os
import dgl
import pandas as pd
os.getcwd()
os.chdir('/data/xiangdw/MODEL/')
print(os.getcwd())
import sys
import scanpy as sc
import importlib 
import torch as th
import torch.nn as nn
from sklearn.utils import shuffle
from model_integration import *

# 设置西文字体为新罗马字体
from matplotlib import rcParams

config = {
    "font.family":'Times New Roman', 
     "font.size":20, # 设置字体类型
    "axes.unicode_minus": False #解决负号无法显示的问题
}
rcParams.update(config)

### Prediction

#### Batch1 to Batch2

In [None]:
import datetime
file_fold = '/data/xiangdw/data/data/'
a = []
layers_nums = 3
for i in range(1):
    starttime = datetime.datetime.now()
    seeds = i+1
    adata_omics1 = sc.read_h5ad(file_fold + str(seeds) + 'adata_RNA_spots.h5ad')
    adata_omics2 = sc.read_h5ad(file_fold + str(seeds) + 'adata_ADT_spots.h5ad')

    modalities = [adata_omics1, adata_omics2]
    g_spatial_omics1, g_feature_omics1, g_spatial_omics2, g_feature_omics2, adata_omics1, adata_omics2 = Sagegraph(modalities, device, datatype='spots', batch=False)
    output_dir = '/data/xiangdw/data/pred result/'
    weight = [0,0,3]

    pred_name = 'simu_SpaMIE_'+str(layers_nums)+'_pred.csv'
    true_name = 'simu_SpaMIE_'+str(layers_nums)+'_truth.csv'

    in_feat = adata_omics1.obsm['feat'].shape[1]
    out_feat = adata_omics2.X.shape[1]
    train_size = ada_omics1[adata_omics1.obs['batch']=='1'].shape[0]

    model = Sagewrapper(seed=(int(seeds)), device=device, in_feat=in_feat, n_hidden=256, out_feat=out_feat, task='prediction', datatype='spots',
                        layers_nums=int(layers_nums), weight=weight, epoch=600, res_type='res_add', activation=nn.LeakyReLU
                        , sagetype='mean', lr=2e-4, lr2 = 0.002)

    adata_omics1_pred, adata_omics2_pred, test_idx, train_idx,wt,alph  = model.fit(g_spatial_omics1, g_feature_omics1,  g_spatial_omics2, g_feature_omics2,
                                                                                    adata_omics1, adata_omics2, output_dir=output_dir, pred_name=pred_name, 
                                                                                    true_name=true_name, train_size=train_size, weight=True, save_csv=False)



#### Batch2 to Batch1

In [None]:
import datetime
file_fold = '/data/xiangdw/data/data/'
a = []
layers_nums = 3
for i in range(1):
    starttime = datetime.datetime.now()
    seeds = i+1
    adata_omics1 = sc.read_h5ad(file_fold + str(seeds) + 'adata_RNA_spots.h5ad')
    adata_omics2 = sc.read_h5ad(file_fold + str(seeds) + 'adata_ADT_spots.h5ad')

    modalities = [adata_omics1, adata_omics2]
    g_spatial_omics1, g_feature_omics1, g_spatial_omics2, g_feature_omics2, adata_omics1, adata_omics2 = Sagegraph(modalities, device, datatype='spots', batch=True)
    output_dir = '/data/xiangdw/data/pred result/'
    weight = [0,0,1]

    pred_name = 'simu_SpaMIE_'+str(layers_nums)+'_pred.csv'
    true_name = 'simu_SpaMIE_'+str(layers_nums)+'_truth.csv'

    in_feat = adata_omics1.obsm['feat'].shape[1]
    out_feat = adata_omics2.X.shape[1]
    train_size = ada_omics1[adata_omics1.obs['batch']=='2'].shape[0]

    model = Sagewrapper(seed=(int(seeds)), device=device, in_feat=in_feat, n_hidden=256, out_feat=out_feat, task='prediction', datatype='spots',
                        layers_nums=int(layers_nums), weight=weight, epoch=600, res_type='res_add', activation=nn.LeakyReLU
                        , sagetype='mean', lr=2e-4, lr2 = 0.002)

    adata_omics1_pred, adata_omics2_pred, test_idx, train_idx,wt,alph  = model.fit(g_spatial_omics1, g_feature_omics1,  g_spatial_omics2, g_feature_omics2,
                                                                                    adata_omics1, adata_omics2, output_dir=output_dir, pred_name=pred_name, 
                                                                                    true_name=true_name, train_size=train_size, weight=True, save_csv=False)



### Integration

In [None]:
import torch.nn.functional as F
from SpaMIE.create_graph import Sagegraph
from SpaMIE.spamie_main import Sagewrapper
import numpy as np
from model_integration import set_seed
device = th.device('cuda:1' if th.cuda.is_available() else 'cpu')


for i in range(1): 
    seeds = str(i+1)
    path = '/data/xiangdw/data/data/'
    adata_omics1 = sc.read_h5ad(path + str(seeds) + 'adata_RNA_spots.h5ad')
    adata_omics2 = sc.read_h5ad(path + str(seeds) + 'adata_RNA_spots.h5ad')
    set_seed(2024) 
    sc.pp.scale(adata_omics1)
    sc.pp.scale(adata_omics2)
    modalities = [adata_omics1, adata_omics2]
    g_spatial_omics1, g_feature_omics1, g_spatial_omics2, g_feature_omics2, adata_omics1, adata_omics2 = Sagegraph(modalities, device, datatype="spots",batch=True)

    in_feat = adata_omics1.obsm['feat'].shape[1]
    out_feat = adata_omics2.X.shape[1]
    weight = [1,1,1]
    model = Sagewrapper(seed=(int(seeds)), device=device, in_feat=in_feat, n_hidden=256, out_feat=out_feat, task='integration', datatype='spots',
                        layers_nums=int(3), weight=weight, epoch=600, res_type='res_add', activation=nn.LeakyReLU
                        , sagetype='mean', lr=2e-4, lr2 = 0.002)

    output  = model.fit(g_spatial_omics1, g_feature_omics1, g_spatial_omics2, g_feature_omics2, adata_omics1, adata_omics2, weight_factors=[1,5,1,1])

    adata_omics2.obsm['SpaMIE'] = output[0]
    