In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy.external as sce
from sklearn import preprocessing
import pickle5 as pickle
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn import preprocessing
import sklearn
from sklearn.metrics import accuracy_score


from utils import *


eps=1e-100


In [None]:
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split

import torch.nn as nn
class NNTransfer(nn.Module):
    def __init__(self, input_dim=128, output_dim=10):
        super(NNTransfer, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, output_dim)
        self.activate = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x= self.activate(x)
        return x

def NNTransferTrain(model, criterion, optimizer, train_loader,val_loader, device, save_dir,label_key,focus_name, epochs=200):
    eval_accuracy_mini=0#np.inf
    patience_count=0
    for epoch in range(epochs):
        model.train()
        loss_all=0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_all+=loss.item()
        eval_loss, eval_accuracy = NNTransferEvaluate(model, val_loader, criterion, device)
        if eval_accuracy_mini<eval_accuracy:
            eval_accuracy_mini=eval_accuracy
#             torch.save(model.state_dict(), pth+f"/predictor/FuseMap_NNtransfer_transfer_tissue_region_sub_refine.pt")
            print(f"Epoch {epoch}/{epochs} - Train Loss: {loss_all / len(train_loader)}, Accuracy: {eval_accuracy}")
            patience_count=0
        else:
            patience_count+=1
        if patience_count>10:
            p=0
            print(f"Epoch {epoch}/{epochs} - early stopping due to patience count")
            break

def NNTransferEvaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return total_loss/len(dataloader), accuracy

def NNTransferPredictWithUncertainty(model, dataloader, device):
    model.eval()
    all_predictions = []
    all_uncertainties = []

    with torch.no_grad():
        for inputs in dataloader:
            inputs = inputs[0].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            confidence = torch.max(outputs, dim=1)[0]
            uncertainty = 1 - confidence
            all_predictions.extend(predicted.detach().cpu().numpy())
            all_uncertainties.extend(uncertainty.detach().cpu().numpy())

    return np.vstack(all_predictions), np.vstack(all_uncertainties)


In [None]:
ad_gene_embedding=sc.read_h5ad(f"../source_data/ad_embed.h5ad")


# plot gene modules

In [None]:
ad_gene_embedding

In [None]:
# uncomment to read all WGCNA modules and save in 'GM_label'
# import PyWGCNA
# pyWGCNA_5xFAD = PyWGCNA.readWGCNA("../source_data/PyWGCNA/all/all.p")
# ad_gene_embedding.obs.loc[pyWGCNA_5xFAD.datExpr.var.index,'GM_label']=list(pyWGCNA_5xFAD.datExpr.var['moduleColors'])


In [None]:
np.random.seed(5)
color_palette=np.random.rand(ad_gene_embedding.obs['GM_label'].unique().shape[0],3)
import seaborn as sns
color_dic = {}
for ind,i in enumerate(ad_gene_embedding.obs['GM_label'].unique()):
    color_dic[i]=color_palette[ind]
    

fig,ax = plt.subplots(figsize=(10,10))
ax = sc.pl.umap(ad_gene_embedding,
                color='GM_label',size=10,
                legend_loc=[],
                palette=color_dic,
                ax=ax,show=False)


# Example usage
directory = "../source_data/PyWGCNA/all/hub_gene/"
for i in os.listdir(directory):
        focus_file=pd.read_csv(directory+i,index_col=0)
        focus_name=focus_file.index
        focus_name=[i.upper() for i in focus_name]
        plt.scatter(ad_gene_embedding[focus_name].obsm['X_umap'][:,0],
                   ad_gene_embedding[focus_name].obsm['X_umap'][:,1],
                    color=color_dic[i.split('_')[2]],edgecolors='k',
                    s=60)

# plt.savefig('test.png',dpi=300,transparent=True)

# compare gene selection result

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.neural_network import MLPClassifier

from sklearn.metrics import accuracy_score
def compute_cell_type_classification_accuracy(scrna,selected_genes,randomseed):
    ad_selectgene = scrna[:,np.intersect1d(scrna.var.index,list(selected_genes))]
    ad_selectgene.obs['ct_label'] = ad_selectgene.obs['Description']
    sc.pp.normalize_total(ad_selectgene)#, target_sum=1e4)
    sc.pp.log1p(ad_selectgene)
    sc.pp.scale(ad_selectgene, zero_center=False, max_value=10)

    ad_selectgene.X=ad_selectgene.X.toarray()
    print(ad_selectgene.shape)
    
    np.random.seed(0)

    sample1_embeddings = ad_selectgene.X
    sample1_labels = list(ad_selectgene.obs['ct_label'])

    le = preprocessing.LabelEncoder()
    le.fit(sample1_labels)


    sample1_labels = le.transform(sample1_labels)
    sample1_labels = sample1_labels.astype('str').astype('int')


    dataset1 = TensorDataset(torch.Tensor(sample1_embeddings), torch.Tensor(sample1_labels).long())
    train_size = int(0.8 * len(dataset1))  # Use 80% of the data for training
    val_size = len(dataset1) - train_size
    train_dataset, val_dataset = random_split(dataset1, [train_size, val_size])


    
    X = train_dataset.dataset.tensors[0].numpy()
    y = train_dataset.dataset.tensors[1].numpy()
    clf = MLPClassifier(alpha=1, max_iter=1000, random_state=randomseed)
    clf.fit(X, y)

    
    test_X = val_dataset.dataset.tensors[0].numpy()
    test_y = val_dataset.dataset.tensors[1].numpy()
#     test_y_predict = clf.predict(test_X)
    
    score = clf.score(test_X, test_y)
    
#     print('accuracy_score',accuracy_score(val_dataset.dataset.tensors[1].numpy(), test_y_predict))
    return score

In [None]:
###### import scanpy as sc
scrna = sc.read_h5ad('../source_data/Atlas8_scrnaseq.h5ad')
scrna.var.index = [i.upper() for i in scrna.var.index]
scrna.var_names_make_unique()

In [None]:
randomseedrange=range(5)

genenumberrange = [10,20,35,50,75,100,125,150]

pd_random = {'type':[],'acc':[],'randomseed':[],'genenumber':[]}


uncomment below to compute accuracy and save

In [None]:
# ## random
# for randomseed in randomseedrange:
#     for genenumber in genenumberrange:
#         np.random.seed(randomseed)
#         intergene = np.intersect1d(ad_gene_embedding.obs.index,scrna.var.index)
#         selected_genes = intergene[np.random.permutation(intergene.shape[0])[:genenumber]]
#         pd_random['type'].append('random')
#         pd_random['acc'].append(compute_cell_type_classification_accuracy(scrna,selected_genes,randomseed=randomseed))
#         pd_random['randomseed'].append(randomseed)
#         pd_random['genenumber'].append(genenumber)


# ## get intersection genes
# intergene = np.intersect1d(ad_gene_embedding.obs.index,scrna.var.index)
# scrna_inter = scrna[:,intergene]
# sc.pp.highly_variable_genes(scrna_inter, n_top_genes=1100, flavor="seurat_v3")

# scrna_inter_top = scrna_inter[:,scrna_inter.var['highly_variable']==True]
# scrna_inter_top_sort = scrna_inter_top.var.sort_values(by='highly_variable_rank',ascending=True)#[:genenumber]



# ## gene module guided selection
# scrna_inter_top_sort['GM_label'] = ad_gene_embedding.obs.loc[scrna_inter_top_sort.index,'GM_label']

# for randomseed in randomseedrange:
#     for genenumber in genenumberrange:
#         sorted_genes = scrna_inter_top_sort #.sort_values(['GM_label', 'mean'], ascending=[True, False])
#         selected_genes = scrna_inter_top_sort.groupby('GM_label').head(np.floor(genenumber/13))
#         if len(selected_genes) < genenumber:
#             additional_genes_needed = genenumber - len(selected_genes)
#             additional_genes = (sorted_genes[~sorted_genes.index.isin(selected_genes.index)]
#                                 .sort_values('means', ascending=False)
#                                 .head(additional_genes_needed))
#             selected_genes = pd.concat([selected_genes, additional_genes])
#         selected_genes = selected_genes.index

#         np.random.seed(randomseed)
#         pd_random['type'].append('spatial_mean_umap_high_top_sort_1')
#         pd_random['acc'].append(compute_cell_type_classification_accuracy(scrna,selected_genes,randomseed=randomseed))
#         pd_random['randomseed'].append(randomseed)
#         pd_random['genenumber'].append(genenumber)
        
        
# ## highly variable selection
# for randomseed in randomseedrange:
#     for genenumber in genenumberrange:
#         np.random.seed(randomseed)
#         selected_genes = scrna_inter_top_sort[:genenumber].index
#         pd_random['type'].append('highvariable_inter_top_sort')
#         pd_random['acc'].append(compute_cell_type_classification_accuracy(scrna,selected_genes,randomseed=randomseed))
#         pd_random['randomseed'].append(randomseed)
#         pd_random['genenumber'].append(genenumber)


# pd_random_final = pd.DataFrame(pd_random)
# pd_random_final.to_csv('../source_data/result/gene_selection/subtype_classification.csv')

### plot accuracy

In [None]:
color_map={'highvariable_inter_top_sort': np.array([0.267004, 0.004874, 0.329415, 0.1      ]),
 'random': np.array([0.127568, 0.566949, 0.550556, 0.1      ]),
 'spatial_mean_umap_high_top_sort_1': np.array([0.993248, 0.906157, 0.143936, 0.1      ])}
color_map['spatial_mean_umap_high_top_sort_1']='orange'


pd_random_final=pd.read_csv('../source_data/result/gene_selection/maintype_classification.csv',index_col=0)
plt.figure(figsize=(6, 6))
sns.pointplot(data=pd_random_final, x="genenumber", y="acc", hue="type", dodge=False,
              palette=color_map,ci=None)
sns.boxplot(data=pd_random_final, x="genenumber", y="acc", hue="type", dodge=False,
            boxprops=dict(facecolor='none', ))
plt.legend().remove()
plt.title('Cell type granularity: main types')
plt.ylim([0,1])


pd_random_final=pd.read_csv('../source_data/result/gene_selection/subtype_classification.csv',index_col=0)
plt.figure(figsize=(6, 6))
sns.pointplot(data=pd_random_final, x="genenumber", y="acc", hue="type", dodge=False,
              palette=color_map,ci=None)
sns.boxplot(data=pd_random_final, x="genenumber", y="acc", hue="type", dodge=False,
            boxprops=dict(facecolor='none', ))
plt.legend().remove()
plt.title('Cell type granularity: subtypes')
plt.ylim([0,0.8])
