In [1]:
from sklearn.metrics import (adjusted_rand_score, normalized_mutual_info_score, 
                             silhouette_score, calinski_harabasz_score,
                             davies_bouldin_score)
import logging
import numpy as np
from tqdm import tqdm
import torch
from sklearn.preprocessing import LabelEncoder
from graphmae.utils import (
    
    build_args,
    create_optimizer,
    set_random_seed,
    TBLogger,
    get_current_lr,
    load_best_configs,

    
)
from collections import Counter
from graphmae.datasets.data_util import load_dataset
from graphmae.evaluation import node_classification_evaluation
from graphmae.models import build_model
from ogb.nodeproppred import DglNodePropPredDataset
from sklearn.cluster import KMeans
def kMeans_use(embedding,cluster_number):
    kmeans = KMeans(n_clusters=cluster_number,
                init="k-means++",
                random_state=0)
    pred = kmeans.fit_predict(embedding)
    return pred
import argparse
parser = argparse.ArgumentParser(description="GAT")
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--max_epoch", type=int, default=200,
                    help="number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=-1)

parser.add_argument("--num_heads", type=int, default=4,
                    help="number of hidden attention heads")
parser.add_argument("--num_out_heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num_layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--num_hidden", type=int, default=256,
                    help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--in_drop", type=float, default=.2,
                    help="input feature dropout")
parser.add_argument("--attn_drop", type=float, default=.1,
                    help="attention dropout")
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005,
                    help="learning rate")
parser.add_argument("--weight_decay", type=float, default=5e-4,
                    help="weight decay")
parser.add_argument("--negative_slope", type=float, default=0.2,
                    help="the negative slope of leaky relu for GAT")
parser.add_argument("--activation", type=str, default="prelu")
parser.add_argument("--mask_rate", type=float, default=0.5)
parser.add_argument("--drop_edge_rate", type=float, default=0.0)
parser.add_argument("--replace_rate", type=float, default=0.0)

parser.add_argument("--encoder", type=str, default="gat")
parser.add_argument("--decoder", type=str, default="gat")
parser.add_argument("--loss_fn", type=str, default="byol")
parser.add_argument("--alpha_l", type=float, default=2, help="`pow`inddex for `sce` loss")
parser.add_argument("--optimizer", type=str, default="adam")

parser.add_argument("--max_epoch_f", type=int, default=30)
parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
parser.add_argument("--linear_prob", action="store_true", default=False)

parser.add_argument("--load_model", action="store_true")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--logging", action="store_true")
parser.add_argument("--scheduler", action="store_true", default=False)
parser.add_argument("--concat_hidden", action="store_true", default=False)

# for graph classification
parser.add_argument("--pooling", type=str, default="mean")
parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
parser.add_argument("--batch_size", type=int, default=32)

2022-09-12 19:33:46,137 - INFO - Enabling RDKit 2022.03.5 jupyter extensions


_StoreAction(option_strings=['--batch_size'], dest='batch_size', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [2]:
import pickle as pkl
from scanpy import read_10x_h5
import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import plotly.express as px
import pandas as pd
import scanpy as sc
import scipy as sci
import warnings
warnings.filterwarnings("ignore")
def drawPicture(dataframe,col_name, row_name,colorattribute,save_file,celltype_colors =  ("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", "#A65628", "#F781BF", "#999999", "#E41A80",
    "#377F1C", "#4DAFAE", "#984F07", "#FF7F64", "#FFFF97", "#A6568C", "#F78223", "#9999FD", "#E41AE4", "#377F80", "#4DB012",
    "#984F6B", "#FF7FC8",  "#A656F0", "#F78287", "#999A61", "#E41B48", "#377FE4", "#4DB076", "#984FCF", "#FF802C",
    "#00005F", "#A65754", "#F782EB", "#999AC5", "#E41BAC", "#378048", "#4DB0DA", "#985033", "#FF8090", "#0000C3", "#A657B8",
    "#F7834F", "#999B29", "#E41C10", "#3780AC", "#4DB13E", "#985097", "#FF80F4", "#000127", "#A6581C", "#F783B3", "#999B8D",
    "#E41C74", "#378110", "#4DB1A2", "#9850FB", "#FF8158", "#00018B", "#A65880", "#F78417", "#999BF1", "#E41CD8", "#378174",
    "#4DB206", "#98515F", "#FF81BC", "#0001EF", "#A658E4", "#F7847B", "#999C55", "#E41D3C", "#3781D8", "#4DB26A", "#9851C3",
    "#FF8220", "#000253", "#A65948", "#F784DF", "#999CB9", "#E41DA0", "#37823C", "#4DB2CE", "#985227", "#FF8284", "#0002B7",
    "#A659AC", "#F78543", "#999D1D", "#E41E04", "#3782A0", "#4DB332", "#98528B", "#FF82E8", "#00031B", "#A65A10", "#F785A7",
    "#999D81", "#E41E68", "#378304", "#4DB396", "#9852EF", "#FF834C", "#00037F", "#A65A74", "#F7860B", "#999DE5", "#E41ECC",
    "#378368", "#4DB3FA", "#985353", "#FF83B0", "#0003E3", "#A65AD8", "#F7866F", "#999E49", "#E41F30", "#3783CC", "#4DB45E",
    "#9853B7", "#FF8414", "#000447", "#A65B3C", "#F786D3", "#999EAD", "#E41F94", "#378430", "#4DB4C2", "#98541B", "#FF8478",
    "#0004AB", "#A65BA0", "#F78737", "#999F11", "#E41FF8", "#378494", "#4DB526", "#98547F", "#FF84DC", "#00050F", "#A65C04",
    "#F7879B", "#999F75"
    ),width = 1000,height = 1000,marker_size = 10,is_show = True,is_save = False,save_type = "pdf"):

    import plotly.express as px
    import plotly.graph_objects as go
    size = len(set(dataframe[colorattribute]))
    #dataframe.sort_values(by = colorattribute,inplace=True, ascending=True)
    length_col = max(dataframe[col_name]) - min(dataframe[col_name])

    length_row = max(dataframe[row_name]) - min(dataframe[row_name])
    max_length = max(length_col,length_row) + 2
    fig = px.scatter(dataframe, x = col_name, y= row_name,color = colorattribute,color_discrete_sequence=celltype_colors)
    fig.update_traces(marker_size=marker_size)
    fig.update_layout(
        xaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[row_name]),   # 起始点
            dtick = max_length  # 间距
        ),
        yaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[col_name]),   # 起始点
            dtick = max_length  # 间距
        ),
    #     xaxis_range = [min(it_mapping_csv.row)-10,min(it_mapping_csv.row) + max_length ],
    #     yaxis_range =[min(it_mapping_csv.col)-10,min(it_mapping_csv.col) + max_length]
    )
    fig.update_layout(
        autosize=False,
        width=width,
        height = height)
    if(is_show):
        fig.show()
    if(is_save):
        if(save_type == "pdf"):
            fig.write_image(save_file)
        if(save_type == "html"):
            fig.write_html(save_file)

2022-09-12 19:33:48,176 - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2022-09-12 19:33:48,177 - INFO - NumExpr defaulting to 8 threads.


In [3]:
def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
    logging.info("start training..")
    graph = graph.to(device)
    x = feat.to(device)

    epoch_iter = tqdm(range(max_epoch))

    for epoch in epoch_iter:
        model.train()

        loss, loss_dict = model(graph, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
        if logger is not None:
            loss_dict["lr"] = get_current_lr(optimizer)
            logger.note(loss_dict, step=epoch)

        #if (epoch + 1) % 200 == 0:
            #node_classification_evaluation(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)

    # return best_model
    return model

In [4]:
import pandas as pd
#import calculate_adj
import anndata as ad
import scanpy as sc
from operator import index
import re
from scipy.spatial.distance import pdist, squareform
import pandas as pd
import numpy as np
from scipy import sparse
import dgl
import torch
folder_name = "/home/sunhang/Embedding/CCST/dataset/DLPFC/"
samplea_list = {151507,
 151508,
 151509,
 151510,
 151669,
 151670,
 151671,
 151672,
 151673,
 151674,
 151675,
 151676}
sample_name = str(151673)
#gene_exp_data_file = folder_name + sample_name + "_DLPFC_count.csv"
gene_loc_data_file = folder_name + sample_name + "_DLPFC_col_name.csv"
#save file
npz_file = folder_name + sample_name + "_DLPFC_distacne.npz"
#gene_csv = pd.read_csv(gene_exp_data_file,index_col= 0 )
gene_loc_data_csv = pd.read_csv(gene_loc_data_file,index_col=0)
gene_loc_data_csv = gene_loc_data_csv.fillna("None")
row_name = "imagerow"
col_name = "imagecol"
cell_loc = gene_loc_data_csv[[row_name,col_name]].values
distance_np = pdist(cell_loc, metric = "euclidean")
distance_np_X =squareform(distance_np)
distance_loc_csv = pd.DataFrame(index=gene_loc_data_csv.index, columns=gene_loc_data_csv.index,data = distance_np_X)
threshold = 8
num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
#num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
adj_matrix = np.zeros(distance_np_X.shape)
non_zero_point = np.where((0 < distance_np_X) & (distance_np_X < threshold))
adj_matrix = np.zeros(distance_np_X.shape)
non_zero_point = np.where((0< distance_np_X)&(distance_np_X<threshold))
for i in range(num_big):
    x = non_zero_point[0][i]
    y = non_zero_point[1][i]
    adj_matrix[x][y] = 1 
adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
adj_matrix  = np.float32(adj_matrix)
adj_matrix_crs = sparse.csr_matrix(adj_matrix)
graph = dgl.from_scipy(adj_matrix_crs,eweight_name='w')
min_cells = 5
pca_n_comps = 3000

In [5]:
gene_loc_data_csv.index = gene_loc_data_csv.barcode

In [6]:
gene_loc_data_csv

Unnamed: 0_level_0,barcode,sample_name,tissue,row,col,imagerow,imagecol,Cluster,height,width,...,SpatialDE_PCA_spatial,SpatialDE_pool_PCA_spatial,HVG_PCA_spatial,pseudobulk_PCA_spatial,markers_PCA_spatial,SpatialDE_UMAP_spatial,SpatialDE_pool_UMAP_spatial,HVG_UMAP_spatial,pseudobulk_UMAP_spatial,markers_UMAP_spatial
barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACAAGTATCTCCCA-1,AAACAAGTATCTCCCA-1,151673,1,50,102,381.098123,440.639079,7,600,600,...,3,1,1,3,1,7,1,1,2,1
AAACAATCTACTAGCA-1,AAACAATCTACTAGCA-1,151673,1,3,43,126.327637,259.630972,4,600,600,...,7,5,2,2,3,2,1,4,2,3
AAACACCAATAACTGC-1,AAACACCAATAACTGC-1,151673,1,59,19,427.767792,183.078314,8,600,600,...,5,4,4,5,3,5,7,5,3,2
AAACAGAGCGACTCCT-1,AAACAGAGCGACTCCT-1,151673,1,14,94,186.813688,417.236738,6,600,600,...,3,3,1,2,2,3,4,2,1,1
AAACAGCTTTCAGAAG-1,AAACAGCTTTCAGAAG-1,151673,1,43,9,341.269139,152.700275,3,600,600,...,2,1,2,4,1,3,3,8,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGTTTCACATCCAGG-1,TTGTTTCACATCCAGG-1,151673,1,58,42,422.862301,254.410450,8,600,600,...,5,4,4,5,8,5,7,5,7,5
TTGTTTCATTAGTCTA-1,TTGTTTCATTAGTCTA-1,151673,1,60,30,433.393354,217.146722,8,600,600,...,5,4,4,5,8,5,7,1,1,4
TTGTTTCCATACAACT-1,TTGTTTCCATACAACT-1,151673,1,45,27,352.430255,208.415849,6,600,600,...,6,6,7,6,3,1,3,3,3,3
TTGTTTGTATTACACG-1,TTGTTTGTATTACACG-1,151673,1,73,41,503.735391,250.720081,6,600,600,...,6,6,8,5,3,1,3,4,1,3


In [7]:
graph.num_edges()

24763

In [16]:
from scanpy import read_10x_h5
adata = read_10x_h5("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/filtered_feature_bc_matrix.h5")
spatial=pd.read_csv("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/tissue_positions_list.txt",sep=",",header=None,na_filter=False,index_col=0) 

adata.var_names=[i.upper() for i in list(adata.var_names)]
adata.var["genename"]=adata.var.index.astype("str")

adata.var_names_make_unique
adata.obs = pd.merge(gene_loc_data_csv,adata.obs,right_index=True,left_index=True)

In [17]:
adata.obs

Unnamed: 0_level_0,barcode,sample_name,tissue,row,col,imagerow,imagecol,Cluster,height,width,...,SpatialDE_PCA_spatial,SpatialDE_pool_PCA_spatial,HVG_PCA_spatial,pseudobulk_PCA_spatial,markers_PCA_spatial,SpatialDE_UMAP_spatial,SpatialDE_pool_UMAP_spatial,HVG_UMAP_spatial,pseudobulk_UMAP_spatial,markers_UMAP_spatial
barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACAAGTATCTCCCA-1,AAACAAGTATCTCCCA-1,151673,1,50,102,381.098123,440.639079,7,600,600,...,3,1,1,3,1,7,1,1,2,1
AAACAATCTACTAGCA-1,AAACAATCTACTAGCA-1,151673,1,3,43,126.327637,259.630972,4,600,600,...,7,5,2,2,3,2,1,4,2,3
AAACACCAATAACTGC-1,AAACACCAATAACTGC-1,151673,1,59,19,427.767792,183.078314,8,600,600,...,5,4,4,5,3,5,7,5,3,2
AAACAGAGCGACTCCT-1,AAACAGAGCGACTCCT-1,151673,1,14,94,186.813688,417.236738,6,600,600,...,3,3,1,2,2,3,4,2,1,1
AAACAGCTTTCAGAAG-1,AAACAGCTTTCAGAAG-1,151673,1,43,9,341.269139,152.700275,3,600,600,...,2,1,2,4,1,3,3,8,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGTTTCACATCCAGG-1,TTGTTTCACATCCAGG-1,151673,1,58,42,422.862301,254.410450,8,600,600,...,5,4,4,5,8,5,7,5,7,5
TTGTTTCATTAGTCTA-1,TTGTTTCATTAGTCTA-1,151673,1,60,30,433.393354,217.146722,8,600,600,...,5,4,4,5,8,5,7,1,1,4
TTGTTTCCATACAACT-1,TTGTTTCCATACAACT-1,151673,1,45,27,352.430255,208.415849,6,600,600,...,6,6,7,6,3,1,3,3,3,3
TTGTTTGTATTACACG-1,TTGTTTGTATTACACG-1,151673,1,73,41,503.735391,250.720081,6,600,600,...,6,6,8,5,3,1,3,4,1,3


In [18]:
import argparse
parser = argparse.ArgumentParser(description="GAT")
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--max_epoch", type=int, default=200,
                    help="number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=-1)

parser.add_argument("--num_heads", type=int, default=4,
                    help="number of hidden attention heads")
parser.add_argument("--num_out_heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num_layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--num_hidden", type=int, default=256,
                    help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--in_drop", type=float, default=.2,
                    help="input feature dropout")
parser.add_argument("--attn_drop", type=float, default=.1,
                    help="attention dropout")
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005,
                    help="learning rate")
parser.add_argument("--weight_decay", type=float, default=5e-4,
                    help="weight decay")
parser.add_argument("--negative_slope", type=float, default=0.2,
                    help="the negative slope of leaky relu for GAT")
parser.add_argument("--activation", type=str, default="prelu")
parser.add_argument("--mask_rate", type=float, default=0.5)
parser.add_argument("--drop_edge_rate", type=float, default=0.0)
parser.add_argument("--replace_rate", type=float, default=0.0)

parser.add_argument("--encoder", type=str, default="gat")
parser.add_argument("--decoder", type=str, default="gat")
parser.add_argument("--loss_fn", type=str, default="byol")
parser.add_argument("--alpha_l", type=float, default=2, help="`pow`inddex for `sce` loss")
parser.add_argument("--optimizer", type=str, default="adam")

parser.add_argument("--max_epoch_f", type=int, default=30)
parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
parser.add_argument("--linear_prob", action="store_true", default=False)

parser.add_argument("--load_model", action="store_true")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--logging", action="store_true")
parser.add_argument("--scheduler", action="store_true", default=False)
parser.add_argument("--concat_hidden", action="store_true", default=False)

# for graph classification
parser.add_argument("--pooling", type=str, default="mean")
parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
parser.add_argument("--batch_size", type=int, default=32)

_StoreAction(option_strings=['--batch_size'], dest='batch_size', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [9]:
pd.merge(adata.obs,gene_loc_data_csv,left_index=True,right_index=True)

Unnamed: 0,barcode,sample_name,tissue,row,col,imagerow,imagecol,Cluster,height,width,...,SpatialDE_PCA_spatial,SpatialDE_pool_PCA_spatial,HVG_PCA_spatial,pseudobulk_PCA_spatial,markers_PCA_spatial,SpatialDE_UMAP_spatial,SpatialDE_pool_UMAP_spatial,HVG_UMAP_spatial,pseudobulk_UMAP_spatial,markers_UMAP_spatial
AAACAAGTATCTCCCA-1,AAACAAGTATCTCCCA-1,151673,1,50,102,381.098123,440.639079,7,600,600,...,3,1,1,3,1,7,1,1,2,1
AAACAATCTACTAGCA-1,AAACAATCTACTAGCA-1,151673,1,3,43,126.327637,259.630972,4,600,600,...,7,5,2,2,3,2,1,4,2,3
AAACACCAATAACTGC-1,AAACACCAATAACTGC-1,151673,1,59,19,427.767792,183.078314,8,600,600,...,5,4,4,5,3,5,7,5,3,2
AAACAGAGCGACTCCT-1,AAACAGAGCGACTCCT-1,151673,1,14,94,186.813688,417.236738,6,600,600,...,3,3,1,2,2,3,4,2,1,1
AAACAGCTTTCAGAAG-1,AAACAGCTTTCAGAAG-1,151673,1,43,9,341.269139,152.700275,3,600,600,...,2,1,2,4,1,3,3,8,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGTTTCACATCCAGG-1,TTGTTTCACATCCAGG-1,151673,1,58,42,422.862301,254.410450,8,600,600,...,5,4,4,5,8,5,7,5,7,5
TTGTTTCATTAGTCTA-1,TTGTTTCATTAGTCTA-1,151673,1,60,30,433.393354,217.146722,8,600,600,...,5,4,4,5,8,5,7,1,1,4
TTGTTTCCATACAACT-1,TTGTTTCCATACAACT-1,151673,1,45,27,352.430255,208.415849,6,600,600,...,6,6,7,6,3,1,3,3,3,3
TTGTTTGTATTACACG-1,TTGTTTGTATTACACG-1,151673,1,73,41,503.735391,250.720081,6,600,600,...,6,6,8,5,3,1,3,4,1,3


In [12]:
adata.obs

Unnamed: 0,x1,x2,x3,x4,x5
AAACAAGTATCTCCCA-1,1,50,102,8468,9791
AAACAATCTACTAGCA-1,1,3,43,2807,5769
AAACACCAATAACTGC-1,1,59,19,9505,4068
AAACAGAGCGACTCCT-1,1,14,94,4151,9271
AAACAGCTTTCAGAAG-1,1,43,9,7583,3393
...,...,...,...,...,...
TTGTTTCACATCCAGG-1,1,58,42,9396,5653
TTGTTTCATTAGTCTA-1,1,60,30,9630,4825
TTGTTTCCATACAACT-1,1,45,27,7831,4631
TTGTTTGTATTACACG-1,1,73,41,11193,5571


In [14]:
pca_n_comps = 3000
adata.obs = pd.merge(gene_loc_data_csv,adata.obs,right_index=True,left_on="barcode")
adata.obs.index = adata.obs["barcode"]
le = LabelEncoder()
label = le.fit_transform(adata.obs['layer_guess_reordered_short'])
adata.obs["lay_num"] = label
sc.pp.filter_genes(adata, min_cells=5)
adata_X = sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, inplace=False)['X']
adata_X = sc.pp.scale(adata_X)
adata_X = sc.pp.pca(adata_X, n_comps=pca_n_comps)
graph.ndata["feat"] = torch.tensor(adata_X.copy())
num_features = graph.ndata["feat"].shape[1]
num_classes = len(set(adata.obs.lay_num))-1

ValueError: 'barcode' is both an index level and a column label, which is ambiguous.

In [None]:
args = parser.parse_args([])
args.lr = 0.001
args.lr_f = 0.01
args.num_hidden = 512
args.num_heads = 4
args.weight_decay = 2e-4
args.weight_decay_f= 1e-4
args.max_epoch= 600
args.max_epoch_f= 300
args.mask_rate= 0.5
args.num_layers= 2
args.encoder= "gat"
args.decoder= "gat" 
args.activation= "prelu"
args.in_drop= 0.2
args.attn_drop= 0.1
args.linear_prob= True
args.loss_fn= "sce" 
args.drop_edge_rate=0.0
args.optimizer= "adam"
args.replace_rate= 0.05 
args.alpha_l= 3
args.scheduler= True
args.dataset = "sp"


#参数传递
device = args.device if args.device >= 0 else "cpu"
seeds = args.seeds
dataset_name = args.dataset
max_epoch = args.max_epoch
max_epoch_f = args.max_epoch_f
num_hidden = args.num_hidden
num_layers = args.num_layers
encoder_type = args.encoder
decoder_type = args.decoder
replace_rate = args.replace_rate

optim_type = args.optimizer 
loss_fn = args.loss_fn

lr = args.lr
weight_decay = args.weight_decay
lr_f = args.lr_f
weight_decay_f = args.weight_decay_f
linear_prob = args.linear_prob
load_model = args.load_model
save_model = args.save_model
logs = args.logging
use_scheduler = args.scheduler

In [10]:
from scanpy import read_10x_h5
adata = read_10x_h5("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/filtered_feature_bc_matrix.h5")
spatial=pd.read_csv("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/tissue_positions_list.txt",sep=",",header=None,na_filter=False,index_col=0) 
adata.obs["x1"]=spatial[1]
adata.obs["x2"]=spatial[2]
adata.obs["x3"]=spatial[3]
adata.obs["x4"]=spatial[4]
adata.obs["x5"]=spatial[5]
#Select captured samples
adata=adata[adata.obs["x1"]==1]
adata.var_names=[i.upper() for i in list(adata.var_names)]
adata.var["genename"]=adata.var.index.astype("str")

adata.var_names_make_unique

<bound method AnnData.var_names_make_unique of AnnData object with n_obs × n_vars = 3639 × 33538
    obs: 'x1', 'x2', 'x3', 'x4', 'x5'
    var: 'gene_ids', 'feature_types', 'genome', 'genename'>

In [11]:
gene_loc_data_csv

Unnamed: 0_level_0,barcode,sample_name,tissue,row,col,imagerow,imagecol,Cluster,height,width,...,SpatialDE_PCA_spatial,SpatialDE_pool_PCA_spatial,HVG_PCA_spatial,pseudobulk_PCA_spatial,markers_PCA_spatial,SpatialDE_UMAP_spatial,SpatialDE_pool_UMAP_spatial,HVG_UMAP_spatial,pseudobulk_UMAP_spatial,markers_UMAP_spatial
barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACAAGTATCTCCCA-1,AAACAAGTATCTCCCA-1,151673,1,50,102,381.098123,440.639079,7,600,600,...,3,1,1,3,1,7,1,1,2,1
AAACAATCTACTAGCA-1,AAACAATCTACTAGCA-1,151673,1,3,43,126.327637,259.630972,4,600,600,...,7,5,2,2,3,2,1,4,2,3
AAACACCAATAACTGC-1,AAACACCAATAACTGC-1,151673,1,59,19,427.767792,183.078314,8,600,600,...,5,4,4,5,3,5,7,5,3,2
AAACAGAGCGACTCCT-1,AAACAGAGCGACTCCT-1,151673,1,14,94,186.813688,417.236738,6,600,600,...,3,3,1,2,2,3,4,2,1,1
AAACAGCTTTCAGAAG-1,AAACAGCTTTCAGAAG-1,151673,1,43,9,341.269139,152.700275,3,600,600,...,2,1,2,4,1,3,3,8,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGTTTCACATCCAGG-1,TTGTTTCACATCCAGG-1,151673,1,58,42,422.862301,254.410450,8,600,600,...,5,4,4,5,8,5,7,5,7,5
TTGTTTCATTAGTCTA-1,TTGTTTCATTAGTCTA-1,151673,1,60,30,433.393354,217.146722,8,600,600,...,5,4,4,5,8,5,7,1,1,4
TTGTTTCCATACAACT-1,TTGTTTCCATACAACT-1,151673,1,45,27,352.430255,208.415849,6,600,600,...,6,6,7,6,3,1,3,3,3,3
TTGTTTGTATTACACG-1,TTGTTTGTATTACACG-1,151673,1,73,41,503.735391,250.720081,6,600,600,...,6,6,8,5,3,1,3,4,1,3


In [None]:
adata.obs.columns

In [None]:
folder_name = "/home/sunhang/Embedding/Spatial_dataset/DLPFC"
sample_name = str(151673)
gene_loc_data_file = folder_name + "/" +sample_name+ "/" + sample_name + "_DLPFC_col_name.csv"
adata_file = folder_name + "/" +sample_name+ "/" + sample_name + "_filtered_feature_bc_matrix.h5"
gene_loc_data_csv = pd.read_csv(gene_loc_data_file,index_col=0)
gene_loc_data_csv.index = gene_loc_data_csv.barcode
gene_loc_data_csv = gene_loc_data_csv.fillna("None")
le = LabelEncoder()
label = le.fit_transform(gene_loc_data_csv['layer_guess_reordered_short'])
gene_loc_data_csv["lay_num"] = label
num_classes = len(set(gene_loc_data_csv.lay_num))
if((gene_loc_data_csv['layer_guess_reordered_short'] == "None").any()):
    num_classes = len(set(gene_loc_data_csv.lay_num)) - 1

# Create a group with location informatio
row_name = "imagerow"
col_name = "imagecol"
cell_loc = gene_loc_data_csv[[row_name,col_name]].values
distance_np = pdist(cell_loc, metric = "euclidean")
distance_np_X =squareform(distance_np)
distance_loc_csv = pd.DataFrame(index=gene_loc_data_csv.index, columns=gene_loc_data_csv.index,data = distance_np_X)
threshold = 8
num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
#num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
adj_matrix = np.zeros(distance_np_X.shape)
non_zero_point = np.where((0 < distance_np_X) & (distance_np_X < threshold))
adj_matrix = np.zeros(distance_np_X.shape)
non_zero_point = np.where((0< distance_np_X)&(distance_np_X<threshold))
for i in range(num_big):
    x = non_zero_point[0][i]
    y = non_zero_point[1][i]
    adj_matrix[x][y] = 1 
adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
adj_matrix  = np.float32(adj_matrix)
adj_matrix_crs = sparse.csr_matrix(adj_matrix)
graph = dgl.from_scipy(adj_matrix_crs,eweight_name='w')


adata = read_10x_h5(adata_file)
adata.obs = pd.merge(adata.obs,gene_loc_data_csv,left_index=True,right_index=True)
adata.var_names=[i.upper() for i in list(adata.var_names)]
adata.var["genename"]=adata.var.index.astype("str")
adata.var_names_make_unique
pca_n_comps = 3000
sc.pp.filter_genes(adata, min_cells=5)
adata_X = sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, inplace=False)['X']
adata_X = sc.pp.scale(adata_X)
adata_X = sc.pp.pca(adata_X, n_comps=pca_n_comps)
graph.ndata["feat"] = torch.tensor(adata_X.copy())
num_features = graph.ndata["feat"].shape[1]
args.num_features = num_features

In [None]:
for num_set in np.arange(1,4,1):
    for num_set_two in np.arange(500,4000,500):
        print(num_set)
        print(num_set_two)
        gene_loc_data_csv = pd.read_csv(gene_loc_data_file,index_col=0)
        gene_loc_data_csv = gene_loc_data_csv.fillna("None")
        row_name = "imagerow"
        col_name = "imagecol"
        cell_loc = gene_loc_data_csv[[row_name,col_name]].values
        distance_np = pdist(cell_loc, metric = "euclidean")
        distance_np_X =squareform(distance_np)
        distance_loc_csv = pd.DataFrame(index=gene_loc_data_csv.index, columns=gene_loc_data_csv.index,data = distance_np_X)
        threshold = 8
        num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
        #num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
        adj_matrix = np.zeros(distance_np_X.shape)
        non_zero_point = np.where((0 < distance_np_X) & (distance_np_X < threshold))
        adj_matrix = np.zeros(distance_np_X.shape)
        non_zero_point = np.where((0< distance_np_X)&(distance_np_X<threshold))
        for i in range(num_big):
            x = non_zero_point[0][i]
            y = non_zero_point[1][i]
            adj_matrix[x][y] = 1 
        adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
        adj_matrix  = np.float32(adj_matrix)
        adj_matrix_crs = sparse.csr_matrix(adj_matrix)
        graph = dgl.from_scipy(adj_matrix_crs,eweight_name='w')
        min_cells = 5
        adata = read_10x_h5("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/filtered_feature_bc_matrix.h5")
        spatial=pd.read_csv("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/tissue_positions_list.txt",sep=",",header=None,na_filter=False,index_col=0) 
        adata.obs["x1"]=spatial[1]
        adata.obs["x2"]=spatial[2]
        adata.obs["x3"]=spatial[3]
        adata.obs["x4"]=spatial[4]
        adata.obs["x5"]=spatial[5]
        #Select captured samples
        adata=adata[adata.obs["x1"]==1]
        adata.var_names=[i.upper() for i in list(adata.var_names)]
        adata.var["genename"]=adata.var.index.astype("str")
        pca_n_comps = 600
        adata.var_names_make_unique
        adata.obs = pd.merge(gene_loc_data_csv,adata.obs,right_index=True,left_on="barcode")
        adata.obs.index = adata.obs["barcode"]
        le = LabelEncoder()
        label = le.fit_transform(adata.obs['layer_guess_reordered_short'])
        adata.obs["lay_num"] = label
        sc.pp.filter_genes(adata, min_cells=5)
        adata_X = sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, inplace=False)['X']
        adata_X = sc.pp.scale(adata_X)
        adata_X = sc.pp.pca(adata_X, n_comps=pca_n_comps)
        graph.ndata["feat"] = torch.tensor(adata_X.copy())
        num_features = graph.ndata["feat"].shape[1]
        num_classes = len(set(adata.obs.lay_num))-1

        #pca_n_comps = num_set
        #adata.obs = pd.merge(gene_loc_data_csv,adata.obs,right_index=True,left_on="barcode")
    #     adata.obs.index = adata.obs["barcode"]
    #     le = LabelEncoder()
    #     label = le.fit_transform(adata.obs['layer_guess_reordered_short'])
    #     adata.obs["lay_num"] = label
    #     sc.pp.filter_genes(adata, min_cells=5)
    #     adata_X = sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, inplace=False)['X']
    #     adata_X = sc.pp.scale(adata_X)
        #adata_X = sc.pp.pca(adata_X, n_comps=pca_n_comps)
        graph.ndata["feat"] = torch.tensor(adata_X.copy())
        num_features = graph.ndata["feat"].shape[1]
        num_classes = len(set(adata.obs.lay_num))-1
        #for num_set in np.arange(1000,4000,500):
        print(num_set)
        args = parser.parse_args([])
        args.lr = 0.001
        args.lr_f = 0.01
        args.num_hidden = 256
        args.num_heads = 4
        args.weight_decay = 2e-4
        args.weight_decay_f= 1e-4
        args.max_epoch= num_set_two
        args.max_epoch_f= 500
        args.mask_rate= 0.4
        args.num_layers= num_set
        args.encoder= "gat"
        args.decoder= "gat" 
        args.activation= "prelu"
        args.in_drop= 0.2
        args.attn_drop= 0.1
        args.linear_prob= True
        args.loss_fn= "sce" 
        args.drop_edge_rate=0.0
        args.optimizer= "adam"
        args.replace_rate= 0.05 
        args.alpha_l= 3
        args.scheduler= True
        args.dataset = "sp"


        #参数传递
        device = args.device if args.device >= 0 else "cpu"
        seeds = args.seeds
        dataset_name = args.dataset
        max_epoch = args.max_epoch
        max_epoch_f = args.max_epoch_f
        num_hidden = args.num_hidden
        num_layers = args.num_layers
        encoder_type = args.encoder
        decoder_type = args.decoder
        replace_rate = args.replace_rate

        optim_type = args.optimizer 
        loss_fn = args.loss_fn

        lr = args.lr
        weight_decay = args.weight_decay
        lr_f = args.lr_f
        weight_decay_f = args.weight_decay_f
        linear_prob = args.linear_prob
        load_model = args.load_model
        save_model = args.save_model
        logs = args.logging
        use_scheduler = args.scheduler
        args.num_features = num_features

        acc_list = []
        estp_acc_list = []
        times = 3

        #print(f"####### Run {i} for seed {seed}")
        #print(i)
        seed = 0
        set_random_seed(seed)
        if logs:
            logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
        else:
            logger = None
        model = build_model(args)
        device = 1
        model.to(device)
        optimizer = create_optimizer(optim_type, model, lr, weight_decay)

        if use_scheduler:
            logging.info("Use schedular")
            scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
            # scheduler = lambda epoch: epoch / warmup_steps if epoch < warmup_steps \
                    # else ( 1 + np.cos((epoch - warmup_steps) * np.pi / (max_epoch - warmup_steps))) * 0.5
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
        else:
            scheduler = None
        #训练模型
        x = graph.ndata["feat"]
        if not load_model:
            model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
            model = model.cpu()
        x = graph.ndata["feat"]
        model.to(device)
        embedding = model.embed(graph.to(device), x.to(device))
        new_pred = kMeans_use(embedding.cpu().detach().numpy(),num_classes)
        adata.obs["pre"] = new_pred
        score = adjusted_rand_score(adata.obs.lay_num.values, new_pred )
        print("first cluster:")
        print(score)
        print("结果:" + str(Counter(new_pred)))
        print("ground_truth :"  + str(Counter(adata.obs.lay_num.values)))
        x = graph.ndata["feat"]
        test = model.embed(graph.to(device), x.to(device))
        test_new_pred = kMeans_use(test.cpu().detach().numpy(),num_classes)
        score = adjusted_rand_score(adata.obs.lay_num.values, test_new_pred )
        adata.obs["second_pre"] = test_new_pred
        print("second cluster:")
        print(score)
        print("结果:" + str(Counter(test_new_pred)))
        print("ground_truth :"  + str(Counter(adata.obs.lay_num.values)))
        #drawPicture(adata.obs,"imagecol","imagerow",colorattribute="lay_num",save_file= None)
        #drawPicture(adata.obs,"imagecol","imagerow",colorattribute="pre",save_file= None)
        #drawPicture(adata.obs,"imagecol","imagerow",colorattribute="second_pre",save_file= None)

In [1]:
parameters = {
    "dataset_name":("151673"),
    "graph_devise":("CCST"),
    "threshold_num":(8),
    "feature_dim":("PCA"),
    "feature_dim_num" : np.append(np.arange(100,1000,100),np.arange(1000,4000,500)),
    "mask_rate" : np.arange(0,1,0.1),
    "code_networt_and_norm" : (["gat",None],["dotgat",None]),
    "num_hidden" : (128,256,512),
    "num_layers" : (1,2,3),
    "activation" : ("relu","gelu","prelu","elu"),
    "max_epoch" : np.arange(500,4000,500),
    "lr" : (0.001)
    #"vect__max_df": (0.5, 0.75, 1.0),
    # 'vect__max_features': (None, 5000, 10000, 50000),
    #"vect__ngram_range": ((1, 1), (1, 2)),  # unigrams or bigrams
    # 'tfidf__use_idf': (True, False),
    # 'tfidf__norm': ('l1', 'l2'),
    #"clf__max_iter": (20,),
    #"clf__alpha": (0.00001, 0.000001),
    #"clf__penalty": ("l2", "elasticnet"),
    # 'clf__max_iter': (10, 50, 80),
}


NameError: name 'np' is not defined

In [None]:
graph.num_edges

In [None]:
graph.add_edges

In [None]:
for num_set in np.arange(1000,4000,500):
    print(num_set)
    args = parser.parse_args([])
    args.lr = 0.001
    args.lr_f = 0.01
    args.num_hidden = 512
    args.num_heads = 4
    args.weight_decay = 2e-4
    args.weight_decay_f= 1e-4
    args.max_epoch= 500
    args.max_epoch_f= 500
    args.mask_rate= 0.4
    args.num_layers= 2
    args.encoder= "dotgat"
    args.decoder= "dotgat" 
    args.activation= "prelu"
    args.in_drop= 0.2
    args.attn_drop= 0.1
    args.linear_prob= True
    args.loss_fn= "sce" 
    args.drop_edge_rate=0.0
    args.optimizer= "adam"
    args.replace_rate= 0.05 
    args.alpha_l= 3
    args.scheduler= True
    args.dataset = "sp"


    #参数传递
    device = args.device if args.device >= 0 else "cpu"
    seeds = args.seeds
    dataset_name = args.dataset
    max_epoch = args.max_epoch
    max_epoch_f = args.max_epoch_f
    num_hidden = args.num_hidden
    num_layers = args.num_layers
    encoder_type = args.encoder
    decoder_type = args.decoder
    replace_rate = args.replace_rate

    optim_type = args.optimizer 
    loss_fn = args.loss_fn

    lr = args.lr
    weight_decay = args.weight_decay
    lr_f = args.lr_f
    weight_decay_f = args.weight_decay_f
    linear_prob = args.linear_prob
    load_model = args.load_model
    save_model = args.save_model
    logs = args.logging
    use_scheduler = args.scheduler
    args.num_features = num_features

    acc_list = []
    estp_acc_list = []
    times = 3

    #print(f"####### Run {i} for seed {seed}")
    #print(i)
    seed = 0
    set_random_seed(seed)
    if logs:
        logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
    else:
        logger = None
    model = build_model(args)
    device = 1
    model.to(device)
    optimizer = create_optimizer(optim_type, model, lr, weight_decay)

    if use_scheduler:
        logging.info("Use schedular")
        scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
        # scheduler = lambda epoch: epoch / warmup_steps if epoch < warmup_steps \
                # else ( 1 + np.cos((epoch - warmup_steps) * np.pi / (max_epoch - warmup_steps))) * 0.5
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
    else:
        scheduler = None
    #训练模型
    x = graph.ndata["feat"]
    if not load_model:
        model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
        model = model.cpu()
    x = graph.ndata["feat"]
    model.to(device)
    embedding = model.embed(graph.to(device), x.to(device))
    new_pred = kMeans_use(embedding.cpu().detach().numpy(),num_classes)
    adata.obs["pre"] = new_pred
    score = adjusted_rand_score(adata.obs.lay_num.values, new_pred )
    print("first cluster:")
    print(score)
    print("结果:" + str(Counter(new_pred)))
    print("ground_truth :"  + str(Counter(adata.obs.lay_num.values)))
    x = graph.ndata["feat"]
    test = model.embed(graph.to(device), x.to(device))
    test_new_pred = kMeans_use(test.cpu().detach().numpy(),num_classes)
    score = adjusted_rand_score(adata.obs.lay_num.values, test_new_pred )
    adata.obs["second_pre"] = test_new_pred
    print("second cluster:")
    print(score)
    print("结果:" + str(Counter(test_new_pred)))
    print("ground_truth :"  + str(Counter(adata.obs.lay_num.values)))
    #drawPicture(adata.obs,"imagecol","imagerow",colorattribute="lay_num",save_file= None)
    #drawPicture(adata.obs,"imagecol","imagerow",colorattribute="pre",save_file= None)
    #drawPicture(adata.obs,"imagecol","imagerow",colorattribute="second_pre",save_file= None)

In [None]:
drawPicture(adata.obs,"imagecol","imagerow",colorattribute="second_pre",save_file= None)

In [None]:
# #训练模型
# x = graph.ndata["feat"]
# if not load_model:
#     model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
#     model = model.cpu()

In [None]:
# x = graph.ndata["feat"]
# model.to(device)
# embedding = model.embed(graph.to(device), x.to(device))
# new_pred = kMeans_use(embedding.cpu().detach().numpy(),num_classes)
# adata.obs["pre"] = new_pred
# score = adjusted_rand_score(adata.obs.lay_num.values, new_pred )
# print("first cluster:")
# print(score)
# print("结果:" + str(Counter(new_pred)))
# print("ground_truth :"  + str(Counter(adata.obs.lay_num.values)))

In [None]:
# x = graph.ndata["feat"]
# test = model.embed(graph.to(device), x.to(device))
# test_new_pred = kMeans_use(test.cpu().detach().numpy(),num_classes)
# score = adjusted_rand_score(adata.obs.lay_num.values, test_new_pred )
# adata.obs["second_pre"] = test_new_pred
# print("second cluster:")
# print(score)
# print("结果:" + str(Counter(test_new_pred)))
# print("ground_truth :"  + str(Counter(adata.obs.lay_num.values)))

In [None]:
# drawPicture(adata.obs,"imagecol","imagerow",colorattribute="lay_num",save_file= None)
# drawPicture(adata.obs,"imagecol","imagerow",colorattribute="pre",save_file= None)
# drawPicture(adata.obs,"imagecol","imagerow",colorattribute="second_pre",save_file= None)