In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy.external as sce
from sklearn import preprocessing
import pickle5 as pickle
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn import preprocessing
import sklearn
from sklearn.neighbors import KNeighborsClassifier

from utils import *
import matplotlib.pyplot as plt

eps=1e-100


In [None]:
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split
import torch.nn as nn
class NNTransfer(nn.Module):
    def __init__(self, input_dim=128, output_dim=10):
        super(NNTransfer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, output_dim)
        self.activate = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x= self.activate(x)
        return x

def NNTransferTrain(model, criterion, optimizer, train_loader,val_loader, device, save_dir,label_key,focus_name, epochs=200):
    eval_accuracy_mini=0#np.inf
    patience_count=0
    for epoch in range(epochs):
        model.train()
        loss_all=0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_all+=loss.item()
        eval_loss, eval_accuracy = NNTransferEvaluate(model, val_loader, criterion, device)
        if eval_accuracy_mini<eval_accuracy:
            eval_accuracy_mini=eval_accuracy
#             torch.save(model.state_dict(), pth+f"/predictor/FuseMap_NNtransfer_transfer_tissue_region_sub_refine.pt")
            print(f"Epoch {epoch}/{epochs} - Train Loss: {loss_all / len(train_loader)}, Accuracy: {eval_accuracy}")
            patience_count=0
        else:
            patience_count+=1
        if patience_count>10:
            p=0
            print(f"Epoch {epoch}/{epochs} - early stopping due to patience count")
            break

def NNTransferEvaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return total_loss/len(dataloader), accuracy

def NNTransferPredictWithUncertainty(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_uncertainties = []

    with torch.no_grad():
        for inputs in dataloader:
            inputs = inputs[0].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            confidence = torch.max(outputs, dim=1)[0]
            uncertainty = 1 - confidence
            all_predictions.extend(predicted.detach().cpu().numpy())
            all_uncertainties.extend(uncertainty.detach().cpu().numpy())

    return np.vstack(all_predictions), np.vstack(all_uncertainties)


read raw data

In [None]:
X_input=[]
raw_data = sc.read_h5ad('source_data/ribo_STARmap-rep2.h5ad')
raw_data.obs['name']='query'
raw_data.X=raw_data.layers['counts']
X_input.append(raw_data)


## transfer main cell-type labels

In [None]:
celltype_all_o=sc.read_h5ad(f"source_data/transfer_celltype_ref.h5ad")    
celltype_all_o.obs.loc[celltype_all_o.obs['name']=='query','query_celltype_main']=list(X_input[0].obs['gt_cell_type_main'])
celltype_all_o.obs.loc[celltype_all_o.obs['name']=='query','query_celltype_sub']=list(X_input[0].obs['gt_cell_type_sub'])
celltype_all_o.obs.loc[celltype_all_o.obs['name']=='query','query_tissueregion_main']=list(X_input[0].obs['gt_tissue_region_main'])


In [None]:
fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(celltype_all_o,color='name',size=1,
                ax=ax,show=False)
plt.axis('off')
plt.title('Single-cell embedding colored by data source')

In [None]:
### main cell type color
color_dic_main=pd.read_csv('source_data/color/cell_type/starmap_main.csv',
                           index_col=0)
color_dic=dict(zip(color_dic_main['key'],color_dic_main['color']))

label = celltype_all_o.obs['transfer_gt_cell_type_main_STARmap_ref']
location = np.array(celltype_all_o.obsm['X_umap'])
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(location, label)

predicted_labels = knn.predict(location)
celltype_all_o.obs['transfer_gt_cell_type_main_STARmap_ref'] = predicted_labels




fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']!='ref',:],
                color='transfer_gt_cell_type_main_STARmap_ref',size=3,
                palette=color_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Single-cell embedding, colored by transferred main level cell types')



ad_plot = celltype_all_o[celltype_all_o.obs['name']!='ref',:]
coeefi=(max(ad_plot.obs['x'])-min(ad_plot.obs['x']))/(max(ad_plot.obs['y'])-min(ad_plot.obs['y']))

plt.figure(figsize=(8,8/coeefi))
plt.scatter(ad_plot.obs['x'],
           ad_plot.obs['y'],
            s=1,
           c=[color_dic[i] for i in ad_plot.obs['transfer_gt_cell_type_main_STARmap_ref']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by transferred main level cell types')


In [None]:
query_dic={}
for i,j in zip(X_input[0].uns['level_2_color'],X_input[0].uns['level_2_order']):
    query_dic[j]=i
    
fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']!='ref',:],
                color='query_celltype_main',size=3,palette=query_dic,
                legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Single-cell embedding, colored by original main level cell types')


coeefi=(max(X_input[0].obs['x'])-min(X_input[0].obs['x']))/(max(X_input[0].obs['y'])-min(X_input[0].obs['y']))

plt.figure(figsize=(8,8/coeefi))
plt.scatter(X_input[0].obs['x'],
           X_input[0].obs['y'],
            s=1,
           c=[query_dic[i] for i in celltype_all_o[celltype_all_o.obs['name']!='ref',:].obs['query_celltype_main']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by original main level cell types')


In [None]:
focus=celltype_all_o[celltype_all_o.obs['name']!='ref',:]
focus=focus[focus.obs['transfer_gt_cell_type_main_STARmap_ref']!='Unannotated']
focus=focus[focus.obs['query_celltype_main']!='Mix']

GT=np.array(focus.obs['query_celltype_main'])
PRED=np.array(focus.obs['transfer_gt_cell_type_main_STARmap_ref'])


cross_tab = pd.crosstab(pd.Series(GT, name='Original'),
                                pd.Series(PRED, name='FuseMap'))
cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=0), axis=1)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=1), axis=0)

old_list=['Astrocyte', 'Astroependymal cells',
       'Cholinergic, monoaminergic and peptidergic neurons',
       'Di/Mesencephalon neurons', 'Microglia', 'Oligodendrocyte',
       'Oligodendrocytes precursor cell', 'Perivascular macrophages',
       'Telencephalon interneurons', 'Telencephalon projecting neurons',
       'Vascular cells']


new_list= ['Astrocytes', 'Olfactory ensheathing cells', 
       'Choroid plexus epithelial cells','Ependymal cells', 
            'Hindbrain neurons/Spinal cord neurons',
           'Peptidergic neurons', 
       'Di- and mesencephalon excitatory neurons',
       'Microglia', 'Non-glutamatergic neuroblasts',
           'Oligodendrocytes',
           'Oligodendrocyte precursor cells','Cholinergic and monoaminergic neurons',
       'Telencephalon inhibitory interneurons',
           'Di- and mesencephalon inhibitory neurons', 
       'Glutamatergic neuroblasts',
       'Telencephalon projecting excitatory neurons','Dentate gyrus granule neurons',
       'Telencephalon projecting inhibitory neurons','Pericytes',
       'Vascular and leptomeningeal cells', 'Vascular endothelial cells',
       'Vascular smooth muscle cells']

cross_tab_normalized = cross_tab_normalized[new_list]
cross_tab_normalized = cross_tab_normalized.loc[old_list]
cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)

plt.figure(figsize=(10,6))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,vmax=1)
plt.title("Normalized Correspondence of Two Categories")
plt.show()


### transfer sublevel cell type

In [None]:
result_celltype=celltype_all_o
gt_ref_key='gt_cell_type_main_STARmap'
gt_ref_sub_key='gt_cell_type_sub_STARmap'
transfer_key='transfer_gt_cell_type_main_STARmap_ref'
result_celltype.obs[f'transfer_{gt_ref_sub_key}']='-1'

In [None]:
for focus_main in result_celltype[result_celltype.obs['name']=='ref'].obs[gt_ref_key].unique():
    print(focus_main)
    
    ad_embed_OB = result_celltype[result_celltype.obs[transfer_key]==focus_main]
    if ad_embed_OB.shape[0]==0:
        continue
    sub_list=list(result_celltype[result_celltype.obs[gt_ref_key]==focus_main].obs[gt_ref_sub_key].value_counts().keys())

    ad_embed_train = ad_embed_OB[ad_embed_OB.obs[gt_ref_sub_key].isin(sub_list)]
    sample1_embeddings = ad_embed_train.X
    sample1_labels = list(ad_embed_train.obs[gt_ref_sub_key])
        
    le = preprocessing.LabelEncoder()
    le.fit(sample1_labels)
    
    
    sample1_labels = le.transform(sample1_labels)
    sample1_labels = sample1_labels.astype('str').astype('int')
    

    dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
    train_size = int(0.95 * len(dataset1))  # Use 80% of the data for training
    val_size = len(dataset1) - train_size
    train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    
    import torch.optim as optim
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                                classes=np.unique(sample1_labels),
                                                                                y=sample1_labels))
    model = NNTransfer(input_dim=sample1_embeddings.shape[1],
                       output_dim=len(np.unique(sample1_labels)))
    model.to(device)  # Move the model to GPU if available
    criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    NNTransferTrain(model, criterion, optimizer, train_loader, val_loader, device, 0, 0, 0)

    
    sample2_embeddings = ad_embed_OB.X
    dataset2 = TensorDataset(torch.Tensor(sample2_embeddings))
    dataloader2 = DataLoader(dataset2, batch_size=256, shuffle=False)
    sample2_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, dataloader2, device)
    sample2_predictions = le.inverse_transform(sample2_predictions)

    result_celltype.obs.loc[result_celltype.obs[transfer_key]==focus_main,f'transfer_{gt_ref_sub_key}'] = sample2_predictions
    result_celltype.obs.loc[result_celltype.obs[transfer_key]==focus_main,f'transfer_{gt_ref_sub_key}_uncertainty'] = sample2_uncertainty
    

In [None]:
### main cell type color
color_dic_main=pd.read_csv('source_data/color/cell_type/starmap_sub.csv',
                           index_col=0)
color_dic=dict(zip(color_dic_main['key'],color_dic_main['color']))


fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']!='ref',:],
                color='transfer_gt_cell_type_sub_STARmap',size=3,
                palette=color_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Single-cell embedding, colored by transferred sub level cell types')



ad_plot = celltype_all_o[celltype_all_o.obs['name']!='ref',:]
coeefi=(max(ad_plot.obs['x'])-min(ad_plot.obs['x']))/(max(ad_plot.obs['y'])-min(ad_plot.obs['y']))

plt.figure(figsize=(8,8/coeefi))
plt.scatter(ad_plot.obs['x'],
           ad_plot.obs['y'],
            s=1,
           c=[color_dic[i] for i in ad_plot.obs['transfer_gt_cell_type_sub_STARmap']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by transferred sub level cell types')


In [None]:
query_dic={}
for i,j in zip(X_input[0].uns['level_3_color'],X_input[0].uns['level_3_order']):
    query_dic[j]=i
    
fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']!='ref',:],
                color='query_celltype_sub',size=3,
                palette=query_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Single-cell embedding, colored by original sub level cell types')


coeefi=(max(X_input[0].obs['x'])-min(X_input[0].obs['x']))/(max(X_input[0].obs['y'])-min(X_input[0].obs['y']))

plt.figure(figsize=(7,7/coeefi))
plt.scatter(focus.obs['x'],
           focus.obs['y'],
            s=1,
           c=[query_dic[i] for i in focus.obs['query_celltype_sub']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by original sub level cell types')


### TEGLU

In [None]:
ad_embed = celltype_all_o
ad_embed_query = ad_embed[ad_embed.obs['name']=='query']
ad_embed_ref = ad_embed[ad_embed.obs['name']!='query']

In [None]:
old_list=['DGGRC', 'TEGLU CA1', 'TEGLU CA2', 'TEGLU CA3', 
       'TEGLU L2_3', 'TEGLU L4_5', 'TEGLU L5', 'TEGLU L6','TEGLU COA', 'TEGLU L6a',
       'TEGLU Mix', 'TEGLU PIR']



new_list= ['DGGRC', 'TEGLU_5', 'TEGLU_6','TEGLU_4',  
           'TEGLU_39','TEGLU_25', 'TEGLU_19','TEGLU_8','TEGLU_9', 'TEGLU_40', 'TEGLU_29',
          'TEGLU_23', 'TEGLU_35','TEGLU_3','TEGLU_1',  'TEGLU_41', 'TEGLU_12',  'TEGLU_10',
       'TEGLU_14',  'TEGLU_18', 
       'TEGLU_22', 
       'TEGLU_30',  'TEGLU_37', 
      'TEGLU_7', 'TEGLU_17',  'TEGLU_13','TEGLU_24',  'TEGLU_11','TEGLU_16', 'TEGLU_15', ]


In [None]:
ad_embed_query_focus = ad_embed_query[ad_embed_query.obs['transfer_gt_cell_type_sub_STARmap'].isin(new_list)]
ad_embed_ref_focus = ad_embed_ref[ad_embed_ref.obs['transfer_gt_cell_type_sub_STARmap'].isin(new_list)]

ad_embed_query_focus=ad_embed_query_focus[ad_embed_query_focus.obs['query_celltype_sub'].isin(old_list)]


plot spatial map

In [None]:
np.random.seed(130)
color_palette=np.random.rand(ad_embed_query_focus.obs['transfer_gt_cell_type_sub_STARmap'].unique().shape[0],3)

import seaborn as sns
color_dic = {}
for ind,i in enumerate(ad_embed_query_focus.obs['transfer_gt_cell_type_sub_STARmap'].unique()):
    color_dic[i]=color_palette[ind]

In [None]:

coeefi=(max(ad_embed_query_focus.obs['x'])-min(ad_embed_query_focus.obs['x']))/(max(ad_embed_query_focus.obs['y'])-min(ad_embed_query_focus.obs['y']))

plt.figure(figsize=(6,6/coeefi))


plt.scatter(ad_embed_query_focus.obs['x'],
           ad_embed_query_focus.obs['y'],
            s=1,
           c=[color_dic[i] for i in ad_embed_query_focus.obs['transfer_gt_cell_type_sub_STARmap']])

plt.gca().invert_yaxis()
plt.axis('off')
# plt.savefig('figures_umap/celltype_query_sub.png',dpi=300,transparent=True)


In [None]:
ad_embed_ref_focus_sample4=ad_embed_ref_focus[ad_embed_ref_focus.obs['batch']=='sample4']

coeefi=(max(ad_embed_ref_focus_sample4.obs['x'])-min(ad_embed_ref_focus_sample4.obs['x']))/(max(ad_embed_ref_focus_sample4.obs['y'])-min(ad_embed_ref_focus_sample4.obs['y']))

plt.figure(figsize=(6,6*coeefi))
plt.scatter(ad_embed_ref_focus_sample4.obs['y'],
           ad_embed_ref_focus_sample4.obs['x'],
            s=1,
           c=[color_dic[i] for i in ad_embed_ref_focus_sample4.obs['transfer_gt_cell_type_sub_STARmap']])

plt.gca().invert_yaxis()
plt.axis('off')


In [None]:
np.random.seed(0)
color_palette=np.random.rand(ad_embed_query_focus.obs['query_celltype_sub'].unique().shape[0],3)

color_dic = {}
for ind,i in enumerate(ad_embed_query_focus.obs['query_celltype_sub'].unique()):
    color_dic[i]=color_palette[ind]
    
    
coeefi=(max(focus.obs['x'])-min(focus.obs['x']))/(max(focus.obs['y'])-min(focus.obs['y']))

plt.figure(figsize=(6,6/coeefi))


plt.scatter(ad_embed_query_focus.obs['x'],
           ad_embed_query_focus.obs['y'],
            s=1,
           c=[color_dic[i] for i in ad_embed_query_focus.obs['query_celltype_sub']])

plt.gca().invert_yaxis()
plt.axis('off')


## transfer main tissue-region labels

In [None]:
tissueregion_all_o=sc.read_h5ad(f"source_data/transfer_tissueregion_ref.h5ad")    

In [None]:
fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(tissueregion_all_o,color='name',size=3,
                ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding colored by data source')

In [None]:
tissueregion_all_o.obs.loc[tissueregion_all_o.obs['name']=='query','query_tissueregion_main']=list(X_input[0].obs['gt_tissue_region_main'])

for i in tissueregion_all_o.obs['batch'].unique():
    sub = tissueregion_all_o[tissueregion_all_o.obs['batch']==i]
    label = sub.obs['transfer_gt_tissue_region_main_STARmap_ref']
    location = np.array(sub.obs[['x','y']])
    knn = KNeighborsClassifier(n_neighbors=30)
    knn.fit(location, label)

    predicted_labels = knn.predict(location)
    tissueregion_all_o.obs.loc[tissueregion_all_o.obs['batch']==i,'transfer_gt_tissue_region_main_STARmap_ref'] = predicted_labels

In [None]:
### main cell type color
color_dic_main=pd.read_csv('source_data/color/tissue_domain/starmap_main.csv',
                           index_col=0)
color_dic=dict(zip(color_dic_main['key'],color_dic_main['color']))


fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(tissueregion_all_o[tissueregion_all_o.obs['name']!='ref',:],color='transfer_gt_tissue_region_main_STARmap_ref',size=3,
                palette=color_dic,legend_loc=[],ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by transferred main level tissue regions')


coeefi=(max(X_input[0].obs['x'])-min(X_input[0].obs['x']))/(max(X_input[0].obs['y'])-min(X_input[0].obs['y']))

plt.figure(figsize=(5,5/coeefi))
plt.scatter(X_input[0].obs['x'],
           X_input[0].obs['y'],
            s=0.5,
           c=[color_dic[i] for i in tissueregion_all_o[tissueregion_all_o.obs['name']!='ref',:].obs['transfer_gt_tissue_region_main_STARmap_ref']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by transferred main level tissue regions')


In [None]:
query_dic={}
for i in X_input[0].obs['region'].unique():
    query_dic[i]=X_input[0][X_input[0].obs['region']==i].obs['region_color'].unique()[0]
    
fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(tissueregion_all_o[tissueregion_all_o.obs['name']!='ref',:],
                color='query_tissueregion_main',size=3,
                palette=query_dic,legend_loc=[],ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by original acatomical tissue regions')


coeefi=(max(X_input[0].obs['x'])-min(X_input[0].obs['x']))/(max(X_input[0].obs['y'])-min(X_input[0].obs['y']))
plt.figure(figsize=(5,5/coeefi))
plt.scatter(X_input[0].obs['x'],
           X_input[0].obs['y'],
            s=0.5,
           c=[query_dic[i] for i in tissueregion_all_o[tissueregion_all_o.obs['name']!='ref',:].obs['query_tissueregion_main']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by original acatomical tissue regions')

### transfer sub tissue regions

In [None]:
result_celltype=tissueregion_all_o
gt_ref_key='gt_tissue_region_main_STARmap'
gt_ref_sub_key='gt_tissue_region_sub_STARmap'
transfer_key='transfer_gt_tissue_region_main_STARmap_ref'
result_celltype.obs[f'transfer_{gt_ref_sub_key}']='-1'

In [None]:
for focus_main in result_celltype[result_celltype.obs['name']=='ref'].obs[gt_ref_key].unique():
    print(focus_main)
    
    ad_embed_OB = result_celltype[result_celltype.obs[transfer_key]==focus_main]
    if ad_embed_OB.shape[0]==0:
        continue
    sub_list=list(result_celltype[result_celltype.obs[gt_ref_key]==focus_main].obs[gt_ref_sub_key].value_counts().keys())

    ad_embed_train = ad_embed_OB[ad_embed_OB.obs[gt_ref_sub_key].isin(sub_list)]
    sample1_embeddings = ad_embed_train.X
    sample1_labels = list(ad_embed_train.obs[gt_ref_sub_key])
        
    le = preprocessing.LabelEncoder()
    le.fit(sample1_labels)
    
    
    sample1_labels = le.transform(sample1_labels)
    sample1_labels = sample1_labels.astype('str').astype('int')
    

    dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
    train_size = int(0.95 * len(dataset1))  # Use 80% of the data for training
    val_size = len(dataset1) - train_size
    train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    
    import torch.optim as optim
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                                classes=np.unique(sample1_labels),
                                                                                y=sample1_labels))
    model = NNTransfer(input_dim=sample1_embeddings.shape[1],
                       output_dim=len(np.unique(sample1_labels)))
    model.to(device)  # Move the model to GPU if available
    criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    NNTransferTrain(model, criterion, optimizer, train_loader, val_loader, device, 0, 0, 0)

    
    sample2_embeddings = ad_embed_OB.X
    dataset2 = TensorDataset(torch.Tensor(sample2_embeddings))
    dataloader2 = DataLoader(dataset2, batch_size=256, shuffle=False)
    sample2_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, dataloader2, device)
    sample2_predictions = le.inverse_transform(sample2_predictions)

    result_celltype.obs.loc[result_celltype.obs[transfer_key]==focus_main,f'transfer_{gt_ref_sub_key}'] = sample2_predictions
    result_celltype.obs.loc[result_celltype.obs[transfer_key]==focus_main,f'transfer_{gt_ref_sub_key}_uncertainty'] = sample2_uncertainty
    

In [None]:
ad_embed=result_celltype
ad_embed.obs['transfer_gt_tissue_region_sub_STARmap_o'] = ad_embed.obs['transfer_gt_tissue_region_sub_STARmap'].copy()
ad_embed.obs['transfer_gt_tissue_region_sub_STARmap'] = ad_embed.obs['transfer_gt_tissue_region_sub_STARmap_o'].copy()

for i in ad_embed.obs['batch'].unique():
    sub = ad_embed[ad_embed.obs['batch']==i]
    label = sub.obs['transfer_gt_tissue_region_sub_STARmap']
    location = np.array(sub.obs[['x','y']])
    knn = KNeighborsClassifier(n_neighbors=30)
    knn.fit(location, label)

    predicted_labels = knn.predict(location)
    ad_embed.obs.loc[ad_embed.obs['batch']==i,'transfer_gt_tissue_region_sub_STARmap'] = predicted_labels


In [None]:
### main cell type color
color_dic_main=pd.read_csv('source_data/color/tissue_domain/starmap_sub_old.csv',
                           index_col=0)
color_dic=dict(zip(color_dic_main['key'],color_dic_main['color']))


fig,ax = plt.subplots(figsize=(7,7))
ax = sc.pl.umap(ad_embed[ad_embed.obs['name']!='ref',:],color='transfer_gt_tissue_region_sub_STARmap',size=3,
                palette=color_dic,legend_loc=[],ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by transferred sub level tissue regions')


coeefi=(max(X_input[0].obs['x'])-min(X_input[0].obs['x']))/(max(X_input[0].obs['y'])-min(X_input[0].obs['y']))
plt.figure(figsize=(5,5/coeefi))
plt.scatter(X_input[0].obs['x'],
           X_input[0].obs['y'],
            s=0.5,
           c=[color_dic[i] for i in tissueregion_all_o[tissueregion_all_o.obs['name']!='ref',:].obs['transfer_gt_tissue_region_sub_STARmap']])
plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by transferred sub level tissue regions')


In [None]:
focus=ad_embed[ad_embed.obs['batch']=='sample0']

GT=np.array(focus.obs['query_tissueregion_main'])
PRED=np.array(focus.obs['transfer_gt_tissue_region_sub_STARmap'])


cross_tab = pd.crosstab(pd.Series(GT, name='Original'),
                                pd.Series(PRED, name='FuseMap'))
cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=1), axis=0)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=0), axis=1)


corresponde_sub=[]
corresponde_main=[]
for i in cross_tab_normalized.columns:
    corresponde_sub.append(i)
    corresponde_main.append(cross_tab_normalized.index[cross_tab_normalized[i].argmax()])

    
corresponde_main = np.array(corresponde_main)
corresponde_sub = np.array(corresponde_sub)

old_list = list(np.unique(corresponde_main))
new_list=[]
for i in np.unique(corresponde_main):
    for t in corresponde_sub[corresponde_main==i]:
        new_list.append(t)
        
cross_tab_normalized=cross_tab_normalized[new_list]
cross_tab_normalized=cross_tab_normalized.loc[old_list]

cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)


plt.figure(figsize=(30,3))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,vmax=1)
plt.title("Normalized Correspondence of Two Categories")
plt.show()


### Isocortex and hippocampal region

In [None]:
ad_embed_query = ad_embed[ad_embed.obs['name']=='query']
ad_embed_ref = ad_embed[ad_embed.obs['name']!='query']

new_list= ['CTX_1','HPF_CA','DG']

ad_embed_query_focus = ad_embed_query[ad_embed_query.obs['transfer_gt_tissue_region_main_STARmap_ref'].isin(new_list)]
ad_embed_ref_focus = ad_embed_ref[ad_embed_ref.obs['transfer_gt_tissue_region_main_STARmap_ref'].isin(new_list)]

In [None]:
coeefi=(max(ad_embed_query_focus.obs['x'])-min(ad_embed_query_focus.obs['x']))/(max(ad_embed_query_focus.obs['y'])-min(ad_embed_query_focus.obs['y']))
plt.figure(figsize=(6,6/coeefi))
plt.scatter(ad_embed_query_focus.obs['x'],
           ad_embed_query_focus.obs['y'],
            s=1,
           c=[color_dic[i] for i in ad_embed_query_focus.obs['transfer_gt_tissue_region_sub_STARmap']])

plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map in isocortex and hippocampal region, colored by transferred sub level tissue regions')


In [None]:
ad_embed_ref_focus_sample4=ad_embed_ref_focus[ad_embed_ref_focus.obs['batch']=='sample4']

coeefi=(max(ad_embed_ref_focus_sample4.obs['x'])-min(ad_embed_ref_focus_sample4.obs['x']))/(max(ad_embed_ref_focus_sample4.obs['y'])-min(ad_embed_ref_focus_sample4.obs['y']))

plt.figure(figsize=(6,6*coeefi))
plt.scatter(ad_embed_ref_focus_sample4.obs['y'],
           ad_embed_ref_focus_sample4.obs['x'],
            s=1,
           c=[color_dic[i] for i in ad_embed_ref_focus_sample4.obs['transfer_gt_tissue_region_sub_STARmap']])

plt.gca().invert_yaxis()
plt.gca().invert_xaxis()
plt.axis('off')
plt.title('Spatial map in isocortex and hippocampal region, colored by reference sub level tissue regions')
