In [1]:
from sklearn.metrics import (adjusted_rand_score, normalized_mutual_info_score, 
                             silhouette_score, calinski_harabasz_score,
                             davies_bouldin_score)
import logging
import numpy as np
from tqdm import tqdm
import torch
from sklearn.preprocessing import LabelEncoder
from graphmae.utils import (
    
    build_args,
    create_optimizer,
    set_random_seed,
    TBLogger,
    get_current_lr,
    load_best_configs,

    
)
from collections import Counter
from graphmae.datasets.data_util import load_dataset
from graphmae.evaluation import node_classification_evaluation
from graphmae.models import build_model
from ogb.nodeproppred import DglNodePropPredDataset
from sklearn.cluster import KMeans
def kMeans_use(embedding,cluster_number):
    kmeans = KMeans(n_clusters=cluster_number,
                init="k-means++",
                random_state=0)
    pred = kmeans.fit_predict(embedding)
    return pred
import argparse
parser = argparse.ArgumentParser(description="GAT")
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--max_epoch", type=int, default=200,
                    help="number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=-1)

parser.add_argument("--num_heads", type=int, default=4,
                    help="number of hidden attention heads")
parser.add_argument("--num_out_heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num_layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--num_hidden", type=int, default=256,
                    help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--in_drop", type=float, default=.2,
                    help="input feature dropout")
parser.add_argument("--attn_drop", type=float, default=.1,
                    help="attention dropout")
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005,
                    help="learning rate")
parser.add_argument("--weight_decay", type=float, default=5e-4,
                    help="weight decay")
parser.add_argument("--negative_slope", type=float, default=0.2,
                    help="the negative slope of leaky relu for GAT")
parser.add_argument("--activation", type=str, default="prelu")
parser.add_argument("--mask_rate", type=float, default=0.5)
parser.add_argument("--drop_edge_rate", type=float, default=0.0)
parser.add_argument("--replace_rate", type=float, default=0.0)

parser.add_argument("--encoder", type=str, default="gat")
parser.add_argument("--decoder", type=str, default="gat")
parser.add_argument("--loss_fn", type=str, default="byol")
parser.add_argument("--alpha_l", type=float, default=2, help="`pow`inddex for `sce` loss")
parser.add_argument("--optimizer", type=str, default="adam")

parser.add_argument("--max_epoch_f", type=int, default=30)
parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
parser.add_argument("--linear_prob", action="store_true", default=False)

parser.add_argument("--load_model", action="store_true")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--logging", action="store_true")
parser.add_argument("--scheduler", action="store_true", default=False)
parser.add_argument("--concat_hidden", action="store_true", default=False)

# for graph classification
parser.add_argument("--pooling", type=str, default="mean")
parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
parser.add_argument("--batch_size", type=int, default=32)

2022-09-11 15:16:06,587 - INFO - Enabling RDKit 2022.03.5 jupyter extensions


_StoreAction(option_strings=['--batch_size'], dest='batch_size', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [2]:
import pickle as pkl

import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import plotly.express as px
import pandas as pd
import scanpy as sc
import scipy as sci

def drawPicture(dataframe,col_name, row_name,colorattribute,save_file,celltype_colors =  ("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", "#A65628", "#F781BF", "#999999", "#E41A80",
    "#377F1C", "#4DAFAE", "#984F07", "#FF7F64", "#FFFF97", "#A6568C", "#F78223", "#9999FD", "#E41AE4", "#377F80", "#4DB012",
    "#984F6B", "#FF7FC8",  "#A656F0", "#F78287", "#999A61", "#E41B48", "#377FE4", "#4DB076", "#984FCF", "#FF802C",
    "#00005F", "#A65754", "#F782EB", "#999AC5", "#E41BAC", "#378048", "#4DB0DA", "#985033", "#FF8090", "#0000C3", "#A657B8",
    "#F7834F", "#999B29", "#E41C10", "#3780AC", "#4DB13E", "#985097", "#FF80F4", "#000127", "#A6581C", "#F783B3", "#999B8D",
    "#E41C74", "#378110", "#4DB1A2", "#9850FB", "#FF8158", "#00018B", "#A65880", "#F78417", "#999BF1", "#E41CD8", "#378174",
    "#4DB206", "#98515F", "#FF81BC", "#0001EF", "#A658E4", "#F7847B", "#999C55", "#E41D3C", "#3781D8", "#4DB26A", "#9851C3",
    "#FF8220", "#000253", "#A65948", "#F784DF", "#999CB9", "#E41DA0", "#37823C", "#4DB2CE", "#985227", "#FF8284", "#0002B7",
    "#A659AC", "#F78543", "#999D1D", "#E41E04", "#3782A0", "#4DB332", "#98528B", "#FF82E8", "#00031B", "#A65A10", "#F785A7",
    "#999D81", "#E41E68", "#378304", "#4DB396", "#9852EF", "#FF834C", "#00037F", "#A65A74", "#F7860B", "#999DE5", "#E41ECC",
    "#378368", "#4DB3FA", "#985353", "#FF83B0", "#0003E3", "#A65AD8", "#F7866F", "#999E49", "#E41F30", "#3783CC", "#4DB45E",
    "#9853B7", "#FF8414", "#000447", "#A65B3C", "#F786D3", "#999EAD", "#E41F94", "#378430", "#4DB4C2", "#98541B", "#FF8478",
    "#0004AB", "#A65BA0", "#F78737", "#999F11", "#E41FF8", "#378494", "#4DB526", "#98547F", "#FF84DC", "#00050F", "#A65C04",
    "#F7879B", "#999F75"
    ),width = 1000,height = 1000,marker_size = 10,is_show = True,is_save = False,save_type = "pdf"):

    import plotly.express as px
    import plotly.graph_objects as go
    size = len(set(dataframe[colorattribute]))
    dataframe.sort_values(by = colorattribute,inplace=True, ascending=True)
    length_col = max(dataframe[col_name]) - min(dataframe[col_name])

    length_row = max(dataframe[row_name]) - min(dataframe[row_name])
    max_length = max(length_col,length_row) + 2
    fig = px.scatter(dataframe, x = col_name, y= row_name,color = colorattribute,color_discrete_sequence=celltype_colors)
    fig.update_traces(marker_size=marker_size)
    fig.update_layout(
        xaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[row_name]),   # 起始点
            dtick = max_length  # 间距
        ),
        yaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[col_name]),   # 起始点
            dtick = max_length  # 间距
        ),
    #     xaxis_range = [min(it_mapping_csv.row)-10,min(it_mapping_csv.row) + max_length ],
    #     yaxis_range =[min(it_mapping_csv.col)-10,min(it_mapping_csv.col) + max_length]
    )
    fig.update_layout(
        autosize=False,
        width=width,
        height = height)
    if(is_show):
        fig.show()
    if(is_save):
        if(save_type == "pdf"):
            fig.write_image(save_file)
        if(save_type == "html"):
            fig.write_html(save_file)

2022-09-11 15:16:17,114 - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2022-09-11 15:16:17,116 - INFO - NumExpr defaulting to 8 threads.


In [3]:
def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
    logging.info("start training..")
    graph = graph.to(device)
    x = feat.to(device)

    epoch_iter = tqdm(range(max_epoch))

    for epoch in epoch_iter:
        model.train()

        loss, loss_dict = model(graph, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
        if logger is not None:
            loss_dict["lr"] = get_current_lr(optimizer)
            logger.note(loss_dict, step=epoch)

        #if (epoch + 1) % 200 == 0:
            #node_classification_evaluation(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)

    # return best_model
    return model

In [4]:
import pandas as pd
#import calculate_adj
import anndata as ad
import scanpy as sc
from operator import index
import re
from scipy.spatial.distance import pdist, squareform
import pandas as pd
import numpy as np
from scipy import sparse
import dgl
import torch
folder_name = "/home/sunhang/Embedding/CCST/dataset/DLPFC/"
samplea_list = {151507,
 151508,
 151509,
 151510,
 151669,
 151670,
 151671,
 151672,
 151673,
 151674,
 151675,
 151676}
sample_name = str(151673)
#gene_exp_data_file = folder_name + sample_name + "_DLPFC_count.csv"
gene_loc_data_file = folder_name + sample_name + "_DLPFC_col_name.csv"
#save file
npz_file = folder_name + sample_name + "_DLPFC_distacne.npz"
#gene_csv = pd.read_csv(gene_exp_data_file,index_col= 0 )
gene_loc_data_csv = pd.read_csv(gene_loc_data_file,index_col=0)
gene_loc_data_csv = gene_loc_data_csv.fillna("None")
min_cells = 5
pca_n_comps = 3000

In [5]:
from scanpy import read_10x_h5
import warnings
warnings.filterwarnings("ignore")
adata = read_10x_h5("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/filtered_feature_bc_matrix.h5")
spatial=pd.read_csv("/home/sunhang/Embedding/SpaGCN/tutorial/data/" + sample_name +"/tissue_positions_list.txt",sep=",",header=None,na_filter=False,index_col=0) 
adata.obs["x1"]=spatial[1]
adata.obs["x2"]=spatial[2]
adata.obs["x3"]=spatial[3]
adata.obs["x4"]=spatial[4]
adata.obs["x5"]=spatial[5]
#Select captured samples
adata=adata[adata.obs["x1"]==1]
adata.var_names=[i.upper() for i in list(adata.var_names)]
adata.var["genename"]=adata.var.index.astype("str")

adata.var_names_make_unique

<bound method AnnData.var_names_make_unique of AnnData object with n_obs × n_vars = 3639 × 33538
    obs: 'x1', 'x2', 'x3', 'x4', 'x5'
    var: 'gene_ids', 'feature_types', 'genome', 'genename'>

In [6]:
pca_n_comps = 3000
adata.obs = pd.merge(gene_loc_data_csv,adata.obs,right_index=True,left_on="barcode")
adata.obs.index = adata.obs["barcode"]
le = LabelEncoder()
label = le.fit_transform(adata.obs['layer_guess_reordered_short'])
adata.obs["lay_num"] = label
sc.pp.filter_genes(adata, min_cells=5)
adata_X = sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, inplace=False)['X']
adata_X = sc.pp.scale(adata_X)
adata_X = sc.pp.pca(adata_X, n_comps=pca_n_comps)

In [7]:
row_name = "imagerow"
col_name = "imagecol"
cell_loc = adata.obs[[row_name,col_name]].values
distance_np = pdist(cell_loc, metric = "euclidean")
distance_np_X =squareform(distance_np)
distance_loc_csv = pd.DataFrame(index=adata.obs.index, columns=adata.obs.index,data = distance_np_X)
# threshold = 8
# num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
# #num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
# adj_matrix = np.zeros(distance_np_X.shape)
# non_zero_point = np.where((0 < distance_np_X) & (distance_np_X < threshold))
# adj_matrix = np.zeros(distance_np_X.shape)
# non_zero_point = np.where((0< distance_np_X)&(distance_np_X<threshold))
# for i in range(num_big):
#     x = non_zero_point[0][i]
#     y = non_zero_point[1][i]
#     adj_matrix[x][y] = 1 
# adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
# adj_matrix  = np.float32(adj_matrix)
# adj_matrix_crs = sparse.csr_matrix(adj_matrix)
# graph = dgl.from_scipy(adj_matrix_crs,eweight_name='w')

In [8]:
gene_exp_np = adata_X
distance_gene_np = pdist(gene_exp_np, metric = "euclidean")
distance_gene_np_X =squareform(distance_gene_np)
distance_gene_csv = pd.DataFrame(index=adata.obs.index, columns=adata.obs.index,data = distance_gene_np_X)
distance_loc_rank_value_csv = pd.DataFrame(index = adata.obs.index )
distance_loc_rank_index_csv = pd.DataFrame(index = adata.obs.index )
distance_np_X[range(len(adata.obs.index)),np.argmin(distance_np_X,axis=1)]=distance_np_X.max()
K_NUM = 8
for k in range(K_NUM):
    #distance_np_X[range(len(all_time_csv.index)),np.argmin(distance_np_X,axis=1)]=distance_np_X.max()
    distance_loc_rank_value_csv[k] = np.min(distance_np_X,axis=1)
    distance_loc_rank_index_csv[k] = np.argmin(distance_np_X,axis=1)
    distance_np_X[range(len(adata.obs.index)),np.argmin(distance_np_X,axis=1)]=distance_np_X.max() 

In [9]:
# 提取gene distance rank 的矩阵
distance_gene_rank_value_csv = pd.DataFrame(index = range(K_NUM))
for i in range(len(adata.obs.index)):
    distance_gene_rank_value_csv[adata.obs.index[i]] =distance_gene_csv.iloc[i,distance_loc_rank_index_csv.iloc[i]].values
distance_gene_rank_value_np_X = distance_gene_rank_value_csv.values
result_rank_value_csv = pd.DataFrame(index = adata.obs.index )
result_rank_index_csv = pd.DataFrame(index = adata.obs.index )
distance_loc_rank_index_np = distance_loc_rank_index_csv.values
for k in range(K_NUM -3):
    #distance_np_X[range(len(all_time_csv.index)),np.argmin(distance_np_X,axis=1)]=distance_np_X.max()
    result_rank_value_csv[k] = np.min(distance_gene_rank_value_np_X ,axis=0)
    result_rank_index_csv[k] = distance_loc_rank_index_np[range(len(np.argmin(distance_gene_rank_value_np_X,axis=0))), np.argmin(distance_gene_rank_value_np_X,axis=0)]
    distance_gene_rank_value_np_X [np.argmin(distance_gene_rank_value_np_X,axis=0),range(len(distance_gene_rank_value_csv.columns))]=distance_gene_rank_value_np_X.max()


In [99]:
result_rank_index_csv

Unnamed: 0_level_0,0,1,2,3,4
barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
AAACAAGTATCTCCCA-1,3099,2329,485,397,3542
AAACAATCTACTAGCA-1,1885,3363,1502,3106,1617
AAACACCAATAACTGC-1,1849,255,2280,2629,3024
AAACAGAGCGACTCCT-1,231,3592,2241,2088,1910
AAACAGCTTTCAGAAG-1,1806,2389,446,2840,306
...,...,...,...,...,...
TTGTTTCACATCCAGG-1,1632,403,2562,2924,2638
TTGTTTCATTAGTCTA-1,2883,64,3278,844,312
TTGTTTCCATACAACT-1,2626,1973,3409,578,928
TTGTTTGTATTACACG-1,6,1974,3313,2159,2582


In [98]:
adj_matrix = np.zeros(distance_np_X.shape)
for i in range(len(result_rank_index_csv.index)):
    for j in range(len(result_rank_index_csv.columns)):
        x = i
        y = result_rank_index_csv.values[i][j]
        adj_matrix[x][y] = 1
adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
adj_matrix  = np.float32(adj_matrix)
adj_matrix_crs = sparse.csr_matrix(adj_matrix)
graph = dgl.from_scipy(adj_matrix_crs,eweight_name='w')

In [13]:
import plotly.graph_objects as go

In [75]:
celltype_colors =  ("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", "#A65628", "#F781BF", "#999999", "#E41A80",
    "#377F1C", "#4DAFAE", "#984F07", "#FF7F64", "#FFFF97", "#A6568C", "#F78223", "#9999FD", "#E41AE4", "#377F80", "#4DB012",
    "#984F6B", "#FF7FC8",  "#A656F0", "#F78287", "#999A61", "#E41B48", "#377FE4", "#4DB076", "#984FCF", "#FF802C",
    "#00005F", "#A65754", "#F782EB", "#999AC5", "#E41BAC", "#378048", "#4DB0DA", "#985033", "#FF8090", "#0000C3", "#A657B8",
    "#F7834F", "#999B29", "#E41C10", "#3780AC", "#4DB13E", "#985097", "#FF80F4", "#000127", "#A6581C", "#F783B3", "#999B8D",
    "#E41C74", "#378110", "#4DB1A2", "#9850FB", "#FF8158", "#00018B", "#A65880", "#F78417", "#999BF1", "#E41CD8", "#378174",
    "#4DB206", "#98515F", "#FF81BC", "#0001EF", "#A658E4", "#F7847B", "#999C55", "#E41D3C", "#3781D8", "#4DB26A", "#9851C3",
    "#FF8220", "#000253", "#A65948", "#F784DF", "#999CB9", "#E41DA0", "#37823C", "#4DB2CE", "#985227", "#FF8284", "#0002B7",
    "#A659AC", "#F78543", "#999D1D", "#E41E04", "#3782A0", "#4DB332", "#98528B", "#FF82E8", "#00031B", "#A65A10", "#F785A7",
    "#999D81", "#E41E68", "#378304", "#4DB396", "#9852EF", "#FF834C", "#00037F", "#A65A74", "#F7860B", "#999DE5", "#E41ECC",
    "#378368", "#4DB3FA", "#985353", "#FF83B0", "#0003E3", "#A65AD8", "#F7866F", "#999E49", "#E41F30", "#3783CC", "#4DB45E",
    "#9853B7", "#FF8414", "#000447", "#A65B3C", "#F786D3", "#999EAD", "#E41F94", "#378430", "#4DB4C2", "#98541B", "#FF8478",
    "#0004AB", "#A65BA0", "#F78737", "#999F11", "#E41FF8", "#378494", "#4DB526", "#98547F", "#FF84DC", "#00050F", "#A65C04",
    "#F7879B", "#999F75"
    )

In [119]:
celltype_color = []
for i in adata.obs.index:
    celltype_color.append(celltype_colors[adata.obs.lay_num[i]])
adata.obs["celltype_colors"] = celltype_color

In [144]:
adata.obs["index_num"] = range(len(adata.obs.index))

In [145]:
adata.obs["celltype_colors"]

barcode
AAACAAGTATCTCCCA-1    #4DAF4A
AAACAATCTACTAGCA-1    #E41A1C
AAACACCAATAACTGC-1    #F781BF
AAACAGAGCGACTCCT-1    #4DAF4A
AAACAGCTTTCAGAAG-1    #FF7F00
                       ...   
TTGTTTCACATCCAGG-1    #F781BF
TTGTTTCATTAGTCTA-1    #F781BF
TTGTTTCCATACAACT-1    #FFFF33
TTGTTTGTATTACACG-1    #F781BF
TTGTTTGTGTAAATTC-1    #377EB8
Name: celltype_colors, Length: 3639, dtype: object

In [146]:
node_trace = go.Scatter(
    x = adata.obs.imagerow,
    y = adata.obs.imagecol,
    hovertext = adata.obs.index_num,
    marker=dict(
    size=5,
    color = adata.obs.celltype_colors,
    ),
    mode= "markers",
    #     xaxis = dict(
#         tickmode = 'linear',
#         tick0 = min(adata.obs[row_name]),   # 起始点
#         dtick = max_length  # 间距
#     ),
#     yaxis = dict(
#         tickmode = 'linear',
#         tick0 = min(adata.obs[col_name]),   # 起始点
#         dtick = max_length  # 间距
#     ),
)
layout = {
    
}
node_traces = []
node_traces.append(node_trace)

In [147]:
go.Scatter?

In [148]:
row_name = "imagerow"
col_name = "imagecol"
length_col = max(adata.obs[col_name]) - min(adata.obs[col_name])

length_row = max(adata.obs[row_name]) - min(adata.obs[row_name])
max_length = max(length_col,length_row) + 2
#fig.show()
width = 1000
height = 1000
marker_size = 10,
layout = {
    'xaxis' : dict(
        tickmode = 'linear',
        tick0 = min(adata.obs[row_name]),   # 起始点
        dtick = max_length  # 间距
    ),
    'yaxis' : dict(
        tickmode = 'linear',
        tick0 = min(adata.obs[col_name]),   # 起始点
        dtick = max_length  # 间距
    ), 
    'autosize' : False,
    'width' : width,
    'height' : height
    
}
fig = go.Figure(node_trace,layout=layout)
#fig.update_traces(marker_size=marker_size)
# fig.update_layout(
#     xaxis = dict(
#         tickmode = 'linear',
#         tick0 = min(adata.obs[row_name]),   # 起始点
#         dtick = max_length  # 间距
#     ),
#     yaxis = dict(
#         tickmode = 'linear',
#         tick0 = min(adata.obs[col_name]),   # 起始点
#         dtick = max_length  # 间距
#     ),
# #     xaxis_range = [min(it_mapping_csv.row)-10,min(it_mapping_csv.row) + max_length ],
# #     yaxis_range =[min(it_mapping_csv.col)-10,min(it_mapping_csv.col) + max_length]
# )
# fig.update_layout(
#     autosize=False,
#     width=width,
#     height = height)
fig.show()

In [124]:
index_name =  result_rank_index_csv.index[0]

In [126]:
adata.obs.loc[index_name,"celltype_colors"]

'#4DAF4A'

In [155]:
edge_traces = []
edge_x = []
edge_y = []
edge_color = []
line_color_list = []
index_name =  result_rank_index_csv.index[3268]
for col in result_rank_index_csv.columns:
    another_name = result_rank_index_csv.index[result_rank_index_csv.loc[index_name,col]]
    x0 = adata.obs.loc[index_name,row_name]
    y0 = adata.obs.loc[index_name,col_name]
    edge_color0 = adata.obs.loc[index_name,"celltype_colors"]
    x1 = adata.obs.loc[another_name,row_name]
    y1 = adata.obs.loc[another_name,col_name]
    edge_color1 = adata.obs.loc[another_name,"celltype_colors"]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)
    edge_color.append(edge_color0)
    edge_color.append(edge_color1)
    edge_color.append('#888')
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        legendgrouptitle_text="Line Title",
        line=dict(width=1),
        hoverinfo='none',
        mode='lines')
    edge_traces.append(edge_trace)

In [156]:
fig = go.Figure(node_traces + edge_traces,layout=layout)
fig.show()

In [107]:
adata.obs.columns

Index(['barcode', 'sample_name', 'tissue', 'row', 'col', 'imagerow',
       'imagecol', 'Cluster', 'height', 'width', 'sum_umi', 'sum_gene',
       'subject', 'position', 'replicate', 'subject_position', 'discard',
       'key', 'cell_count', 'SNN_k50_k4', 'SNN_k50_k5', 'SNN_k50_k6',
       'SNN_k50_k7', 'SNN_k50_k8', 'SNN_k50_k9', 'SNN_k50_k10', 'SNN_k50_k11',
       'SNN_k50_k12', 'SNN_k50_k13', 'SNN_k50_k14', 'SNN_k50_k15',
       'SNN_k50_k16', 'SNN_k50_k17', 'SNN_k50_k18', 'SNN_k50_k19',
       'SNN_k50_k20', 'SNN_k50_k21', 'SNN_k50_k22', 'SNN_k50_k23',
       'SNN_k50_k24', 'SNN_k50_k25', 'SNN_k50_k26', 'SNN_k50_k27',
       'SNN_k50_k28', 'GraphBased', 'Maynard', 'Martinowich', 'Layer',
       'layer_guess', 'layer_guess_reordered', 'layer_guess_reordered_short',
       'expr_chrM', 'expr_chrM_ratio', 'SpatialDE_PCA', 'SpatialDE_pool_PCA',
       'HVG_PCA', 'pseudobulk_PCA', 'markers_PCA', 'SpatialDE_UMAP',
       'SpatialDE_pool_UMAP', 'HVG_UMAP', 'pseudobulk_UMAP', 'markers_UM

In [None]:
edge_traces = []
edge_x = []
edge_y = []
edge_color = []
line_color_list = []
for edge in G.edges():

    x0, z0, y0 ,edge_color0= G.nodes[edge[0]]['pos'][0:4]
    x1, z1, y1 ,edge_color1= G.nodes[edge[1]]['pos'][0:4]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_x.append(None)
    edge_y.append(y0)
    edge_y.append(y1)
    edge_y.append(None)
    edge_color.append(edge_color0)
    edge_color.append(edge_color1)
    edge_color.append('#888')
    edge_trace = go.Scatter3d(
        x=edge_x, y=edge_y, z = edge_z,
        legendgrouptitle_text="Line Title",
        line=dict(width=1, color=edge_color),
        hoverinfo='none',
        mode='lines')
    edge_traces.append(edge_trace)

In [None]:
edge_traces = []

In [70]:
go.Figure?

In [71]:
max_length

413.65617979399997

In [20]:
fig.show()

In [181]:
parameters = {
    "vect__max_df": np.arange(0,1,0.1),
    # 'vect__max_features': (None, 5000, 10000, 50000),
    "vect__ngram_range": ((1, 1), (1, 2)),  # unigrams or bigrams
    # 'tfidf__use_idf': (True, False),
    # 'tfidf__norm': ('l1', 'l2'),
    "clf__max_iter": (20,),
    "clf__alpha": (0.00001, 0.000001),
    "clf__penalty": ("l2", "elasticnet"),
    # 'clf__max_iter': (10, 50, 80),
}

class myDict(object):
    def __init__(self,mydict):
        self.mydict = mydict
        self.length = []
        self.keys = []
        for key,values in self.mydict.items():
            self.keys.append(key)
            self.length.append(len(values))
        self.nums = [1] * len(self.length)
        for i in range(len(self.length)):
            for j in range(i,len(self.length)):
                self.nums[i] *= self.length[j]
        self.para_dis = []
        print(self.length)
        print(self.nums)
                
    def getindex(self,index):
        result = []
        value = index
        for i in range(len(self.nums) - 1):
            result.append(value // self.nums[i+1])
            value = value - result[i] * self.nums[i+1]
        result.append(value) 
        result_dict = dict()
        for index,value in enumerate(result):
            result_dict[self.keys[index]] = self.mydict.get(self.keys[index])[value]
        return result_dict
    
    #para_dis = []
    def myiter(self):
        #para_dis = []
        for i in range(0,self.nums[0]):
            self.para_dis.append(self.getindex(i))
        return self.para_dis

In [182]:
mydict = myDict(parameters)
aa  =mydict.myiter()

[10, 2, 1, 2, 2]
[80, 8, 4, 4, 2]


In [183]:
len(aa)

80

In [189]:
parameters = {
    "dataset_name":("151673",),
    "graph_devise":("CCST",),
    "threshold_num":(8,),
    "feature_dim":("PCA",),
    "feature_dim_num" : np.append(np.arange(100,1000,100),np.arange(1000,4000,500)),
    "mask_rate" : np.arange(0,1,0.1),
    "code_networt_and_norm" : (["gat",None],["dotgat",None]),
    "num_hidden" : (128,256,512),
    "num_layers" : (1,2,3),
    "activation" : ("relu","gelu","prelu","elu"),
    "max_epoch" : np.arange(500,4000,500),
    "lr" : (0.001,)
    #"vect__max_df": (0.5, 0.75, 1.0),
    # 'vect__max_features': (None, 5000, 10000, 50000),
    #"vect__ngram_range": ((1, 1), (1, 2)),  # unigrams or bigrams
    # 'tfidf__use_idf': (True, False),
    # 'tfidf__norm': ('l1', 'l2'),
    #"clf__max_iter": (20,),
    #"clf__alpha": (0.00001, 0.000001),
    #"clf__penalty": ("l2", "elasticnet"),
    # 'clf__max_iter': (10, 50, 80),
}


In [190]:
mydict = myDict(parameters)
aa  =mydict.myiter()

[1, 1, 1, 1, 15, 10, 2, 3, 3, 4, 7, 1]
[75600, 75600, 75600, 75600, 75600, 5040, 504, 252, 84, 28, 7, 1]


In [194]:
aa【0】

SyntaxError: invalid character in identifier (3535634408.py, line 1)

In [195]:
aa[0]["mask_rate"]

0.0

In [196]:
parameters_list  = mydict.myiter()
time = 0

In [197]:
choose_parameter = parameters_list[0]

In [202]:
choose_parameter["code_networt_and_norm"]

'gat'

In [None]:
code_networt_and_norm