In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import os
import torch

from mucstpy.utils import add_contrastive_label, get_feature
import warnings

warnings.filterwarnings('ignore')
os.environ['R_HOME'] = 'C:/Program Files/R/R-4.3.1'

In [None]:
adata = sc.read_h5ad('D:/st_projects/data/slide_seq/v2/hippocampus/slideseqv2.h5ad')
adata.obsm['spatial'][:, 1] = -adata.obsm['spatial'][:, 1]

In [None]:
sc.pp.calculate_qc_metrics(adata, inplace=True)
adata = adata[:, adata.var['total_counts'] > 100]

sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)

sc.pp.filter_genes(adata, min_cells=1)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, zero_center=False, max_value=10)

In [None]:
sc.pl.embedding(adata, basis='spatial', color='cluster', size=10)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

cluster_num = pd.get_dummies(adata.obs['cluster']).shape[1]
cluster_num, device

### For SRT data without histology image, just set $\lambda_1$ =0

In [None]:
from mucstpy.utils import construction_interaction

construction_interaction(adata=adata, n_neighbor=15)

In [None]:
add_contrastive_label(adata)
get_feature(adata)

gene_dims=[adata.shape[1], 64]

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from mucstpy.model import Encoder, Decoder, Discriminator, AvgReadout

class MuCST_no_his(nn.Module):
    def __init__(self, gene_dims, graph_nei):
        super().__init__()
        self.graph = graph_nei
        # in_dim -> hidden -> project[0]
        self.gene_encoder_layer1 = Encoder(in_dims=gene_dims[0], hidden_dims=gene_dims[1:])
        self.decoder = Decoder(hidden_dims=list(reversed(gene_dims[1:])), out_dims=gene_dims[0])
        self.disc = Discriminator(gene_dims[-1])
        self.sigma = nn.Sigmoid()
        self.read = AvgReadout()
        self.mse_loss = nn.MSELoss()

    def forward_gene(self, gene, graph):
        zg = self.gene_encoder_layer1.forward(x=gene, adj=graph)
        return zg
    
    def recon_gene_loss(self, zg, xg, graph):
        zg = self.decoder(zg, graph)
        return self.mse_loss(zg, xg)
    
    def forward(self, gene, fake_gene , graph):
        # encode the gene expression into latent embeddings
        zg = self.forward_gene(gene, graph)
        zg_fake = self.forward_gene(fake_gene, graph)

        emb_true = F.relu(zg)
        emb_fake = F.relu(zg_fake)
        
        g = self.read(emb_true, self.graph)
        g = self.sigma(g)
        g_fake = self.read(emb_fake, self.graph)
        g_fake = self.sigma(g_fake)

        dis_a = self.disc(g, emb_true, emb_fake)
        dis_b = self.disc(g_fake, emb_fake, emb_true)
        
        rec_gene = self.decoder(zg, graph)
        # rec_gene = F.relu(rec_gene)
        return zg, rec_gene, dis_a, dis_b

In [None]:
features = torch.FloatTensor(adata.obsm['feat'].copy()).to(device)
features_fake = torch.FloatTensor(adata.obsm['feat_fake'].copy()).to(device)
label_cont = torch.FloatTensor(adata.obsm['label_CSL']).to(device)
adj = adata.obsm['adj']
graph_neigh = torch.FloatTensor(adata.obsm['graph_neigh'].copy() + np.eye(adj.shape[0])).to(device)

In [None]:
from tqdm import tqdm

model = MuCST_no_his(gene_dims=gene_dims, graph_nei=graph_neigh).to(device)

loss_CSL = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.)

print('Begin to train MuCST without histology image...')
model.train()

for epoch in tqdm(range(1500)):
    hidden_fea, rec_data, ret, ret_fake = model(features, features_fake, graph_neigh)
    loss_cont = loss_CSL(ret, label_cont)
    loss_cont_dual = loss_CSL(ret_fake, label_cont)
    loss_feat = F.mse_loss(features, rec_data)

    loss = loss_feat + 0.1 * (loss_cont + loss_cont_dual)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
from sklearn.decomposition import PCA

with torch.no_grad():
    model.eval()
    _, rec_feature, _, _ = model.forward(features, features_fake, graph_neigh)
    rec_feature = rec_feature.detach().cpu().numpy()
    
adata.obsm['rec_feature'] = rec_feature
pca = PCA(n_components=50, random_state=2023)
rec_feat_pca = pca.fit_transform(rec_feature)
adata.obsm['rec_feat_pca'] = rec_feat_pca

In [None]:
sc.set_figure_params(figsize=(4, 4))
sc.tl.leiden(adata, key_added='leiden_rec_feature', resolution=0.3)
sc.pl.embedding(adata, basis='spatial', color=['leiden_rec_feature'], size=30)