In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy.external as sce
from sklearn import preprocessing
import pickle5 as pickle
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn import preprocessing
import sklearn
from sklearn.metrics import accuracy_score

from utils import *


eps=1e-100


In [None]:
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split

import torch.nn as nn
class NNTransfer(nn.Module):
    def __init__(self, input_dim=128, output_dim=10):
        super(NNTransfer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, output_dim)
        self.activate = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x= self.activate(x)
        return x

    

def NNTransferTrain(model, criterion, optimizer, train_loader,val_loader, device, 
                    save_pth=None, epochs=200):
    eval_accuracy_mini=0#np.inf
    patience_count=0
    for epoch in range(epochs):
        model.train()
        loss_all=0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_all+=loss.item()
        eval_loss, eval_accuracy = NNTransferEvaluate(model, val_loader, criterion, device)
        if eval_accuracy_mini<eval_accuracy:
            eval_accuracy_mini=eval_accuracy
#             torch.save(model.state_dict(), save_pth)
            print(f"Epoch {epoch}/{epochs} - Train Loss: {loss_all / len(train_loader)}, Accuracy: {eval_accuracy}")
            patience_count=0
        else:
            patience_count+=1
        if patience_count>10:
            p=0
            print(f"Epoch {epoch}/{epochs} - early stopping due to patience count")
            break
            
def NNTransferEvaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return total_loss/len(dataloader), accuracy

def NNTransferPredictWithUncertainty(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_uncertainties = []

    with torch.no_grad():
        for inputs in dataloader:
            inputs = inputs[0].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            confidence = torch.max(outputs, dim=1)[0]
            uncertainty = 1 - confidence
            all_predictions.extend(predicted.detach().cpu().numpy())
            all_uncertainties.extend(uncertainty.detach().cpu().numpy())

    return np.vstack(all_predictions), np.vstack(all_uncertainties)


### Read universla single-cell embedding

In [None]:
ad_fusemap_emb = sc.read_h5ad('source_data/ad_embed.h5ad')

### Transfer A1N main level

In [None]:
ad_embed_train = ad_fusemap_emb[ad_fusemap_emb.obs.loc[ad_fusemap_emb.obs['gt_celltype_main_STARmap']!='nan'].index]
ad_embed_train = ad_embed_train[ad_embed_train.obs['gt_celltype_main_STARmap']!='Unannotated',:]

In [None]:

sample1_embeddings = ad_embed_train.X
sample1_labels = list(ad_embed_train.obs['gt_celltype_main_STARmap'])

le = preprocessing.LabelEncoder()
le.fit(sample1_labels)


sample1_labels = le.transform(sample1_labels)
sample1_labels = sample1_labels.astype('str').astype('int')


dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
train_size = int(0.8 * len(dataset1))  # Use 80% of the data for training
val_size = len(dataset1) - train_size
train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])


In [None]:

val_size = int(0.5 * len(val_dataset))  # Use 10% of the data for val and 10% for testing 
test_size = len(val_dataset) - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])


In [None]:


train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                            classes=np.unique(sample1_labels),
                                                                            y=sample1_labels))
model = NNTransfer(input_dim=sample1_embeddings.shape[1],
                   output_dim=len(np.unique(sample1_labels)))
model.to(device)  # Move the model to GPU if available
criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
optimizer = optim.Adam(model.parameters(), lr=0.001)

NNTransferTrain(model, criterion, optimizer, train_loader, val_loader, device)



In [None]:
test_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, test_loader, device)
test_predictions = le.inverse_transform(test_predictions)

all_labels = [label.item() for _, label in test_dataset]


gt_test_predictions = le.inverse_transform(all_labels)

GT_starmap_s = gt_test_predictions
PRED_starmap_s = test_predictions

### plot the heatmap

In [None]:
cross_tab = pd.crosstab(pd.Series(GT_starmap_s, name='Original'),
                                pd.Series(PRED_starmap_s, name='FuseMap'))

cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=0), axis=1)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=1), axis=0)

cross_tab_normalized = cross_tab_normalized*100
cross_tab_normalized = np.around(cross_tab_normalized)
cross_tab_normalized=cross_tab_normalized.astype('int')

In [None]:
cross_tab_normalized.shape

In [None]:
sort_list=['Telencephalon projecting excitatory neurons','Di- and mesencephalon excitatory neurons',
           'Telencephalon projecting inhibitory neurons', 
          'Telencephalon inhibitory interneurons','Di- and mesencephalon inhibitory neurons', 
          'Peptidergic neurons', 'Glutamatergic neuroblasts',
          'Dentate gyrus granule neurons','Non-glutamatergic neuroblasts',
          'Olfactory inhibitory neurons', 'Cerebellum neurons', 'Cholinergic and monoaminergic neurons',
          'Hindbrain neurons/Spinal cord neurons','Subcommissural organ hypendymal cells','Ependymal cells', 
          'Choroid plexus epithelial cells','Astrocytes', 'Vascular smooth muscle cells' , 'Pericytes',   
           'Vascular endothelial cells', 'Vascular and leptomeningeal cells',
          'Perivascular macrophages', 'Microglia', 'Oligodendrocyte precursor cells', 
           'Oligodendrocytes','Olfactory ensheathing cells',]

In [None]:
cross_tab_normalized = cross_tab_normalized[sort_list]

cross_tab_normalized = cross_tab_normalized.loc[sort_list]


In [None]:
cross_tab_normalized.shape

In [None]:
import seaborn as sns
cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)

# Plot heatmap
plt.figure(figsize=(7,6))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,)
plt.title("Normalized Correspondence of Two Categories")
# plt.savefig('figures_refine/main_ct_starmap.png',dpi=300, transparent=True)
plt.show()


transfer to all cells

In [None]:

sample2_embeddings = ad_fusemap_emb.X
dataset2 = TensorDataset(torch.Tensor(sample2_embeddings))
dataloader2 = DataLoader(dataset2, batch_size=256, shuffle=False)
sample2_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, dataloader2, device)
sample2_predictions = le.inverse_transform(sample2_predictions)

ad_fusemap_emb.obs['transfer_gt_cell_type_main_STARmap'] = sample2_predictions


### Transfer A2N main

In [None]:
ad_fusemap_emb.obs['gt_celltype_class_allen'] = ad_fusemap_emb.obs['gt_celltype_class_allen'].astype('str')
ad_embed_train = ad_fusemap_emb[ad_fusemap_emb.obs.loc[ad_fusemap_emb.obs['gt_celltype_class_allen']!='nan'].index]


In [None]:

sample1_embeddings = ad_embed_train.X
sample1_labels = list(ad_embed_train.obs['gt_celltype_class_allen'])

le = preprocessing.LabelEncoder()
le.fit(sample1_labels)


sample1_labels = le.transform(sample1_labels)
sample1_labels = sample1_labels.astype('str').astype('int')


dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
train_size = int(0.8 * len(dataset1))  # Use 80% of the data for training
val_size = len(dataset1) - train_size
train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])


In [None]:

val_size = int(0.5 * len(val_dataset))  # Use 10% of the data for val and 10% for testing 
test_size = len(val_dataset) - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])


In [None]:


train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                            classes=np.unique(sample1_labels),
                                                                            y=sample1_labels))
model_allen = NNTransfer(input_dim=sample1_embeddings.shape[1],
                   output_dim=len(np.unique(sample1_labels)))
model_allen.to(device)  # Move the model to GPU if available
criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
optimizer = optim.Adam(model_allen.parameters(), lr=0.001)

NNTransferTrain(model_allen, criterion, optimizer, train_loader, val_loader, device)



In [None]:
test_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model_allen, test_loader, device)
test_predictions = le.inverse_transform(test_predictions)
all_labels = [label.item() for _, label in test_dataset]

gt_test_predictions = le.inverse_transform(all_labels)
GT_allen_s = gt_test_predictions
PRED_allen_s = test_predictions

In [None]:
cross_tab = pd.crosstab(pd.Series(GT_allen_s, name='Original'),
                                pd.Series(PRED_allen_s, name='FuseMap'))

cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=0), axis=1)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=1), axis=0)

cross_tab_normalized = cross_tab_normalized*100
cross_tab_normalized = np.around(cross_tab_normalized)
cross_tab_normalized=cross_tab_normalized.astype('int')

In [None]:
new_list= ['01 IT-ET Glut', '02 NP-CT-L6b Glut', 
            '16 HY MM Glut', '17 MH-LH Glut', '18 TH Glut', '19 MB Glut',
          '09 CNU-LGE GABA',  '06 CTX-CGE GABA', '07 CTX-MGE GABA',
       '08 CNU-MGE GABA', '20 MB GABA',  '26 P GABA','12 HY GABA',
           '10 LSX GABA', '11 CNU-HYa GABA', '14 HY Glut', '15 HY Gnrh1 Glut',
           '03 OB-CR Glut',  '13 CNU-HYa Glut','04 DG-IMN Glut','05 OB-IMN GABA', 
            '28 CB GABA', '29 CB Glut','21 MB Dopa', '22 MB-HB Sero',       
        '23 P Glut', '24 MY Glut','27 MY GABA',  '30 Astro-Epen', 
           '33 Vascular', '34 Immune','31 OPC-Oligo', '32 OEC',]

In [None]:
cross_tab_normalized = cross_tab_normalized[new_list]

cross_tab_normalized = cross_tab_normalized.loc[new_list]


In [None]:
import seaborn as sns
cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)

# Plot heatmap
plt.figure(figsize=(15,12))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,)
plt.title("Normalized Correspondence of Two Categories")
# plt.savefig('figures_refine/main_ct_allen.png',dpi=300, transparent=True)
plt.show()


transfer to all cells

In [None]:

sample2_embeddings = ad_fusemap_emb.X
dataset2 = TensorDataset(torch.Tensor(sample2_embeddings))
dataloader2 = DataLoader(dataset2, batch_size=256, shuffle=False)
sample2_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model_allen, dataloader2, device)
sample2_predictions = le.inverse_transform(sample2_predictions)

ad_fusemap_emb.obs['transfer_gt_cell_type_main_Allen'] = sample2_predictions


### Correspondence between A1N and A2N

In [None]:
from sklearn.neighbors import KNeighborsClassifier

label = ad_fusemap_emb.obs['transfer_gt_cell_type_main_STARmap']
location = np.array(ad_fusemap_emb.obsm['X_umap'])
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(location, label)

querylocation = np.array(ad_fusemap_emb.obsm['X_umap'])

predicted_labels = knn.predict(querylocation)
ad_fusemap_emb.obs['transfer_gt_cell_type_main_STARmap'] = predicted_labels


label = ad_fusemap_emb.obs['transfer_gt_cell_type_main_Allen']
location = np.array(ad_fusemap_emb.obsm['X_umap'])
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(location, label)

querylocation = np.array(ad_fusemap_emb.obsm['X_umap'])

predicted_labels = knn.predict(querylocation)
ad_fusemap_emb.obs['transfer_gt_cell_type_main_Allen'] = predicted_labels


In [None]:
GT_starmap_c=np.array(ad_fusemap_emb.obs['transfer_gt_cell_type_main_STARmap'] )
PRED_allen_c=np.array(ad_fusemap_emb.obs['transfer_gt_cell_type_main_Allen'])


cross_tab = pd.crosstab(pd.Series(GT_starmap_c, name='Original'),
                                pd.Series(PRED_allen_c, name='FuseMap'))

cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=0), axis=1)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=1), axis=0)

cross_tab_normalized = cross_tab_normalized*100
cross_tab_normalized = np.around(cross_tab_normalized)
cross_tab_normalized=cross_tab_normalized.astype('int')

In [None]:
old_list=['Telencephalon projecting excitatory neurons', 'Di- and mesencephalon excitatory neurons',
          'Telencephalon projecting inhibitory neurons', 'Telencephalon inhibitory interneurons',
          'Di- and mesencephalon inhibitory neurons', 'Peptidergic neurons', 'Glutamatergic neuroblasts',
          'Dentate gyrus granule neurons','Non-glutamatergic neuroblasts',
          'Olfactory inhibitory neurons', 'Cerebellum neurons', 'Cholinergic and monoaminergic neurons',
          'Hindbrain neurons/Spinal cord neurons', 'Subcommissural organ hypendymal cells','Ependymal cells', 
          'Choroid plexus epithelial cells','Astrocytes', 
          'Vascular smooth muscle cells' , 'Pericytes',   'Vascular endothelial cells', 'Vascular and leptomeningeal cells',
          'Perivascular macrophages', 'Microglia', 
        'Oligodendrocyte precursor cells', 'Oligodendrocytes',
        'Olfactory ensheathing cells',]

new_list= ['01 IT-ET Glut', '02 NP-CT-L6b Glut', 
            '16 HY MM Glut', '17 MH-LH Glut', '18 TH Glut', '19 MB Glut',
           '09 CNU-LGE GABA',  '06 CTX-CGE GABA', '07 CTX-MGE GABA',
           '08 CNU-MGE GABA', '20 MB GABA',  '26 P GABA','12 HY GABA',
           '10 LSX GABA', '11 CNU-HYa GABA',
            '14 HY Glut', '03 OB-CR Glut',  '13 CNU-HYa Glut',
           '04 DG-IMN Glut','05 OB-IMN GABA', 
            '28 CB GABA', '29 CB Glut','21 MB Dopa', '22 MB-HB Sero',
            '23 P Glut', '24 MY Glut','27 MY GABA',  '30 Astro-Epen',
            '33 Vascular', '34 Immune','31 OPC-Oligo', '32 OEC',]

cross_tab_normalized = cross_tab_normalized[new_list]
cross_tab_normalized = cross_tab_normalized.loc[old_list]

In [None]:
import seaborn as sns
cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)
plt.figure(figsize=(10,6))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,)
plt.title("Normalized Correspondence of Two Categories")
# plt.savefig('figures_refine/main_ct_corr.png',dpi=300, transparent=True)
plt.show()
