In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy.external as sce
from sklearn import preprocessing
import pickle5 as pickle
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn import preprocessing
import sklearn
from sklearn.neighbors import KNeighborsClassifier
import seaborn as sns

from utils import *
seed_all(0)
eps=1e-100

In [None]:
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split
import torch.nn as nn
class NNTransfer(nn.Module):
    def __init__(self, input_dim=128, output_dim=10):
        super(NNTransfer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, output_dim)
        self.activate = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x= self.activate(x)
        return x

def NNTransferTrain(model, criterion, optimizer, train_loader,val_loader, device, save_dir,label_key,focus_name, epochs=200):
    eval_accuracy_mini=0#np.inf
    patience_count=0
    for epoch in range(epochs):
        model.train()
        loss_all=0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_all+=loss.item()
        eval_loss, eval_accuracy = NNTransferEvaluate(model, val_loader, criterion, device)
        if eval_accuracy_mini<eval_accuracy:
            eval_accuracy_mini=eval_accuracy
#             torch.save(model.state_dict(), pth+f"/predictor/FuseMap_NNtransfer_transfer_tissue_region_sub_refine.pt")
#             print(f"Epoch {epoch}/{epochs} - Train Loss: {loss_all / len(train_loader)}, Accuracy: {eval_accuracy}")
            patience_count=0
        else:
            patience_count+=1
        if patience_count>10:
            p=0
#             print(f"Epoch {epoch}/{epochs} - early stopping due to patience count")
            break

def NNTransferEvaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return total_loss/len(dataloader), accuracy

def NNTransferPredictWithUncertainty(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_uncertainties = []

    with torch.no_grad():
        for inputs in dataloader:
            inputs = inputs[0].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            confidence = torch.max(outputs, dim=1)[0]
            uncertainty = 1 - confidence
            all_predictions.extend(predicted.detach().cpu().numpy())
            all_uncertainties.extend(uncertainty.detach().cpu().numpy())

    return np.vstack(all_predictions), np.vstack(all_uncertainties)


In [None]:
celltype_all_o=sc.read_h5ad('source_data/transfer_celltype_ref.h5ad')

In [None]:
celltype_all_o_ref=celltype_all_o[celltype_all_o.obs['name']=='ref']
celltype_all_o_query=celltype_all_o[celltype_all_o.obs['name']!='ref']

celltype_all_o.obs['transfer_gt_cell_type_main_ref'] = celltype_all_o.obs['transfer_gt_cell_type_main_ref'].astype('str')


label = celltype_all_o_ref.obs['gt_cell_type_main']
location = np.array(celltype_all_o_ref.obsm['X_umap'])
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(location, label)

querylocation = np.array(celltype_all_o_query.obsm['X_umap'])

predicted_labels = knn.predict(querylocation)
celltype_all_o.obs.loc[celltype_all_o.obs['name']!='ref','transfer_gt_cell_type_main_ref'] = predicted_labels


### transfer main cell type

In [None]:
fig,ax = plt.subplots(figsize=(4,4))
ax = sc.pl.umap(celltype_all_o,color='name',size=5,
                ax=ax,show=False)
plt.axis('off')
plt.title('')


In [None]:
color_dic_main=pd.read_csv('source_data/color/cell_type/allen_main.csv',
                           index_col=0)
color_dic=dict(zip(color_dic_main['key'],color_dic_main['color']))


In [None]:
fig,ax = plt.subplots(figsize=(4,4))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']=='ref',:],
                color='gt_cell_type_main',size=5,
                palette=color_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by reference main level cell types')

fig,ax = plt.subplots(figsize=(4,4))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']!='ref',:],
                color='transfer_gt_cell_type_main_ref',size=5,
                palette=color_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by transferred main level cell types')

### transfer sub cell types

In [None]:
result_celltype=celltype_all_o
gt_ref_key='gt_cell_type_main'
gt_ref_sub_key='gt_cell_type_sub'
transfer_key='transfer_gt_cell_type_main_ref'
result_celltype.obs['transfer_gt_cell_type_sub_ref']='-1'

result_celltype.obs['transfer_gt_cell_type_sub_ref']=result_celltype.obs['transfer_gt_cell_type_sub_ref'].astype('str')

for focus_main in result_celltype[result_celltype.obs['name']!='ref'].obs[transfer_key].unique():
    print(focus_main)
    
    ad_embed_OB = result_celltype[result_celltype.obs[gt_ref_key]==focus_main]
    if ad_embed_OB.shape[0]<3:
        continue
    sub_list=list(result_celltype[result_celltype.obs[gt_ref_key]==focus_main].obs[gt_ref_sub_key].value_counts().keys())

    ad_embed_train = ad_embed_OB[ad_embed_OB.obs[gt_ref_sub_key].isin(sub_list)]
    sample1_embeddings = ad_embed_train.X

    sample1_labels = list(ad_embed_train.obs[gt_ref_sub_key])
        
    le = preprocessing.LabelEncoder()
    le.fit(sample1_labels)
    
    
    sample1_labels = le.transform(sample1_labels)
    sample1_labels = sample1_labels.astype('str').astype('int')
    

    dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
    train_size = int(0.95 * len(dataset1))  # Use 80% of the data for training
    val_size = len(dataset1) - train_size
    train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    
    import torch.optim as optim
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                                classes=np.unique(sample1_labels),
                                                                                y=sample1_labels))
    model = NNTransfer(input_dim=sample1_embeddings.shape[1],
                       output_dim=len(np.unique(sample1_labels)))
    model.to(device)  # Move the model to GPU if available
    criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    NNTransferTrain(model, criterion, optimizer, train_loader, val_loader, device, 0, 0, 0)

    
    sample2_embeddings = result_celltype[result_celltype.obs[transfer_key]==focus_main].X
    dataset2 = TensorDataset(torch.Tensor(sample2_embeddings))
    dataloader2 = DataLoader(dataset2, batch_size=256, shuffle=False)
    sample2_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, dataloader2, device)
    sample2_predictions = le.inverse_transform(sample2_predictions)

    result_celltype.obs.loc[result_celltype.obs[transfer_key]==focus_main,f'transfer_gt_cell_type_sub_ref'] = sample2_predictions


In [None]:
result_celltype=celltype_all_o
gt_ref_key='gt_cell_type_sub'
gt_ref_sub_key='gt_cell_type_supertype'
transfer_key='transfer_gt_cell_type_sub_ref'
result_celltype.obs['transfer_gt_cell_type_supertype_ref']=result_celltype.obs['transfer_gt_cell_type_sub_ref']

for focus_main in result_celltype[result_celltype.obs['name']!='ref'].obs[transfer_key].unique():
    print(focus_main)
    
    ad_embed_OB = result_celltype[result_celltype.obs[gt_ref_key]==focus_main]
    if ad_embed_OB.shape[0]<3:
        continue
    sub_list=list(result_celltype[result_celltype.obs[gt_ref_key]==focus_main].obs[gt_ref_sub_key].value_counts().keys())

    ad_embed_train = ad_embed_OB[ad_embed_OB.obs[gt_ref_sub_key].isin(sub_list)]
    sample1_embeddings = ad_embed_train.X

    sample1_labels = list(ad_embed_train.obs[gt_ref_sub_key])
        
    le = preprocessing.LabelEncoder()
    le.fit(sample1_labels)
    
    
    sample1_labels = le.transform(sample1_labels)
    sample1_labels = sample1_labels.astype('str').astype('int')
    

    dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
    train_size = int(0.95 * len(dataset1))  # Use 80% of the data for training
    val_size = len(dataset1) - train_size
    train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    
    import torch.optim as optim
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                                classes=np.unique(sample1_labels),
                                                                                y=sample1_labels))
    model = NNTransfer(input_dim=sample1_embeddings.shape[1],
                       output_dim=len(np.unique(sample1_labels)))
    model.to(device)  # Move the model to GPU if available
    criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    NNTransferTrain(model, criterion, optimizer, train_loader, val_loader, device, 0, 0, 0)

    
    sample2_embeddings = result_celltype[result_celltype.obs[transfer_key]==focus_main].X
    dataset2 = TensorDataset(torch.Tensor(sample2_embeddings))
    dataloader2 = DataLoader(dataset2, batch_size=256, shuffle=False)
    sample2_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, dataloader2, device)
    sample2_predictions = le.inverse_transform(sample2_predictions)

    result_celltype.obs.loc[result_celltype.obs[transfer_key]==focus_main,f'transfer_gt_cell_type_supertype_ref'] = sample2_predictions


In [None]:
ad_embed=result_celltype
starmap_adata_obs = pd.read_csv('source_data/color/cell_type/allen_all.csv')
color_dic=dict(zip(starmap_adata_obs['key'],starmap_adata_obs['value']))

In [None]:

fig,ax = plt.subplots(figsize=(4,4))
ax = sc.pl.umap(celltype_all_o[celltype_all_o.obs['name']=='ref',:],
                color='gt_cell_type_supertype',size=5,
                palette=color_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by reference sublevel cell types')


fig,ax = plt.subplots(figsize=(4,4))
ax = sc.pl.umap(result_celltype[result_celltype.obs['name']!='ref',:],
                color='transfer_gt_cell_type_supertype_ref',size=5,
                palette=color_dic,legend_loc=[],
                ax=ax,show=False)
plt.axis('off')
plt.title('Spatial embedding, colored by targeted sublevel cell types')


### plot spatial

In [None]:
focus=result_celltype[result_celltype.obs['batch']=='sample2',:]

coeefi=(max(focus.obs['x'])-min(focus.obs['x']))/(max(focus.obs['y'])-min(focus.obs['y']))

plt.figure(figsize=(5,5*coeefi))
plt.scatter(focus.obs['y'],
           focus.obs['x'],
            s=2,
           c=[color_dic[i] for i in focus.obs['gt_cell_type_supertype']])
# plt.gca().invert_yaxis()
plt.axis('off')
plt.title('Spatial map, colored by reference sublevel cell types')


In [None]:
for batch_key in ['sample0','sample1']:
    focus=result_celltype[result_celltype.obs['batch']==batch_key,:]
    coeefi=(max(focus.obs['x'])-min(focus.obs['x']))/(max(focus.obs['y'])-min(focus.obs['y']))

    plt.figure(figsize=(5,5*coeefi))
    plt.scatter(focus.obs['y'],
               focus.obs['x'],
                s=2,
               c=[color_dic[i] for i in focus.obs['transfer_gt_cell_type_supertype_ref']])
    plt.gca().invert_yaxis()
    plt.axis('off')
    plt.title(f'Spatial map of {batch_key}, colored by targeted sublevel cell types')


### glutamatergic neuron

In [None]:
celltype_all_query=celltype_all_o[celltype_all_o.obs['name']!='ref',:]
keep_list1 = ['L2/3 IT','L4/5 IT','L6 CT','L5 IT','L6 IT','L5 ET','L5/6 NP','L6 IT Car3']
keep_list2 = ['01 IT-ET Glut','02 NP-CT-L6b Glut']

celltype_all_query_focus=celltype_all_query[celltype_all_query.obs['query_gt_celltype_main'].isin(keep_list1)]
celltype_all_query_focus=celltype_all_query_focus[celltype_all_query_focus.obs['transfer_gt_cell_type_main_ref'].isin(keep_list2)]

celltype_all_ref_focus=celltype_all_o[celltype_all_o.obs['gt_cell_type_main'].isin(keep_list2)]


In [None]:
focus=celltype_all_ref_focus[celltype_all_ref_focus.obs['batch']=='sample2',:]
focus_all=celltype_all_o[celltype_all_o.obs['batch']=='sample2']

coeefi=(max(focus_all.obs['x'])-min(focus_all.obs['x']))/(max(focus_all.obs['y'])-min(focus_all.obs['y']))

plt.figure(figsize=(5,5*coeefi))
plt.scatter(focus_all.obs['y'],
           focus_all.obs['x'],
            s=2,
           c='lightgrey')

plt.scatter(focus.obs['y'],
           focus.obs['x'],
            s=2,
           c=[color_dic[i] for i in focus.obs['gt_cell_type_supertype']])
plt.axis('off')
plt.title('Spatial map of glutamatergic neruons, colored by reference sublevel cell types')


In [None]:
for batch_key in ['sample0','sample1']:
    focus=celltype_all_query_focus[celltype_all_query_focus.obs['batch']==batch_key,:]
    focus_all=celltype_all_o[celltype_all_o.obs['batch']==batch_key]

    coeefi=(max(focus_all.obs['x'])-min(focus_all.obs['x']))/(max(focus_all.obs['y'])-min(focus_all.obs['y']))

    plt.figure(figsize=(5,5*coeefi))
    plt.scatter(focus_all.obs['y'],
               focus_all.obs['x'],
                s=2,
               c='lightgrey')

    plt.scatter(focus.obs['y'],
               focus.obs['x'],
                s=2,
               c=[color_dic[i] for i in focus.obs['transfer_gt_cell_type_supertype_ref']])
    plt.axis('off')
    plt.title(f'Spatial map of glutamatergic neruons of {batch_key}, colored by targeted sublevel cell types')


confusion matrix

In [None]:
focus=celltype_all_o[celltype_all_o.obs['name']!='ref',:]

keep_list1 = ['L2/3 IT','L4/5 IT','L6 CT','L5 IT','L6 IT','L5 ET','L5/6 NP','L6 IT Car3']
keep_list2 = ['01 IT-ET Glut','02 NP-CT-L6b Glut']

focus=focus[focus.obs['query_gt_celltype_main'].isin(keep_list1)]
focus=focus[focus.obs['transfer_gt_cell_type_main_ref'].isin(keep_list2)]

In [None]:
GT=np.array(focus.obs['query_gt_celltype_sub'])

PRED=np.array(focus.obs['transfer_gt_cell_type_supertype_ref'])

cross_tab = pd.crosstab(pd.Series(GT, name='Original'),
                                pd.Series(PRED, name='FuseMap'))

cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=1), axis=0)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=0), axis=1)


In [None]:
corresponde_sub=[]
corresponde_main=[]
for i in cross_tab_normalized.columns:
    corresponde_sub.append(i)
    corresponde_main.append(cross_tab_normalized.index[cross_tab_normalized[i].argmax()])

    
corresponde_main = np.array(corresponde_main)
corresponde_sub = np.array(corresponde_sub)
old_list = list(np.unique(corresponde_main))

corres_ship={}
new_list=[]
for i in np.unique(corresponde_main):
    for t in corresponde_sub[corresponde_main==i]:
        new_list.append(t)
        corres_ship[t]=i
        
corresponde_sub_2=[]
corresponde_main_2=[]
for i in cross_tab_normalized.index:
    corresponde_sub_2.append(i)
    corresponde_main_2.append(cross_tab_normalized.columns[cross_tab_normalized.loc[i].argmax()])

    
corresponde_main_2 = np.array(corresponde_main_2)
corresponde_sub_2 = np.array(corresponde_sub_2)
old_list_2 = list(np.unique(corresponde_sub_2))


corres_ship_2={}
for i in np.unique(corresponde_main_2):
    for t in corresponde_sub_2[corresponde_main_2==i]:
        corres_ship_2[t]=i
        
for i in cross_tab_normalized.index:
    if i not in old_list:        
        index_of_b= old_list.index(corres_ship[corres_ship_2[i]])
        old_list.insert(index_of_b + 1, i)
        
cross_tab_normalized = cross_tab_normalized[new_list]

cross_tab_normalized = cross_tab_normalized.loc[old_list]


In [None]:
cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)

plt.figure(figsize=(20,10))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,vmax=1)
plt.title("Normalized Correspondence of Two Categories")
plt.show()
