In [26]:
from sklearn.metrics import (adjusted_rand_score, normalized_mutual_info_score, 
                             silhouette_score, calinski_harabasz_score,
                             davies_bouldin_score)

In [27]:
import logging
import numpy as np
from tqdm import tqdm
import torch

from graphmae.utils import (
    
    build_args,
    create_optimizer,
    set_random_seed,
    TBLogger,
    get_current_lr,
    load_best_configs,

    
)
from graphmae.datasets.data_util import load_dataset
from graphmae.evaluation import node_classification_evaluation
from graphmae.models import build_model
from ogb.nodeproppred import DglNodePropPredDataset
#from graphmae.my_util import kmeans_use

In [28]:
from sklearn.cluster import KMeans

In [29]:
import argparse
parser = argparse.ArgumentParser(description="GAT")
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--max_epoch", type=int, default=200,
                    help="number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=-1)

parser.add_argument("--num_heads", type=int, default=4,
                   help="number of hidden attention heads")
parser.add_argument("--num_out_heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num_layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--num_hidden", type=int, default=256,
                    help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--in_drop", type=float, default=.2,
                    help="input feature dropout")
parser.add_argument("--attn_drop", type=float, default=.1,
                    help="attention dropout")
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005,
                    help="learning rate")
parser.add_argument("--weight_decay", type=float, default=5e-4,
                    help="weight decay")
parser.add_argument("--negative_slope", type=float, default=0.2,
                    help="the negative slope of leaky relu for GAT")
parser.add_argument("--activation", type=str, default="prelu")
parser.add_argument("--mask_rate", type=float, default=0.5)
parser.add_argument("--drop_edge_rate", type=float, default=0.0)
parser.add_argument("--replace_rate", type=float, default=0.0)

parser.add_argument("--encoder", type=str, default="gat")
parser.add_argument("--decoder", type=str, default="gat")
parser.add_argument("--loss_fn", type=str, default="byol")
parser.add_argument("--alpha_l", type=float, default=2, help="`pow`inddex for `sce` loss")
parser.add_argument("--optimizer", type=str, default="adam")

parser.add_argument("--max_epoch_f", type=int, default=30)
parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
parser.add_argument("--linear_prob", action="store_true", default=False)

parser.add_argument("--load_model", action="store_true")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--logging", action="store_true")
parser.add_argument("--scheduler", action="store_true", default=False)
parser.add_argument("--concat_hidden", action="store_true", default=False)

# for graph classification
parser.add_argument("--pooling", type=str, default="mean")
parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
parser.add_argument("--batch_size", type=int, default=32)

_StoreAction(option_strings=['--batch_size'], dest='batch_size', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [30]:
def normalize(adata, copy=True, highly_genes = None, filter_min_counts=True, 
              size_factors=True, normalize_input=True, logtrans_input=True):
    """
    Normalizes input data and retains only most variable genes 
    (indicated by highly_genes parameter)

    Args:
        adata ([type]): [description]
        copy (bool, optional): [description]. Defaults to True.
        highly_genes ([type], optional): [description]. Defaults to None.
        filter_min_counts (bool, optional): [description]. Defaults to True.
        size_factors (bool, optional): [description]. Defaults to True.
        normalize_input (bool, optional): [description]. Defaults to True.
        logtrans_input (bool, optional): [description]. Defaults to True.

    Raises:
        NotImplementedError: [description]

    Returns:
        [type]: [description]
    """
    if isinstance(adata, sc.AnnData):
        if copy:
            adata = adata.copy()
    elif isinstance(adata, str):
        adata = sc.read(adata)
    else:
        raise NotImplementedError
    norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
    assert 'n_count' not in adata.obs, norm_error
    if adata.X.size < 50e6: # check if adata.X is integer only if array is small
        if sci.sparse.issparse(adata.X):
            assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error
        else:
            assert np.all(adata.X.astype(int) == adata.X), norm_error

    if filter_min_counts:
        sc.pp.filter_genes(adata, min_counts=1)#3
        sc.pp.filter_cells(adata, min_counts=1)
    if size_factors or normalize_input or logtrans_input:
        adata.raw = adata.copy()
    else:
        adata.raw = adata
    if size_factors:
        sc.pp.normalize_per_cell(adata)
        adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
    else:
        adata.obs['size_factors'] = 1.0
    if logtrans_input:
        sc.pp.log1p(adata)
    if highly_genes != None:
        sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes = highly_genes, subset=True)
    if normalize_input:
        sc.pp.scale(adata)
    return adata

In [31]:
import pickle as pkl

import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import plotly.express as px
import pandas as pd
import scanpy as sc
import scipy as sci

def drawPicture(dataframe,col_name, row_name,colorattribute,save_file,celltype_colors =  ("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", "#A65628", "#F781BF", "#999999", "#E41A80",
    "#377F1C", "#4DAFAE", "#984F07", "#FF7F64", "#FFFF97", "#A6568C", "#F78223", "#9999FD", "#E41AE4", "#377F80", "#4DB012",
    "#984F6B", "#FF7FC8",  "#A656F0", "#F78287", "#999A61", "#E41B48", "#377FE4", "#4DB076", "#984FCF", "#FF802C",
    "#00005F", "#A65754", "#F782EB", "#999AC5", "#E41BAC", "#378048", "#4DB0DA", "#985033", "#FF8090", "#0000C3", "#A657B8",
    "#F7834F", "#999B29", "#E41C10", "#3780AC", "#4DB13E", "#985097", "#FF80F4", "#000127", "#A6581C", "#F783B3", "#999B8D",
    "#E41C74", "#378110", "#4DB1A2", "#9850FB", "#FF8158", "#00018B", "#A65880", "#F78417", "#999BF1", "#E41CD8", "#378174",
    "#4DB206", "#98515F", "#FF81BC", "#0001EF", "#A658E4", "#F7847B", "#999C55", "#E41D3C", "#3781D8", "#4DB26A", "#9851C3",
    "#FF8220", "#000253", "#A65948", "#F784DF", "#999CB9", "#E41DA0", "#37823C", "#4DB2CE", "#985227", "#FF8284", "#0002B7",
    "#A659AC", "#F78543", "#999D1D", "#E41E04", "#3782A0", "#4DB332", "#98528B", "#FF82E8", "#00031B", "#A65A10", "#F785A7",
    "#999D81", "#E41E68", "#378304", "#4DB396", "#9852EF", "#FF834C", "#00037F", "#A65A74", "#F7860B", "#999DE5", "#E41ECC",
    "#378368", "#4DB3FA", "#985353", "#FF83B0", "#0003E3", "#A65AD8", "#F7866F", "#999E49", "#E41F30", "#3783CC", "#4DB45E",
    "#9853B7", "#FF8414", "#000447", "#A65B3C", "#F786D3", "#999EAD", "#E41F94", "#378430", "#4DB4C2", "#98541B", "#FF8478",
    "#0004AB", "#A65BA0", "#F78737", "#999F11", "#E41FF8", "#378494", "#4DB526", "#98547F", "#FF84DC", "#00050F", "#A65C04",
    "#F7879B", "#999F75"
    ),width = 1000,height = 1000,marker_size = 10,is_show = True,is_save = False,save_type = "pdf"):

    import plotly.express as px
    import plotly.graph_objects as go
    size = len(set(dataframe[colorattribute]))
#     dataframe.sort_values(by = colorattribute,inplace=True, ascending=True)
    length_col = max(dataframe[col_name]) - min(dataframe[col_name])

    length_row = max(dataframe[row_name]) - min(dataframe[row_name])
    max_length = max(length_col,length_row) + 2
    fig = px.scatter(dataframe, x = col_name, y= row_name,color = colorattribute,color_discrete_sequence=celltype_colors)
    fig.update_traces(marker_size=marker_size)
    fig.update_layout(
        xaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[row_name]),   # 起始点
            dtick = max_length  # 间距
        ),
        yaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[col_name]),   # 起始点
            dtick = max_length  # 间距
        ),
    #     xaxis_range = [min(it_mapping_csv.row)-10,min(it_mapping_csv.row) + max_length ],
    #     yaxis_range =[min(it_mapping_csv.col)-10,min(it_mapping_csv.col) + max_length]
    )
    fig.update_layout(
        autosize=False,
        width=width,
        height = height)
    if(is_show):
        fig.show()
    if(is_save):
        if(save_type == "pdf"):
            fig.write_image(save_file)
        if(save_type == "html"):
            fig.write_html(save_file)

In [None]:
import pandas as pd
#import calculate_adj
import anndata as ad
import scanpy as sc
from operator import index
import re
from scipy.spatial.distance import pdist, squareform
import pandas as pd
import numpy as np
from scipy import sparse
import dgl
import torch
folder_name = "/home/sunhang/Embedding/CCST/dataset/DLPFC/"
samplea_list = {151507,
 151508,
 151509,
 151510,
 151669,
 151670,
 151671,
 151672,
 151673,
 151674,
 151675,
 151676}
sample_name = str(151676)
gene_exp_data_file = folder_name + sample_name + "_DLPFC_count.csv"
gene_loc_data_file = folder_name + sample_name + "_DLPFC_col_name.csv"
#save file
npz_file = folder_name + sample_name + "_DLPFC_distacne.npz"
gene_csv = pd.read_csv(gene_exp_data_file,index_col= 0 )
gene_loc_data_csv = pd.read_csv(gene_loc_data_file,index_col=0)
gene_loc_data_csv = gene_loc_data_csv.fillna("None")
row_name = "imagerow"
col_name = "imagecol"
cell_loc = gene_loc_data_csv[[row_name,col_name]].values
distance_np = pdist(cell_loc, metric = "euclidean")
distance_np_X =squareform(distance_np)
distance_loc_csv = pd.DataFrame(index=gene_loc_data_csv.index, columns=gene_loc_data_csv.index,data = distance_np_X)
threshold = 8
num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
#num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
adj_matrix = np.zeros(distance_np_X.shape)
non_zero_point = np.where((0 < distance_np_X) & (distance_np_X < threshold))
adj_matrix = np.zeros(distance_np_X.shape)
non_zero_point = np.where((0< distance_np_X)&(distance_np_X<threshold))
for i in range(num_big):
    x = non_zero_point[0][i]
    y = non_zero_point[1][i]
    adj_matrix[x][y] = 1 
adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
adj_matrix  = np.float32(adj_matrix)
adj_matrix_crs = sparse.csr_matrix(adj_matrix)
graph = dgl.from_scipy(adj_matrix_crs,eweight_name='w')
min_cells = 5
highly_genes = 3000
adata = ad.AnnData(gene_csv.values.T, obs=distance_loc_csv, var=pd.DataFrame(index = gene_csv.index), dtype='int32')
adata = normalize(adata,
                copy=True,
                highly_genes=highly_genes,
                size_factors=False,
                normalize_input=True,
                logtrans_input=True)

adata_X = adata.X.astype(np.float32)
from sklearn.preprocessing import LabelEncoder
# Creating a instance of label Encoder.
le = LabelEncoder()
# Using .fit_transform function to fit label
# encoder and return encoded label
label = le.fit_transform(gene_loc_data_csv['layer_guess_reordered_short'])
gene_loc_data_csv["lay_num"] = label
graph.ndata["feat"] = torch.tensor(adata_X.copy())
num_features = graph.ndata["feat"].shape[1]
num_classes = len(set(gene_loc_data_csv.lay_num))

In [None]:
num_big

In [None]:
adata_X.copy()

In [None]:
adata_X

In [25]:
# print(adata_X)
random_state=0
new_pred = KMeans(n_clusters=num_classes, random_state=random_state).fit_predict(adata_X)
# print(new_pred)
# from collections import Counter
# print(Counter(new_pred))
# print(Counter(gene_loc_data_csv.lay_num.values))
# score = adjusted_rand_score(gene_loc_data_csv.lay_num.values, new_pred )
# print(score)
gene_loc_data_csv["before_pre"] = new_pred
print(gene_loc_data_csv)
drawPicture(gene_loc_data_csv,"col","row",colorattribute="before_pre",save_file= None)
print(gene_loc_data_csv)
# print(adata_X)

                                  barcode  sample_name  tissue  row  col  \
AAACAAGTATCTCCCA-1.11  AAACAAGTATCTCCCA-1       151676       1   50  102   
AAACAATCTACTAGCA-1.5   AAACAATCTACTAGCA-1       151676       1    3   43   
AAACACCAATAACTGC-1.11  AAACACCAATAACTGC-1       151676       1   59   19   
AAACAGAGCGACTCCT-1.10  AAACAGAGCGACTCCT-1       151676       1   14   94   
AAACAGGGTCTATATT-1.11  AAACAGGGTCTATATT-1       151676       1   47   13   
...                                   ...          ...     ...  ...  ...   
TTGTTGTGTGTCAAGA-1.11  TTGTTGTGTGTCAAGA-1       151676       1   31   77   
TTGTTTCACATCCAGG-1.11  TTGTTTCACATCCAGG-1       151676       1   58   42   
TTGTTTCATTAGTCTA-1.11  TTGTTTCATTAGTCTA-1       151676       1   60   30   
TTGTTTCCATACAACT-1.11  TTGTTTCCATACAACT-1       151676       1   45   27   
TTGTTTGTGTAAATTC-1.10  TTGTTTGTGTAAATTC-1       151676       1    7   51   

                         imagerow    imagecol  Cluster  height  width  ...  \
AAACAAGTA

                                  barcode  sample_name  tissue  row  col  \
GTAGTCTACGATATTG-1.11  GTAGTCTACGATATTG-1       151676       1   63   77   
AGTCGTATAAAGCAGA-1.11  AGTCGTATAAAGCAGA-1       151676       1   65   81   
ACCGACACATCTCCCA-1.11  ACCGACACATCTCCCA-1       151676       1   61   67   
CACATTTCTTGTCAGA-1.11  CACATTTCTTGTCAGA-1       151676       1   44  108   
TTGCACAATTCAGAAA-1.11  TTGCACAATTCAGAAA-1       151676       1   43   57   
...                                   ...          ...     ...  ...  ...   
GGGAATGAGCCCTCAC-1.11  GGGAATGAGCCCTCAC-1       151676       1   19  107   
GGGAAGACGGTCTGTC-1.11  GGGAAGACGGTCTGTC-1       151676       1   64   76   
GGGAACCACCTGTTTC-1.5   GGGAACCACCTGTTTC-1       151676       1    1   35   
TTGTTTGTGTAAATTC-1.10  TTGTTTGTGTAAATTC-1       151676       1    7   51   
TACTCGGCACGCCGGG-1.11  TACTCGGCACGCCGGG-1       151676       1   36   16   

                         imagerow    imagecol  Cluster  height  width  ...  \
GTAGTCTAC

In [13]:
# print(adata_X)
# random_state=0
# new_pred1 = KMeans(n_clusters=num_classes, random_state=random_state).fit_predict(adata_X)
# print(new_pred)
# from collections import Counter
# print(Counter(new_pred))
# print(Counter(gene_loc_data_csv.lay_num.values))
# score = adjusted_rand_score(gene_loc_data_csv.lay_num.values, new_pred )
# print(score)
gene_loc_data_csv["before_pre"] = new_pred
print(gene_loc_data_csv)
drawPicture(gene_loc_data_csv,"col","row",colorattribute="before_pre",save_file= None)
# print(adata_X)

[[-0.44916332 -0.3244209   4.0256653  ...  1.0916187  -0.11364536
  -0.11493861]
 [-0.44916332 -0.3244209   2.402947   ...  0.5686377  -0.11364536
  -0.11493861]
 [-0.44916332 -0.3244209  -0.37110797 ... -0.36225617 -0.11364536
  -0.11493861]
 ...
 [-0.44916332  4.614942   -0.37110797 ... -3.4746752  -0.11364536
  -0.11493861]
 [ 1.9001449  -0.3244209  -0.37110797 ...  0.8065234  -0.11364536
  -0.11493861]
 [-0.44916332 -0.3244209   2.402947   ...  0.16663085 -0.11364536
  -0.11493861]]


[[-0.44916332 -0.3244209   4.0256653  ...  1.0916187  -0.11364536
  -0.11493861]
 [-0.44916332 -0.3244209   2.402947   ...  0.5686377  -0.11364536
  -0.11493861]
 [-0.44916332 -0.3244209  -0.37110797 ... -0.36225617 -0.11364536
  -0.11493861]
 ...
 [-0.44916332  4.614942   -0.37110797 ... -3.4746752  -0.11364536
  -0.11493861]
 [ 1.9001449  -0.3244209  -0.37110797 ...  0.8065234  -0.11364536
  -0.11493861]
 [-0.44916332 -0.3244209   2.402947   ...  0.16663085 -0.11364536
  -0.11493861]]


In [14]:
num_classes

8

In [15]:
adj_matrix

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)

In [16]:
graph

Graph(num_nodes=3460, num_edges=23512,
      ndata_schemes={'feat': Scheme(shape=(3025,), dtype=torch.float32)}
      edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)})

In [17]:
import argparse
parser = argparse.ArgumentParser(description="GAT")
parser.add_argument("--seeds", type=int, nargs="+", default=[0])
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--max_epoch", type=int, default=200,
                    help="number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=-1)

parser.add_argument("--num_heads", type=int, default=4,
                    help="number of hidden attention heads")
parser.add_argument("--num_out_heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num_layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--num_hidden", type=int, default=256,
                    help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--in_drop", type=float, default=.2,
                    help="input feature dropout")
parser.add_argument("--attn_drop", type=float, default=.1,
                    help="attention dropout")
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.005,
                    help="learning rate")
parser.add_argument("--weight_decay", type=float, default=5e-4,
                    help="weight decay")
parser.add_argument("--negative_slope", type=float, default=0.2,
                    help="the negative slope of leaky relu for GAT")
parser.add_argument("--activation", type=str, default="prelu")
parser.add_argument("--mask_rate", type=float, default=0.5)
parser.add_argument("--drop_edge_rate", type=float, default=0.0)
parser.add_argument("--replace_rate", type=float, default=0.0)

parser.add_argument("--encoder", type=str, default="gat")
parser.add_argument("--decoder", type=str, default="gat")
parser.add_argument("--loss_fn", type=str, default="byol")
parser.add_argument("--alpha_l", type=float, default=2, help="`pow`inddex for `sce` loss")
parser.add_argument("--optimizer", type=str, default="adam")

parser.add_argument("--max_epoch_f", type=int, default=30)
parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
parser.add_argument("--linear_prob", action="store_true", default=False)

parser.add_argument("--load_model", action="store_true")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--use_cfg", action="store_true")
parser.add_argument("--logging", action="store_true")
parser.add_argument("--scheduler", action="store_true", default=False)
parser.add_argument("--concat_hidden", action="store_true", default=False)

# for graph classification
parser.add_argument("--pooling", type=str, default="mean")
parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
parser.add_argument("--batch_size", type=int, default=32)

_StoreAction(option_strings=['--batch_size'], dest='batch_size', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [18]:
args = parser.parse_args([])
args.lr = 0.001
args.lr_f = 0.01
args.num_hidden = 512
args.num_heads = 4
args.weight_decay = 2e-4
args.weight_decay_f= 1e-4
args.max_epoch= 1000
args.max_epoch_f= 300
args.mask_rate= 0.5
args.num_layers= 2
args.encoder= "gat"
args.decoder= "gat" 
args.activation= "prelu"
args.in_drop= 0.2
args.attn_drop= 0.1
args.linear_prob= True
args.loss_fn= "sce" 
args.drop_edge_rate=0.0
args.optimizer= "adam"
args.replace_rate= 0.05 
args.alpha_l= 3
args.scheduler= True

In [19]:
args.dataset = "sp"

In [20]:
device = args.device if args.device >= 0 else "cpu"
seeds = args.seeds
dataset_name = args.dataset
max_epoch = args.max_epoch
max_epoch_f = args.max_epoch_f
num_hidden = args.num_hidden
num_layers = args.num_layers
encoder_type = args.encoder
decoder_type = args.decoder
replace_rate = args.replace_rate

optim_type = args.optimizer 
loss_fn = args.loss_fn

lr = args.lr
weight_decay = args.weight_decay
lr_f = args.lr_f
weight_decay_f = args.weight_decay_f
linear_prob = args.linear_prob
load_model = args.load_model
save_model = args.save_model
logs = args.logging
use_scheduler = args.scheduler

In [21]:
args.num_features = num_features

In [22]:
acc_list = []
estp_acc_list = []
for i, seed in enumerate(seeds):
    print(f"####### Run {i} for seed {seed}")
    set_random_seed(seed)

####### Run 0 for seed 0


In [23]:
if logs:
    logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
else:
    logger = None

In [24]:
model = build_model(args)

In [25]:
model

PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=3025, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=512, out_features=3025, bias=False)
        (feat_drop): Dropout(p=0.2, inplace=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head):

In [26]:
device = 1

In [27]:
model.to(device)
optimizer = create_optimizer(optim_type, model, lr, weight_decay)

if use_scheduler:
    logging.info("Use schedular")
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
    # scheduler = lambda epoch: epoch / warmup_steps if epoch < warmup_steps \
            # else ( 1 + np.cos((epoch - warmup_steps) * np.pi / (max_epoch - warmup_steps))) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

2022-09-09 21:49:29,092 - INFO - Use schedular


In [28]:
def pretrain(model, graph, feat, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
    logging.info("start training..")
    graph = graph.to(device)
    x = feat.to(device)

    epoch_iter = tqdm(range(max_epoch))

    for epoch in epoch_iter:
        model.train()

        loss, loss_dict = model(graph, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
        if logger is not None:
            loss_dict["lr"] = get_current_lr(optimizer)
            logger.note(loss_dict, step=epoch)

        #if (epoch + 1) % 200 == 0:
            #node_classification_evaluation(model, graph, x, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)

    # return best_model
    return model

In [29]:
x = graph.ndata["feat"]
if not load_model:
    model = pretrain(model, graph, x, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
    model = model.cpu()

2022-09-09 21:49:29,124 - INFO - start training..
# Epoch 125: train_loss: 0.4873:  13%|████████████████▍                                                                                                                 | 126/1000 [00:04<00:28, 30.97it/s]


KeyboardInterrupt: 

In [None]:
model

In [None]:
model.to(device)

In [None]:
graph.ndata["feat"]

In [None]:
x = graph.ndata["feat"]
embedding = model.embed(graph.to(device), x.to(device))
print(embedding)
new_pred = pred = kmeans.fit_predict(embedding)
from collections import Counter
print(Counter(new_pred))
print(Counter(gene_loc_data_csv.lay_num.values))
score = adjusted_rand_score(gene_loc_data_csv.lay_num.values, new_pred )
print(score)

In [None]:
x

In [None]:
embedding

In [None]:
#test = model.embed(graph.to(device), x.to(device))
# test_new_pred = kMeans_use(test.cpu().detach().numpy(),num_classes)
# score = adjusted_rand_score(gene_loc_data_csv.lay_num.values, test_new_pred )
# print(score)

In [None]:
gene_loc_data_csv["pre"] = new_pred

In [None]:
import pickle as pkl

import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import plotly.express as px
import pandas as pd
import scanpy as sc
import scipy as sci

def drawPicture(dataframe,col_name, row_name,colorattribute,save_file,celltype_colors =  ("#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", "#A65628", "#F781BF", "#999999", "#E41A80",
    "#377F1C", "#4DAFAE", "#984F07", "#FF7F64", "#FFFF97", "#A6568C", "#F78223", "#9999FD", "#E41AE4", "#377F80", "#4DB012",
    "#984F6B", "#FF7FC8",  "#A656F0", "#F78287", "#999A61", "#E41B48", "#377FE4", "#4DB076", "#984FCF", "#FF802C",
    "#00005F", "#A65754", "#F782EB", "#999AC5", "#E41BAC", "#378048", "#4DB0DA", "#985033", "#FF8090", "#0000C3", "#A657B8",
    "#F7834F", "#999B29", "#E41C10", "#3780AC", "#4DB13E", "#985097", "#FF80F4", "#000127", "#A6581C", "#F783B3", "#999B8D",
    "#E41C74", "#378110", "#4DB1A2", "#9850FB", "#FF8158", "#00018B", "#A65880", "#F78417", "#999BF1", "#E41CD8", "#378174",
    "#4DB206", "#98515F", "#FF81BC", "#0001EF", "#A658E4", "#F7847B", "#999C55", "#E41D3C", "#3781D8", "#4DB26A", "#9851C3",
    "#FF8220", "#000253", "#A65948", "#F784DF", "#999CB9", "#E41DA0", "#37823C", "#4DB2CE", "#985227", "#FF8284", "#0002B7",
    "#A659AC", "#F78543", "#999D1D", "#E41E04", "#3782A0", "#4DB332", "#98528B", "#FF82E8", "#00031B", "#A65A10", "#F785A7",
    "#999D81", "#E41E68", "#378304", "#4DB396", "#9852EF", "#FF834C", "#00037F", "#A65A74", "#F7860B", "#999DE5", "#E41ECC",
    "#378368", "#4DB3FA", "#985353", "#FF83B0", "#0003E3", "#A65AD8", "#F7866F", "#999E49", "#E41F30", "#3783CC", "#4DB45E",
    "#9853B7", "#FF8414", "#000447", "#A65B3C", "#F786D3", "#999EAD", "#E41F94", "#378430", "#4DB4C2", "#98541B", "#FF8478",
    "#0004AB", "#A65BA0", "#F78737", "#999F11", "#E41FF8", "#378494", "#4DB526", "#98547F", "#FF84DC", "#00050F", "#A65C04",
    "#F7879B", "#999F75"
    ),width = 1000,height = 1000,marker_size = 10,is_show = True,is_save = False,save_type = "pdf"):

    import plotly.express as px
    import plotly.graph_objects as go
    size = len(set(dataframe[colorattribute]))
    dataframe.sort_values(by = colorattribute,inplace=True, ascending=True)
    length_col = max(dataframe[col_name]) - min(dataframe[col_name])

    length_row = max(dataframe[row_name]) - min(dataframe[row_name])
    max_length = max(length_col,length_row) + 2
    fig = px.scatter(dataframe, x = col_name, y= row_name,color = colorattribute,color_discrete_sequence=celltype_colors)
    fig.update_traces(marker_size=marker_size)
    fig.update_layout(
        xaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[row_name]),   # 起始点
            dtick = max_length  # 间距
        ),
        yaxis = dict(
            tickmode = 'linear',
            tick0 = min(dataframe[col_name]),   # 起始点
            dtick = max_length  # 间距
        ),
    #     xaxis_range = [min(it_mapping_csv.row)-10,min(it_mapping_csv.row) + max_length ],
    #     yaxis_range =[min(it_mapping_csv.col)-10,min(it_mapping_csv.col) + max_length]
    )
    fig.update_layout(
        autosize=False,
        width=width,
        height = height)
    if(is_show):
        fig.show()
    if(is_save):
        if(save_type == "pdf"):
            fig.write_image(save_file)
        if(save_type == "html"):
            fig.write_html(save_file)

In [None]:
drawPicture(gene_loc_data_csv,"imagecol","imagerow",colorattribute="lay_num",save_file= None)

In [None]:
drawPicture(gene_loc_data_csv,"imagecol","imagerow",colorattribute="pre",save_file= None)

In [None]:
new_pred = kMeans_use(embedding.cpu().detach().numpy(),num_classes)
score = adjusted_rand_score(gene_loc_data_csv.lay_num.values, new_pred )

In [None]:
score

In [None]:
embedding

In [None]:
random_state = 0
new_pred =  KMeans(n_clusters=2, random_state=random_state).fit_predict(adata_X)

print(new_pred)
from collections import Counter
print(Counter(new_pred))
print(Counter(gene_loc_data_csv.lay_num.values))
score = adjusted_rand_score(gene_loc_data_csv.lay_num.values, new_pred )
print(score)
gene_loc_data_csv["before_pre"] = new_pred
drawPicture(gene_loc_data_csv,"col","row",colorattribute="before_pre",save_file= None)