In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy.external as sce
from sklearn import preprocessing
import pickle5 as pickle
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn import preprocessing
import sklearn
from sklearn.metrics import accuracy_score
import seaborn as sns

from utils import *


eps=1e-100


In [None]:
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split

import torch.nn as nn
class NNTransfer(nn.Module):
    def __init__(self, input_dim=128, output_dim=10):
        super(NNTransfer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, output_dim)
        self.activate = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x= self.activate(x)
        return x

    

def NNTransferTrain(model, criterion, optimizer, train_loader,val_loader, device, 
                    save_pth=None, epochs=200):
    eval_accuracy_mini=0#np.inf
    patience_count=0
    for epoch in range(epochs):
        model.train()
        loss_all=0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_all+=loss.item()
        eval_loss, eval_accuracy = NNTransferEvaluate(model, val_loader, criterion, device)
        if eval_accuracy_mini<eval_accuracy:
            eval_accuracy_mini=eval_accuracy
#             torch.save(model.state_dict(), save_pth)
            print(f"Epoch {epoch}/{epochs} - Train Loss: {loss_all / len(train_loader)}, Accuracy: {eval_accuracy}")
            patience_count=0
        else:
            patience_count+=1
        if patience_count>10:
            p=0
            print(f"Epoch {epoch}/{epochs} - early stopping due to patience count")
            break
            
def NNTransferEvaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return total_loss/len(dataloader), accuracy

def NNTransferPredictWithUncertainty(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_uncertainties = []

    with torch.no_grad():
        for inputs in dataloader:
            inputs = inputs[0].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            confidence = torch.max(outputs, dim=1)[0]
            uncertainty = 1 - confidence
            all_predictions.extend(predicted.detach().cpu().numpy())
            all_uncertainties.extend(uncertainty.detach().cpu().numpy())

    return np.vstack(all_predictions), np.vstack(all_uncertainties)

In [None]:

starmap_adata_obs = pd.read_csv('source_data/color/starmap_sub_old.csv',index_col=0)
color_dic = dict(zip(starmap_adata_obs['key'],
                     starmap_adata_obs['color']))


In [None]:

tissueregion_starmap=sc.read_h5ad('source_data/ad_embed.h5ad')    


### Transfer A1N main level labels 

In [None]:
ad_embed_train = tissueregion_starmap[tissueregion_starmap.obs.loc[tissueregion_starmap.obs['gt_tissue_region_main']!='NA'].index]


In [None]:
sample1_embeddings = ad_embed_train.X
sample1_labels = list(ad_embed_train.obs['gt_tissue_region_main'])

le = preprocessing.LabelEncoder()
le.fit(sample1_labels)


sample1_labels = le.transform(sample1_labels)
sample1_labels = sample1_labels.astype('str').astype('int')


dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
train_size = int(0.8 * len(dataset1))  # Use 80% of the data for training
val_size = len(dataset1) - train_size
train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])

In [None]:
val_size = int(0.5 * len(val_dataset))  # Use 10% of the data for val and 10% for testing 
test_size = len(val_dataset) - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weight = torch.Tensor(sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',
                                                                            classes=np.unique(sample1_labels),
                                                                            y=sample1_labels))
model = NNTransfer(input_dim=sample1_embeddings.shape[1],
                   output_dim=len(np.unique(sample1_labels)))
model.to(device)  # Move the model to GPU if available
criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
optimizer = optim.Adam(model.parameters(), lr=0.001)

NNTransferTrain(model, criterion, optimizer, train_loader, val_loader, device)

In [None]:
test_predictions,sample2_uncertainty = NNTransferPredictWithUncertainty(model, test_loader, device)
test_predictions = le.inverse_transform(test_predictions)

all_labels = [label.item() for _, label in test_dataset]


gt_test_predictions = le.inverse_transform(all_labels)

GT_starmap_s = gt_test_predictions
PRED_starmap_s = test_predictions

In [None]:
cross_tab = pd.crosstab(pd.Series(GT_starmap_s, name='Original'),
                                pd.Series(PRED_starmap_s, name='FuseMap'))

cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=0), axis=1)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=1), axis=0)

cross_tab_normalized = cross_tab_normalized*100
cross_tab_normalized = np.around(cross_tab_normalized)
cross_tab_normalized=cross_tab_normalized.astype('int')

In [None]:
# Plot heatmap
cmap = sns.cubehelix_palette(start=2, rot=0, dark=0, light=1.05, reverse=False, as_cmap=True)

plt.figure(figsize=(7,6))
ax=sns.heatmap(cross_tab_normalized, cmap=cmap,)
plt.title("Normalized Correspondence of Two Categories")
# plt.savefig('figures_refine/main_ct_starmap.png',dpi=300, transparent=True)
plt.show()