In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
data_name = 'Baron'
met = "scGCL_pcc"
d_rate = 0.1

save_model_path = './'+data_name+'_'+met+str(int(d_rate*100))+'_dict'

In [4]:
import os
import torch
import random
import numpy as np
import scanpy as sc
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics.cluster import *
from scipy.optimize import linear_sum_assignment as linear_assignment


class CellDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X)
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        return self.X[index], self.y[index]


def loader_construction(data_path):
    data = sc.read_h5ad(data_path)
    X_all = data.X
    y_all = data.obs.values[:,0]
    input_dim = X_all.shape[1]
    
    X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.2, random_state=1)
    train_set = CellDataset(X_train, y_train)
    test_set = CellDataset(X_test, y_test)

    train_loader = DataLoader(dataset=train_set, batch_size=512, shuffle=True, num_workers=10)
    test_loader = DataLoader(dataset=test_set, batch_size=512, shuffle=False, num_workers=10)
    return train_loader, test_loader, input_dim


def setup_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def cluster_acc(y_true, y_pred):
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_assignment(w.max() - w)
    ind = np.array((ind[0], ind[1])).T

    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

def evaluate(y_true, y_pred):
    acc= cluster_acc(y_true, y_pred)
    f1=0
    nmi = normalized_mutual_info_score(y_true, y_pred)
    ari = adjusted_rand_score(y_true, y_pred)
    homo = homogeneity_score(y_true, y_pred)
    comp = completeness_score(y_true, y_pred)
    return acc, f1, nmi, ari, homo, comp


In [5]:
import argparse
import warnings
import torch
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm
import scanpy as sc
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

In [6]:
data_path = './'+data_name+'.h5'

In [7]:
train_loader, test_loader, input_dim = loader_construction(data_path)

In [8]:
data = sc.read_h5ad(data_path)
# X_all = data.X
y_all = data.obs.values[:,0]

In [9]:
import numpy as np
n_clusters = len(np.unique(y_all))

In [10]:
from torch_geometric.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
np.random.seed(0)
import sys
from torch import optim
from tensorboardX import SummaryWriter

# To fix the random seed
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import os
from utils import EMA, set_requires_grad, init_weights, update_moving_average, loss_fn, repeat_1d_tensor, currentTime
import copy
import pandas as pd
from data import Dataset
from embedder_layer import embedder
from utils import config2string
from embedder_layer import Encoder
import faiss
from ZINB_loss import ZINB,NB
import utils

# scGCL Model
# Revised freom Original version in AFGRL
# Ref:
# https://github.com/Namkyeong/AFGRL/tree/master/models/AFGRL.py

In [11]:
def DropData(batch_x, d_rate):
    zero_idx = torch.where(batch_x != 0, torch.ones(batch_x.shape).to(device),
                           torch.zeros(batch_x.shape).to(device))
    batch_x_nozero = torch.where(batch_x == 0, torch.zeros(batch_x.shape).to(device) - 999, batch_x)
    sample_mask = torch.rand(batch_x_nozero.shape, device=device) <= d_rate
    batch_x_drop = torch.where(sample_mask, torch.zeros(batch_x_nozero.shape).to(device), batch_x_nozero)

    final_mask = torch.where(batch_x_drop == 0, torch.ones(batch_x_drop.shape).to(device),
                             torch.zeros(batch_x_drop.shape).to(device)) * zero_idx
    final_x = torch.where(batch_x_drop == -999, torch.zeros(batch_x.shape).to(device), batch_x_drop)

    return final_mask, final_x


In [12]:
class Neighbor(nn.Module):
    def __init__(self, args):
        super(Neighbor, self).__init__()
        # self.device = args.device
        self.device = "cuda:0"
        self.num_centroids = args.num_centroids
        self.num_kmeans = args.num_kmeans
        self.clus_num_iters = args.clus_num_iters

    def __get_close_nei_in_back(self, indices, each_k_idx, cluster_labels, back_nei_idxs, k):
        # get which neighbors are close in the background set
        batch_labels = cluster_labels[each_k_idx][indices]
        top_cluster_labels = cluster_labels[each_k_idx][back_nei_idxs]
        batch_labels = repeat_1d_tensor(batch_labels, k)

        curr_close_nei = torch.eq(batch_labels, top_cluster_labels)
        return curr_close_nei

    def forward(self, adj, student, teacher, top_k, epoch):
        n_data, d = student.shape
        similarity = torch.matmul(student, torch.transpose(teacher, 1, 0).detach())
        similarity += torch.eye(n_data, device=self.device).to(device) * 10

        _, I_knn = similarity.topk(k=top_k, dim=1, largest=True, sorted=True)
        tmp = torch.LongTensor(np.arange(n_data)).unsqueeze(-1).to(self.device)

        knn_neighbor = self.create_sparse(I_knn)
        locality = knn_neighbor * adj

        ncentroids = self.num_centroids
        niter = self.clus_num_iters

        pred_labels = []
        # d_means = []
        for seed in range(self.num_kmeans):
            kmeans = faiss.Kmeans(d, ncentroids, niter=niter, gpu=False, seed=seed + 1234)
            kmeans.train(teacher.cpu().numpy())
            _, I_kmeans = kmeans.index.search(teacher.cpu().numpy(), 1)

            clust_labels = I_kmeans[:,0]
            # d_means.append(D_kmeans)
            pred_labels.append(clust_labels)
        # d_means_s = np.stack(d_means, axis=0)
        # d_means_s = np.mean(d_means_s,axis=0)
        # d_means_s = torch.from_numpy(d_means_s).float()
        # print(d_means_s.shape)
        pred_labels = np.stack(pred_labels, axis=0)
        cluster_labels = torch.from_numpy(pred_labels).float()

        all_close_nei_in_back = None
        with torch.no_grad():
            for each_k_idx in range(self.num_kmeans):
                curr_close_nei = self.__get_close_nei_in_back(tmp.squeeze(-1), each_k_idx, cluster_labels, I_knn, I_knn.shape[1])

                if all_close_nei_in_back is None:
                    all_close_nei_in_back = curr_close_nei
                else:
                    all_close_nei_in_back = all_close_nei_in_back | curr_close_nei

        all_close_nei_in_back = all_close_nei_in_back.to(self.device)

        globality = self.create_sparse_revised(I_knn, all_close_nei_in_back)

        pos_ = locality + globality

        return pos_.coalesce()._indices(), I_knn.shape[1]

    def create_sparse(self, I):
        
        similar = I.reshape(-1).tolist()
        index = np.repeat(range(I.shape[0]), I.shape[1])
        
        assert len(similar) == len(index)
        indices = torch.tensor([index, similar],dtype=torch.int32).to(self.device)
        result = torch.sparse_coo_tensor(indices, torch.ones_like(I.reshape(-1)), [I.shape[0], I.shape[0]])

        return result

    def create_sparse_revised(self, I, all_close_nei_in_back):
        n_data, k = I.shape[0], I.shape[1]

        index = []
        similar = []
        for j in range(I.shape[0]):
            for i in range(k):
                index.append(int(j))
                similar.append(I[j][i].item())

        index = torch.masked_select(torch.LongTensor(index).to(self.device), all_close_nei_in_back.reshape(-1))
        similar = torch.masked_select(torch.LongTensor(similar).to(self.device), all_close_nei_in_back.reshape(-1))

        assert len(similar) == len(index)
        indices = torch.tensor([index.cpu().numpy().tolist(), similar.cpu().numpy().tolist()]).to(self.device)
        result = torch.sparse_coo_tensor(indices, torch.ones(len(index)).to(self.device), [n_data, n_data])

        return result

In [13]:
class AFGRL(nn.Module):
    def __init__(self, layer_config, args, **kwargs):
        super().__init__()
        dec_dim = [512, 256]
        self.student_encoder = Encoder(layer_config=layer_config, dropout=args.dropout, **kwargs)
        self.teacher_encoder = copy.deepcopy(self.student_encoder)
        set_requires_grad(self.teacher_encoder, False)
        self.teacher_ema_updater = EMA(args.mad, args.epochs)
        self.neighbor = Neighbor(args)
        rep_dim = layer_config[-1]
        rep_dim_o = layer_config[0]
        self.student_predictor = nn.Sequential(nn.Linear(rep_dim, args.pred_hid), nn.BatchNorm1d(args.pred_hid), nn.ReLU(), nn.Linear(args.pred_hid, rep_dim), nn.ReLU())
        self.ZINB_Encoder = nn.Sequential(nn.Linear(rep_dim, dec_dim[0]), nn.ReLU(),
                                          nn.Linear(dec_dim[0], dec_dim[1]), nn.ReLU())
        self.pi_Encoder =  nn.Sequential(nn.Linear(dec_dim[1], rep_dim_o),nn.Sigmoid())
        self.disp_Encoder = nn.Sequential(nn.Linear(dec_dim[1], rep_dim_o), nn.Softplus())
        self.mean_Encoder = nn.Linear(dec_dim[1], rep_dim_o)
        self.student_predictor.apply(init_weights)
        self.relu = nn.ReLU()
        self.topk = args.topk
        # self._device = args.device
        self._device = "cpu"
    def clip_by_tensor(self,t, t_min, t_max):
        """
        clip_by_tensor
        :param t: tensor
        :param t_min: min
        :param t_max: max
        :return: cliped tensor
        """
        t = torch.tensor(t,dtype = torch.float32)
        t_min = torch.tensor(t_min,dtype = torch.float32)
        t_max = torch.tensor(t_max,dtype = torch.float32)

        result = torch.tensor((t >= t_min),dtype = torch.float32) * t + torch.tensor((t < t_min),dtype = torch.float32) * t_min
        result = torch.tensor((result <= t_max),dtype = torch.float32) * result + torch.tensor((result > t_max),dtype = torch.float32) * t_max
        return result

    def reset_moving_average(self):
        del self.teacher_encoder
        self.teacher_encoder = None

    def update_moving_average(self):
        assert self.teacher_encoder is not None, 'teacher encoder has not been created yet'
        update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

    def forward(self, x, y, edge_index, neighbor, edge_weight=None, epoch=None):
        student = self.student_encoder(x=x, edge_index=edge_index, edge_weight=edge_weight)
        # student_ = self.student_encoder(x=x2, edge_index=edge_index_2, edge_weight=edge_weight_2)
        pred = self.student_predictor(student)
        # pred_ = self.student_predictor(student_)
        z = self.ZINB_Encoder(student)
        pi = self.pi_Encoder(z)
        disp = self.disp_Encoder(z)
        disp = self.clip_by_tensor(disp,1e-4,1e4)
        mean = self.mean_Encoder(z)
        mean = self.clip_by_tensor(torch.exp(mean),1e-5,1e6)
        modify = 0
        with torch.no_grad():
            teacher = self.teacher_encoder(x=x, edge_index=edge_index, edge_weight=edge_weight)
            # teacher_ = self.teacher_encoder(x=x2, edge_index=edge_index_2, edge_weight=edge_weight_2)
        if edge_weight == None:
            adj = torch.sparse.FloatTensor(neighbor[0], torch.ones_like(neighbor[0][0]), [x.shape[0], x.shape[0]])
        else:
            adj = torch.sparse.FloatTensor(neighbor[0], neighbor[1], [x.shape[0], x.shape[0]])
        #
        ind, k = self.neighbor(adj, F.normalize(student, dim=-1, p=2), F.normalize(teacher, dim=-1, p=2), self.topk, epoch)
        zinb = ZINB(pi, theta=disp, ridge_lambda=0, debug=False)
        zinb_loss = zinb.loss(x, mean, mean=True)
        # adj_recon = torch.matmul(z,z.T)
        loss1 = loss_fn(pred[ind[0]], teacher[ind[1]].detach())
        loss2 = loss_fn(pred[ind[1]], teacher[ind[0]].detach())
        # loss1 = loss_fn(pred, teacher_.detach())
        # loss2 = loss_fn(pred_, teacher.detach())
        recon_loss = torch.nn.MSELoss(reduction='mean')
        recon_loss_ = recon_loss(x,student)
        # adj_recon_ = recon_loss(adj.to_dense(),adj_recon)
        loss_reforce = (loss1 + loss2)
        if modify == 0:
            loss = zinb_loss + loss_reforce + recon_loss_
        elif modify == 1:
            loss = loss_reforce + recon_loss_
        elif modify == 2:
            loss = zinb_loss
        return student, loss.mean()



In [14]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--embedder", type=str, default="AFGRL")
    parser.add_argument("--dataset", type=str, default="adam", help="Name of the dataset. Supported names are: wikics, cs, computers, photo, and physics")

    parser.add_argument('--checkpoint_dir', type=str, default = './model_checkpoints', help='directory to save checkpoint')
    parser.add_argument("--root", type=str, default="data")
    parser.add_argument("--task", type=str, default="clustering", help="Downstream task. Supported tasks are: node, clustering, similarity")
    
    parser.add_argument("--layers", nargs='?', default='[2048]', help="The number of units of each layer of the GNN. Default is [256]")
    parser.add_argument("--pred_hid", type=int, default=2048, help="The number of hidden units of layer of the predictor. Default is 512")
    
    parser.add_argument("--topk", type=int, default=4, help="The number of neighbors to search")
    parser.add_argument("--clus_num_iters", type=int, default=20)
    parser.add_argument("--num_centroids", type=int, default=9, help="The number of centroids for K-means Clustering")
    parser.add_argument("--num_kmeans", type=int, default=5, help="The number of K-means Clustering for being robust to randomness")
    parser.add_argument("--eval_freq", type=float, default=5, help="The frequency of model evaluation")
    parser.add_argument("--mad", type=float, default=0.9, help="Moving Average Decay for Teacher Network")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")    
    parser.add_argument("--es", type=int, default=300, help="Early Stopping Criterion")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=300)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--aug_params", "-p", nargs="+", default=[0.3, 0.4, 0.3, 0.2],help="Hyperparameters for augmentation (p_f1, p_f2, p_e1, p_e2). Default is [0.2, 0.1, 0.2, 0.3]")
    args = parser.parse_args(args=[])
    return args



In [15]:
args = parse_args()

In [16]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors

In [17]:
class Model(nn.Module):
    def __init__(self, layer_config, args):
        super(Model, self).__init__()
        self.afgrl = AFGRL(layer_config, args)

    def forward(self, x, y):
        data_normalized = x.cpu().detach().numpy()
        similarity_matrix = cosine_similarity(data_normalized)

        k = 15  # 设置k值
        nbrs = NearestNeighbors(n_neighbors=k+1, metric='cosine').fit(data_normalized)
        distances, indices = nbrs.kneighbors(data_normalized)

        edges = []
        for i in range(indices.shape[0]):
            for j in range(1, k+1):  
                edges.append((i, indices[i, j]))

        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device)
        neighbor = [edge_index,]
        x_hat, loss = self.afgrl(x, y, edge_index, neighbor, edge_weight=None, epoch=None)
        
        return x_hat, loss

In [18]:
def l1_distance(imputed_data, original_data):

    return np.mean(np.abs(original_data-imputed_data))

def RMSE(imputed_data, original_data):
    return np.sqrt(np.mean((original_data - imputed_data)**2))


def pearson_corr(imputed_data, original_data):
    Y = original_data
    fake_Y = imputed_data
    fake_Y, Y = fake_Y.reshape(-1), Y.reshape(-1)
    fake_Y_mean, Y_mean = np.mean(fake_Y), np.mean(Y)
    corr = (np.sum((fake_Y - fake_Y_mean) * (Y - Y_mean))) / (
            np.sqrt(np.sum((fake_Y - fake_Y_mean) ** 2)) * np.sqrt(np.sum((Y - Y_mean) ** 2)))
    return corr



In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')

In [20]:
def train(train_loader,
          test_loader,
          input_dim,
          lr,
          seed,
          epochs,
          device):
    
    layer_config = [input_dim] + [input_dim//2,input_dim//4,input_dim//2,input_dim] 
    model = Model(layer_config, args).to(device)

    opt_model = torch.optim.Adam(model.parameters(), lr=lr)

    setup_seed(seed)
    train_loss = []
    min_loss = 99999
    best_epoch = 0
    np.set_printoptions(threshold=np.inf)
    np.set_printoptions(precision=2)
    np.set_printoptions(suppress=True)
    
    x_imp_list = []
    for each_epoch in range(epochs):
        batch_loss = []
        model.train()

        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            final_mask, final_x = DropData(batch_x, d_rate)

            x_imp, loss = model(final_x, batch_y)
            
            opt_model.zero_grad()
            loss.backward()
            opt_model.step()

            batch_loss.append(loss.cpu().detach().numpy())
        
        train_loss.append(np.mean(np.array(batch_loss)))
        with torch.no_grad():
            batch_x_imp = []
            batch_loss = []
            model.eval()

            for step, (batch_x, batch_y) in enumerate(test_loader):
                batch_x = batch_x.float().to(device)
                batch_y = batch_y.float().to(device)
                final_mask, final_x = DropData(batch_x, d_rate)

                x_imp, loss = model(final_x, batch_y)
                batch_x_imp.append(x_imp)
                batch_loss.append(loss.cpu().detach().numpy())
         
        x_imp_list.append(torch.cat(batch_x_imp).cpu().detach().numpy())
        cur_loss = np.mean(np.array(batch_loss))
        if cur_loss < min_loss:
            min_loss = cur_loss
            best_epoch = each_epoch
            
    return np.array(x_imp_list[best_epoch])
                
        

In [None]:
x_imp = train(train_loader, test_loader, input_dim, lr=0.0001, seed=1, epochs=200, device=device)