In [1]:
import pandas as pd
import numpy as np
import scanpy as sc
import os
import torch
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import gc
import warnings
# from model.train import train_SMART
from model.utils import fix_seed

fix_seed(2025)

warnings.filterwarnings('ignore')
# Environment configuration. SpatialGlue pacakge can be implemented with either CPU or GPU. GPU acceleration is highly recommend for imporoved efficiency.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [2]:
# Load data
file_fold = '/data/hulei/ZhaoruiJiang/Data/SpatialGlue/'
adata_omics1 = sc.read_h5ad(file_fold + 'Dataset1_Mouse_Spleen1/adata_RNA.h5ad')
adata_omics2 = sc.read_h5ad(file_fold + 'Dataset1_Mouse_Spleen1/adata_ADT.h5ad')
adata_omics1.var_names_make_unique()
adata_omics2.var_names_make_unique()
test_adata_omics1 = sc.read_h5ad(file_fold + 'Dataset2_Mouse_Spleen2/adata_RNA.h5ad')
test_adata_omics2 = sc.read_h5ad(file_fold + 'Dataset2_Mouse_Spleen2/adata_ADT.h5ad')
test_adata_omics1.var_names_make_unique()
test_adata_omics2.var_names_make_unique()

In [3]:
from model.utils import pca
from model.utils import clr_normalize_each_cell

In [4]:
#Normalization
sc.pp.highly_variable_genes(adata_omics1, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata_omics1, target_sum=1e4)
sc.pp.log1p(adata_omics1)
sc.pp.highly_variable_genes(test_adata_omics1, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(test_adata_omics1, target_sum=1e4)
sc.pp.log1p(test_adata_omics1)

In [5]:
adata_omics2 = clr_normalize_each_cell(adata_omics2)
# sc.pp.log1p(adata_omics2)
test_adata_omics2 = clr_normalize_each_cell(test_adata_omics2)
# sc.pp.log1p(test_adata_omics2)

In [6]:
adata_omics1_high = adata_omics1[:, adata_omics1.var['highly_variable']]
adata_omics1.obsm['feat'] = pca(adata_omics1_high, n_comps=adata_omics2.n_vars-1)

test_adata_omics1_high = test_adata_omics1[:, test_adata_omics1.var['highly_variable']]
test_adata_omics1.obsm['feat'] = pca(test_adata_omics1_high, n_comps=test_adata_omics2.n_vars-1)

In [7]:
adata_omics2 = adata_omics2[adata_omics1.obs_names].copy() 
adata_omics2.obsm['feat'] = pca(adata_omics2, n_comps=adata_omics2.n_vars-1)
test_adata_omics2 = test_adata_omics2[test_adata_omics1.obs_names].copy() 
test_adata_omics2.obsm['feat'] = pca(test_adata_omics2, n_comps=test_adata_omics2.n_vars-1)

In [8]:
from model.utils import Cal_Spatial_Net

Cal_Spatial_Net(adata_omics1, model="KNN", n_neighbors=3)
Cal_Spatial_Net(adata_omics2, model="KNN", n_neighbors=3)
Cal_Spatial_Net(test_adata_omics1, model="KNN", n_neighbors=3)
Cal_Spatial_Net(test_adata_omics2, model="KNN", n_neighbors=3)

The graph contains 7704 edges, 2568 cells.
3.0000 neighbors per cell on average.
The graph contains 7704 edges, 2568 cells.
3.0000 neighbors per cell on average.
The graph contains 8304 edges, 2768 cells.
3.0000 neighbors per cell on average.
The graph contains 8304 edges, 2768 cells.
3.0000 neighbors per cell on average.


In [9]:
from model.utils import Mutual_Nearest_Neighbors
anchors1, positives1, negatives1 = Mutual_Nearest_Neighbors(adata_omics1, key="feat", n_nearest_neighbors=4,
                                                            farthest_ratio=0.6)
anchors2, positives2, negatives2 = Mutual_Nearest_Neighbors(adata_omics2, key="feat", n_nearest_neighbors=4,
                                                            farthest_ratio=0.6)

distances calculation completed!
The data use feature 'feat' contains 3340 mnn_anchors
distances calculation completed!
The data use feature 'feat' contains 4298 mnn_anchors


In [10]:
torch.FloatTensor(adata_omics2.X).shape[1]

21

In [11]:
from model.train import train_STProtein

SyntaxError: positional argument follows keyword argument (train.py, line 14)

In [None]:
out, model = train_STProtein(adata = adata_omics1,
                  ground_truth= torch.FloatTensor(adata_omics2.X),
                  feature_key="feat",
                  edge_key="edgeList",
                  weights=[1, 1],
                  n_epochs=1000,
                  weight_decay=0.001
                  )

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:07<00:00, 136.75it/s]


In [None]:
import torch
from tqdm import tqdm

from model.model import SMART,SAGEConv_Encoder,SAGEConv_Decoder
import torch.nn.functional as F


def train_STProteinv2(adata, ground_truth ,triplet_samples, feature_key="feat", edge_key="edgeList", weights=None,emb_dim=64, n_epochs=500,
               lr=0.0001,weight_decay=1e-5,device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
    x1, edge_index1 = torch.FloatTensor(adata.obsm[feature_key]), torch.LongTensor(adata.uns[edge_key])
    emb_dim = ground_truth.shape[1]
    target = ground_truth.to(device)
    print(x1)
    
    print(edge_index1)
    anchors1, positives1, negatives1 = triplet_samples
    print(x1.shape)
    hidden_dims=[x1.shape, emb_dim]
    print(hidden_dims)
    model = SMART(hidden_dims=[x1.shape[1], emb_dim])

    x1, edge_index1= x1.to(device), edge_index1.to(device)
    model.to(device)
    
    n_epochs = n_epochs
    loss_list = []
    w1, w2 = weights
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=weight_decay)
    
    for epoch in tqdm(range(1, n_epochs + 1)):
        model.train()
        optimizer.zero_grad()
        z, x1_rec= model(x1, edge_index1)
    
        anchor_arr1 = x1_rec[anchors1,]
        positive_arr1 = x1_rec[positives1,]
        negative_arr1 = x1_rec[negatives1,]
    
        triplet_loss = torch.nn.TripletMarginLoss(margin=1, p=2, reduction='mean')
        tri_output1 = triplet_loss(anchor_arr1, positive_arr1, negative_arr1)

    
        loss = w1 * F.mse_loss(x1, x1_rec) + w2*tri_output1 + w1*F.mse_loss(target, z)
        loss_list.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()

    
    model.eval()
    x1, edge_index1 = x1.to(device), edge_index1.to(device)
    z, x1_rec = model(x1, edge_index1)

    return z, model
        
    


In [None]:
# out, model = train_STProteinv2(adata = adata_omics1,
#                   ground_truth= torch.FloatTensor(adata_omics2.X),
#                   triplet_samples=(anchors1, positives1, negatives1),
#                   feature_key="feat",
#                   edge_key="edgeList",
#                   weights=[1, 1],
#                   n_epochs=1000,
#                   weight_decay=0.001)

tensor([[ 1.4073e+00,  2.4436e+00, -1.7743e-01,  ...,  1.5488e+00,
          1.4655e-01,  1.3487e+00],
        [ 2.7535e+00, -5.1512e-01, -6.2535e-01,  ...,  1.3381e-02,
         -2.6619e+00, -7.3983e-02],
        [ 3.6504e+00,  1.7775e+00, -1.8608e+00,  ..., -1.2904e+00,
         -5.0396e-01, -1.9837e-01],
        ...,
        [-5.1598e+00, -9.6904e-01, -1.9860e+00,  ...,  1.2161e-01,
         -8.0446e-01, -7.5742e-01],
        [ 1.5710e+00, -9.7967e-01, -8.6028e-01,  ..., -1.5448e+00,
         -2.4997e+00,  2.2396e+00],
        [-2.0188e+00, -2.1807e-03, -1.5388e+00,  ..., -6.4205e-01,
          1.5537e+00, -1.5805e+00]])
tensor([[   0,    0,    0,  ..., 2567, 2567, 2567],
        [2330, 1190, 2125,  ..., 1256, 1371, 1983]])
torch.Size([2568, 20])
[torch.Size([2568, 20]), 21]


  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:11<00:00, 89.57it/s]


In [None]:
out

tensor([[ 0.9201,  0.6181,  1.0360,  ...,  0.0499,  0.0829,  0.4687],
        [ 1.2642,  0.8110,  1.4424,  ...,  0.1793, -0.0323,  0.5874],
        [ 1.2289,  0.6098,  1.0697,  ...,  0.0665,  0.1847,  0.5152],
        ...,
        [ 0.2237,  1.1503,  0.8767,  ...,  0.2778,  0.3653,  0.8861],
        [ 0.6627,  0.7835,  1.2358,  ...,  0.2099,  0.3569,  0.4614],
        [ 0.3432,  0.5640,  0.7524,  ...,  0.0936,  0.2047,  0.3281]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
target_1 = torch.FloatTensor(adata_omics2.X).to(device)

In [None]:
target_1

tensor([[0.8446, 0.6563, 1.6487,  ..., 0.2694, 0.3722, 0.5872],
        [0.8021, 0.7874, 1.6377,  ..., 0.2557, 0.3282, 0.4385],
        [0.9648, 0.6843, 1.7725,  ..., 0.3220, 0.3028, 0.6643],
        ...,
        [0.5764, 0.7217, 0.7081,  ..., 0.3089, 0.2882, 0.4778],
        [1.0276, 0.8202, 1.3457,  ..., 0.1555, 0.2471, 0.6598],
        [0.8590, 0.8590, 0.9274,  ..., 0.2703, 0.0912, 0.6825]],
       device='cuda:0')

In [None]:
import torch.nn.functional as F
F.mse_loss(target_1, out)

tensor(0.1268, device='cuda:0', grad_fn=<MseLossBackward0>)

In [None]:

test_x1, test_edge_index1 = torch.FloatTensor(test_adata_omics1.obsm["feat"]), torch.LongTensor(test_adata_omics1.uns["edgeList"])
test_z, test_out = model(test_x1.to(device), test_edge_index1.to(device))

In [None]:
test_z

tensor([[ 0.7102,  0.6781,  1.0333,  ...,  0.2398,  0.2326,  0.4020],
        [ 1.0262,  0.2326,  0.7908,  ..., -0.2600, -0.0425,  0.1531],
        [ 0.3880,  1.1827,  1.2729,  ...,  0.8883,  0.5605,  0.8280],
        ...,
        [ 0.9790,  0.4583,  0.9958,  ..., -0.0115,  0.1313,  0.0965],
        [ 0.5296,  0.6493,  0.8867,  ...,  0.0662,  0.1503,  0.4268],
        [ 0.3977,  0.9528,  1.1948,  ...,  0.2197,  0.2451,  0.6287]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
target_2 = torch.FloatTensor(test_adata_omics2.X).to(device)

In [None]:
F.mse_loss(target_2, test_z)

tensor(0.2229, device='cuda:0', grad_fn=<MseLossBackward0>)

In [None]:
out.to('cpu').detach().numpy()

array([[ 0.9200578 ,  0.6180627 ,  1.0360086 , ...,  0.04990932,
         0.08289612,  0.4687003 ],
       [ 1.2642105 ,  0.8109873 ,  1.4424335 , ...,  0.17926188,
        -0.03232881,  0.58735865],
       [ 1.2289163 ,  0.6097862 ,  1.0696971 , ...,  0.06648731,
         0.18469982,  0.51522267],
       ...,
       [ 0.22366384,  1.1502502 ,  0.8767188 , ...,  0.27784175,
         0.36527306,  0.8861254 ],
       [ 0.6626982 ,  0.7835271 ,  1.2358161 , ...,  0.20994337,
         0.35687056,  0.46137583],
       [ 0.34324738,  0.5640084 ,  0.7523655 , ...,  0.09359919,
         0.20470257,  0.328089  ]], dtype=float32)