## test

In [None]:
import numpy as np
import cvxpy as cp
import scipy
from scipy.io import loadmat
from scipy.sparse import diags
from scipy.sparse import coo_matrix
from scipy.sparse import csr_matrix
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import eigsh
from scipy.linalg import expm
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, accuracy_score
from sklearn.preprocessing import MinMaxScaler
import math
import time
import warnings

warnings.filterwarnings('ignore')
scaler = MinMaxScaler()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
device1 = torch.device('cpu')
torch.autograd.set_detect_anomaly(True)

print(cp.installed_solvers())


def error_point(prep, real):
    prep_0 = prep
    error_point = np.array([], dtype=int)
    prep_full_label = np.setdiff1d(np.unique(prep), np.array([-1]))
    real_full_label = np.setdiff1d(np.unique(real), np.array([-1]))
    nonnoise_index = np.intersect1d(
        np.where(prep != -1)[0],
        np.where(real != -1)[0])
    real = real[nonnoise_index]
    prep = prep[nonnoise_index]
    real_label = np.unique(real)
    prep_label = np.unique(prep)
    n = len(real_label)
    n_1 = len(prep_label)
    reallogic = (np.reshape(np.repeat(real, n), [len(real), n])
                 == real_label).T + 0
    preplogic = (np.reshape(np.repeat(prep, n_1), [len(prep), n_1])
                 == prep_label).T + 0
    interset_matrix = reallogic @ preplogic.T
    x = cp.Variable((n, n_1), integer=True)
    obj = cp.Minimize(-cp.sum(cp.multiply(interset_matrix, x)))
    con = [
        0 <= x, x <= 1,
        cp.sum(x, axis=0, keepdims=True) == 1,
        cp.sum(x, axis=1, keepdims=True) <= 1
    ]
    prob = cp.Problem(obj, con)
    prob.solve('GLPK_MI')
    index = np.array(np.where(x.value == 1))
    # print(real.size,real_full_label,prep_full_label,real_label,prep_label,x.value)
    # print(len(real_label),reallogic.shape,preplogic.shape,interset_matrix.shape)
    # print(    [np.setdiff1d(real_full_label,
    #                     real_label[index[0,:]]),
    #        np.setdiff1d(prep_full_label,
    #                     prep_label[index[1,:]])])
    add_index = np.array([
        np.setdiff1d(real_full_label, real_label[index[0, :]]),
        np.setdiff1d(prep_full_label, prep_label[index[1, :]])
    ],
                         dtype=int)
    # print(index,add_index)
    prep0 = np.setdiff1d(prep_full_label, prep_label[index[1, :]])
    # print(prep0, n, n_1)
    index = np.concatenate((index, add_index), axis=1)

    related_index = []
    for i in range(n):
        if i < n_1:
            real_iter_index = np.where(real == np.unique(real)[index[0, i]])[0]
            prep_iter_index = np.where(prep == np.unique(prep)[index[1, i]])[0]
            pp_index = np.where(prep_0 == np.unique(prep)[index[1, i]])[0]
            related_index.append(pp_index)
        else:
            real_iter_index = np.where(real == np.unique(real)[index[0, i]])[0]
            prep_iter_index = np.where(prep == np.unique(prep0)[i - n_1])[0]
            pp_index = np.where(prep_0 == np.unique(prep0)[i - n_1])[0]
            related_index.append(pp_index)
        error_point_i = np.setdiff1d(real_iter_index, prep_iter_index)
        error_point = np.union1d(error_point, error_point_i)
    for i in range(n):
        prep_0[related_index[i]] = np.ones(len(
            related_index[i])) * np.unique(real)[index[0, i]]
    error_point = nonnoise_index[error_point]
    # error_point = np.union1d(error_point,np.where(prep0==-1)[0])
    # error_point = np.union1d(error_point,np.where(real0==-1)[0])
    return error_point, prep_0


class load_data(Dataset):

    def __init__(self, dataset):
        data = loadmat(dataset)
        X = data['data'][0]
        self.y = data['truelabel'][0][0].reshape(-1)
        self.x = []
        if isinstance(X[0], csr_matrix) or isinstance(
                X[0], csc_matrix) or isinstance(X[0], coo_matrix):
            for i in range(len(X)):
                self.x.append(X[i].toarray())
        else:
            for i in range(len(X)):
                self.x.append(X[i])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.from_numpy(np.array(self.x[idx])), torch.from_numpy(
            np.array(self.y[idx])), torch.from_numpy(np.array(idx))


def torch_intersect1d(t1: torch.Tensor, t2: torch.Tensor):
    # NOTE: requires t1, t2 to be unique 1D Tensor in advance.
    # Method: based on unique's count
    num_t1, num_t2 = t1.numel(), t2.numel()
    u, inv, cnt = torch.unique(torch.cat([t1, t2]),
                               return_counts=True,
                               return_inverse=True)

    cnt_12 = cnt[inv]
    cnt_t1, cnt_t2 = cnt_12[:num_t1], cnt_12[num_t1:]
    m_t1 = (cnt_t1 == 2)
    inds_t1 = m_t1.nonzero()[..., 0]
    inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
    inds_t2_exclusive = (cnt_t2 == 1).nonzero()[..., 0]

    intersection = t1[inds_t1]
    t1_exclusive = t1[inds_t1_exclusive]
    t2_exclusive = t2[inds_t2_exclusive]
    return intersection, t1_exclusive, t2_exclusive


def NormalizeData(data):
    data = data.T
    if np.sum(np.sum(data**2, axis=0) < 0) > 0:
        print((data**2)[np.where(np.sum(data**2, axis=0) < 0)[0], :])
        print(np.where(np.sum(np.abs(data), axis=0) < 0)[0])
        mm = np.maximum(np.sum(np.abs(data), axis=0), 10**-14)

        data = data * diags(mm**-1, 0)
    else:
        mm = np.maximum(np.sum(data**2, axis=0), 10**-14)
        data = data * diags(mm**-0.5, 0)
    return data.T


def comnFun(K, sigma):
    nSmp = K[0].shape[0]
    view_num = len(K)
    KC = np.zeros([nSmp, nSmp])
    for i in range(view_num):
        KC = KC + sigma[i] * K[i]
    return KC


def kcenter(K):
    n = K.shape[0]
    D = np.sum(K, axis=0) / n
    E = np.sum(D) / n
    J = D.reshape([n, 1])
    K = K - J - J.T + E * np.ones([n, n])
    K = 0.5 * (K + K.T)
    return K


def kernel_regularization(K):
    G = K + K.T / 2
    G = G.detach().cpu().numpy()
    D, V = np.linalg.eig(G)
    D = D * (D > 10**-14)
    G = (V * diags(D, 0)) @ V.T
    G = (G + G.T) / 2
    return torch.tensor(G).to(device)


def generateNeighborhood(X, k):
    X = np.copy(X)
    knn = []
    min_X = np.min(X)
    for i in range(k + 1):
        index = np.argmax(X, axis=1)
        knn.append(index)
        # 将近邻元素设置为比对角元素大的值，以便寻找反向k近邻
        X[np.arange(X.shape[0]), index] = np.array([min_X - 1] * X.shape[0])
    return X < min_X, np.array(knn, dtype=np.int32).T


def adapted_neighbor(X, min_k):
    X = np.copy(X)
    min_X = np.min(X)
    for knn_k in range(int(0.5 * X.shape[0])):
        index = np.argmax(X, axis=1)
        X[np.arange(X.shape[0]), index] = np.array([min_X - 1] * X.shape[0])
        logic_nn = X < min_X
        logic_nn = logic_nn * logic_nn.T
        if np.mean(np.sum(logic_nn, axis=1)) > min_k:
            print((np.sum(np.sum(logic_nn, axis=1)) / X.shape[0]), knn_k)
            return X < min_X
        if knn_k >= max(int(0.02 * X.shape[0]), 5):
            print((np.sum(np.sum(logic_nn, axis=1)) / X.shape[0]), knn_k,
                  max(int(0.01 * X.shape[0]), 5))
            return X < min_X
    return X < min_X


def knn_mean_X(X, knn):
    knn_X = X[knn]
    return np.sum(knn_X, axis=1) / np.maximum(np.sum(knn_X > 0, axis=1),
                                              np.ones(X[0].shape))


def compute_init_loss(K, view_K, X_bar, model):
    loss0 = 0
    loss1 = 0
    loss_param = 0
    criterion = nn.MSELoss()
    for i in range(len(view_K)):
        for param in model[i].parameters():
            if param is model[i].kernel_weights:
                continue
            loss_param += torch.norm(param, p=1) / max(param.size())
        loss0 = loss0 + criterion(view_K[i].view(-1), K[i].view(-1))
        loss1 = loss1 + torch.sum(
            torch.log1p(X_bar[i]) + torch.log1p(1 - X_bar[i]))
    return [loss0, loss1, loss_param]


def kernelkmeans(K, n_clusters):
    _, H = eigsh(K, k=n_clusters, which='LA')
    # U, _, _ = np.linalg.svd(K)
    # H = U[:, 0:n_clusters]
    return H


def data_sample(X, K, Smp_num):
    n_total = X[0].shape[0]
    Smp_index = torch.randint(0, n_total, [
        Smp_num,
    ])
    Smp_X = []
    Smp_K = []
    grid = torch.meshgrid(Smp_index, Smp_index)
    for i in range(len(X)):
        n_feature = X[i].shape[1]
        Smp_X.append(K[i][Smp_index, :])
        Smp_K.append(K[i][grid[0], grid[1]])
    return Smp_X, Smp_K, Smp_index


def selfNN(X, options):
    nozero_number = torch.minimum(
        torch.ones(X.shape[1], dtype=torch.int32, device=device) *
        min([4, int(X.shape[1] / options.n_clusters)]), torch.sum(X > 0,
                                                                  dim=1))
    delete_index = torch.sort(X, dim=1).values[torch.arange(X.shape[1]),
                                               -nozero_number]
    return (1. * (X >= delete_index) +
            (torch.eye(X.shape[1], dtype=torch.int32, device=device))) > 0


class connectivity(nn.Module):

    def __init__(self, options):
        super(connectivity, self).__init__()
        self.options = options
        self.activ = torch.nn.Sigmoid()
        self.activ1 = torch.nn.ReLU()
        self.activ2 = torch.nn.Tanh()
        self.matrix = torch.nn.Parameter(
            1 / (options.nSmp) *
            torch.zeros([options.nSmp, options.nSmp], dtype=torch.float32))
        self.layerencode = nn.ModuleList([
            nn.Linear(options.layer_width_c[i],
                      options.layer_width_c[i + 1],
                      bias=False,
                      dtype=torch.float32)
            for i in range(len(options.layer_width_c) - 1)
        ])

        for i in range(len(options.layer_width_c) - 1):
            nn.init.kaiming_normal_(self.layerencode[i].weight,
                                    nonlinearity='relu')
            # nn.init.orthogonal_(self.layerencode[i].weight.T)

    def forward(self, knn_list):
        X = []
        for i in range(self.options.view_num):
            X.append(knn_list[i])
            for j in range(len(self.layerencode)):
                if j < len(self.layerencode) - 1:
                    X[i] = self.activ1(X[i] @ self.layerencode[j].weight.T)
                else:
                    X[i] = X[i] @ self.layerencode[j].weight.T

        return X


def init_Kernel_train(y, view_K, options):
    knn_list = []
    optimizers = {}
    init_maxiter = options.maxIter
    # print('0', torch.cuda.memory_allocated() / 2**20)
    average_K = comnFun(view_K, np.ones(options.view_num) / options.view_num)
    np.savetxt('KNN_K.npy', average_K )
    for i in range(options.view_num):
        knn_list.append(
            torch.tensor(view_K[i], dtype=torch.float32, device=device))
    connect_kernel = connectivity(options).to(device)

    optimizers["optimizer_connectivity"] = optim.Adam(
        [i.weight for i in connect_kernel.layerencode], lr=1e-3)
    singular_value = [
        torch.ones(options.n_clusters,
                   dtype=torch.float32,
                   device=device,
                   requires_grad=True) for i in range(options.view_num)
    ]
    optimizers["optimizer_singular"] = optim.Adam(singular_value, lr=1e-3)

    best_label = np.zeros([3, options.nSmp], dtype=np.int32)
    best_accurancy = [0] * 3
    knn_temp_list = []
    knn_temp_max_list = []
    knn_temp_min_list = []
    # print('K', torch.cuda.memory_allocated() / 2**20)
    for i in range(options.view_num):
        logic_knn = torch.tensor(1. * options.knn[i],
                                 dtype=torch.float32,
                                 device=device)
        knn_temp_list.append(knn_list[i] * (1. *
                                            (logic_knn + logic_knn.T) > 0))
        # print('K' + str(i), torch.cuda.memory_allocated() / 2**20)
        knn_temp_min_list.append(knn_list[i] * (1. *
                                                (logic_knn * logic_knn.T) > 0))
        # print('K_min' + str(i), torch.cuda.memory_allocated() / 2**20)
        del logic_knn
        torch.cuda.empty_cache()
        D = ((torch.maximum(
            torch.abs(torch.sum(knn_temp_min_list[i], dim=1)),
            torch.ones(knn_temp_min_list[i].shape[0],
                       dtype=torch.float32).to(device) *
            1e-15)**-1)**.5).unsqueeze(1)
        knn_temp_min_list[i] = knn_temp_min_list[i] * D * D.T
        # print('K_min' + str(i), torch.cuda.memory_allocated() / 2**20)
        D = ((torch.maximum(
            torch.abs(torch.sum(knn_temp_list[i], dim=1)),
            torch.ones(knn_temp_list[i].shape[0],
                       dtype=torch.float32).to(device) *
            1e-15)**-1)**.5).unsqueeze(1)
        knn_temp_list[i] = knn_temp_list[i] * D * D.T
        # print('K' + str(i), torch.cuda.memory_allocated() / 2**20)
        knn_temp_max_list.append(knn_temp_list[i] @ knn_temp_list[i])
        # print('K_max' + str(i), torch.cuda.memory_allocated() / 2**20)
        knn_temp_max_list[i] = torch.maximum(
            torch.minimum(knn_temp_max_list[i] * 1. * (knn_temp_list[i] > 0),
                          knn_temp_list[i]), knn_temp_max_list[i])
        # print('K_max' + str(i), torch.cuda.memory_allocated() / 2**20)
        print(
            torch.sum(torch.sum(knn_temp_list[i] > 0, dim=1)) / options.nSmp,
            torch.sum(torch.sum(knn_temp_min_list[i] > 0, dim=1)) /
            options.nSmp,
            torch.sum(torch.sum(knn_temp_max_list[i] > 0, dim=1)) /
            options.nSmp)
        # print('3K' + str(i), torch.cuda.memory_allocated() / 2**20)
    del D, knn_list
    torch.cuda.empty_cache()
    if options.alpha == 0:
        del knn_temp_min_list
        torch.cuda.empty_cache()
    if options.beta == 0:
        del knn_temp_max_list
        torch.cuda.empty_cache()
    # print('init', torch.cuda.memory_allocated() / 2**20)
    old_loss = 1e15
    old_rank = 0
    old_inter_loss = -1e15
    object_function = np.zeros(options.maxIter - options.init_time)
    for epoch in range(init_maxiter):
        loss_init = 0
        loss_list = []
        reg_loss = 0
        start = time.time()
        X_view = connect_kernel(knn_temp_list)
        for j in range(len(options.layer_width_c) - 1):
            reg_loss += 0.5 * torch.norm(
                connect_kernel.layerencode[j].weight)**2
        print(reg_loss.item())
        # print('H', torch.cuda.memory_allocated() / 2**20)
        orx_loss = []
        connected_matrix = (connect_kernel.matrix +
                            connect_kernel.matrix.T) / 2
        # print('K_star', torch.cuda.memory_allocated() / 2**20)
        view_loss = []
        connect_loss = []
        view_loss_temp = []
        view_loss_temp_1 = []
        eta = []
        X_view_orth_grad = torch.zeros_like(X_view[0],
                                            device=device,
                                            dtype=torch.float32)
        X_view_grad = torch.zeros_like(X_view[0],
                                       device=device,
                                       dtype=torch.float32)
        print('forward_time', time.time() - start)
        for i in range(options.view_num):
            # X_view[i].retain_grad()
            start = time.time()
            connect_view_k = ((X_view[i] * singular_value[i]) @ X_view[i].T)
            # print('K_c', torch.cuda.memory_allocated() / 2**20)
            # orx_loss.append(
            #     torch.norm((or_matrix) - torch.eye(
            #         options.n_clusters, dtype=torch.float16).to(device)) +
            #     torch.norm(connect_view_k_power - connect_view_k) +
            #     torch.norm(connect_view_k - connect_view_k_pos))
            orx_loss.append(
                torch.norm(X_view[i].T @ X_view[i] - torch.eye(
                    options.n_clusters, dtype=torch.float32, device=device))**
                2)
            # print('orx_loss', torch.cuda.memory_allocated() / 2**20)
            X_view_orth_grad = 4 * (X_view[i] @ X_view[i].T @ X_view[i] -
                                    X_view[i])
            # print('X_view_orth_grad', torch.cuda.memory_allocated() / 2**20)
            # view_loss_temp_i = 1 / (1 + options.alpha +
            #                         options.beta) * torch.norm(
            #                             (knn_temp_list[i] - connect_view_k))**2

            # view_loss_temp.append(1 / (1 + options.alpha + options.beta) *
            #                       torch.norm(
            #                           (knn_temp_list[i] - connect_view_k))**2)
            view_loss_temp.append(
                1 / (1 + options.alpha + options.beta) *
                torch.trace(X_view[i].T @ (X_view[i] * singular_value[i])
                            @ X_view[i].T @ (X_view[i] * singular_value[i]) -
                            2 * X_view[i].T @ knn_temp_list[i] @ (
                                X_view[i] * singular_value[i])))
            # print('view_loss_temp', torch.cuda.memory_allocated() / 2**20)
            X_view_grad = 2 / (1 + options.alpha + options.beta) * (
                2 * connect_view_k - knn_temp_list[i] -
                knn_temp_list[i].T) @ (X_view[i] * singular_value[i])
            # print('X_view_grad', torch.cuda.memory_allocated() / 2**20)
            if options.alpha > 0:
                # view_loss_temp_1.append(
                #     options.alpha / (1 + options.alpha + options.beta) *
                #     torch.norm(knn_temp_min_list[i] - (
                #         (X_view[i] * singular_value[i]**0.5) @ X_view[i].T))**
                #     2)
                view_loss_temp_1.append(
                    options.alpha / (1 + options.alpha + options.beta) *
                    torch.trace(
                        X_view[i].T @ (X_view[i] * singular_value[i]**.5)
                        @ X_view[i].T @ (X_view[i] * singular_value[i]**.5) -
                        2 * X_view[i].T @ knn_temp_min_list[i] @ (
                            X_view[i] * singular_value[i]**.5)))
                # print('view_loss_temp_1',
                #       torch.cuda.memory_allocated() / 2**20)
                X_view_grad += 2 * options.alpha / (
                    1 + options.alpha + options.beta) * (
                        2 *
                        (X_view[i] * singular_value[i]**0.5) @ X_view[i].T -
                        knn_temp_min_list[i] - knn_temp_min_list[i].T) @ (
                            X_view[i] * singular_value[i]**.5)
            if options.beta > 0:
                view_loss_temp_1.append(
                    options.alpha / (1 + options.alpha + options.beta) *
                    torch.trace(
                        X_view[i].T @ (X_view[i] * singular_value[i]**2)
                        @ X_view[i].T @ (X_view[i] * singular_value[i]**2) -
                        2 * X_view[i].T @ knn_temp_max_list[i] @ (
                            X_view[i] * singular_value[i]**2)))

                X_view_grad += 2 * options.beta / (
                    1 + options.alpha + options.beta) * (
                        2 * (X_view[i] * singular_value[i]**2) @ X_view[i].T -
                        knn_temp_max_list[i] -
                        knn_temp_max_list[i].T) @ (X_view[i] *
                                                   (singular_value[i]**2))
            # print('before_connect' + str(i),
            #       torch.cuda.memory_allocated() / 2**20)
            print('view_time:', time.time() - start)
            if (epoch >= options.init_time):
                start = time.time()
                diag = (torch.diag(
                    connect_view_k +
                    torch.relu(connected_matrix)).detach().unsqueeze(1))**.5
                # print(diag.dtype)
                # print('diag', torch.cuda.memory_allocated() / 2**20)
                connectedness_matrix = (
                    (connect_view_k + torch.relu(connected_matrix)) / diag /
                    diag.T).detach().clone()
                del diag
                torch.cuda.empty_cache()
                connectedness = torch.sum(connectedness_matrix,
                                          dim=1).unsqueeze(1)

                neighbor_peak = torch.tensor(
                    1, device=device,
                    dtype=torch.float32) * (connectedness.repeat(
                        1, options.nSmp) < connectedness.T)

                torch.cuda.empty_cache()
                # print('neighbor_peak', torch.cuda.memory_allocated() / 2**20)
                neighbor_peak = neighbor_peak * (knn_temp_list[i] > 0)
                neighbor_peak_index = torch.where(
                    torch.sum(neighbor_peak, dim=1) == 0)[0]
                # print('neighbor_peak_index',
                #       torch.cuda.memory_allocated() / 2**20)
                parents_index = neighbor_peak * (connectedness_matrix)
                if len(neighbor_peak_index) < options.n_clusters:
                    all_index = torch.arange(options.nSmp,
                                             dtype=torch.int32,
                                             device=device)
                    rest_index = torch_intersect1d(all_index,
                                                   neighbor_peak_index)[1]
                    rest_index = (torch.sort(connectedness[rest_index])[1]
                                  [-options.n_clusters +
                                   len(neighbor_peak_index):]).squeeze(1)
                    neighbor_peak_index = torch.concatenate(
                        [neighbor_peak_index, rest_index], dim=0)
                    print("_______________", len(neighbor_peak_index),
                          len(rest_index))
                    del rest_index, all_index
                del neighbor_peak, connectedness_matrix, connectedness
                torch.cuda.empty_cache()
                parents_index[neighbor_peak_index,
                              neighbor_peak_index] = torch.ones(
                                  len(neighbor_peak_index),
                                  device=device,
                                  dtype=torch.float32)
                del neighbor_peak_index
                torch.cuda.empty_cache()
                # print('parents_index', torch.cuda.memory_allocated() / 2**20)
                D = ((torch.maximum(
                    torch.abs(torch.sum(parents_index, dim=1)),
                    torch.ones(knn_temp_list[i].shape[0],
                               device=device,
                               dtype=torch.float32) *
                    1e-15))).detach().unsqueeze(1)
                # print('D', torch.cuda.memory_allocated() / 2**20)
                parents_index = (parents_index / D)
                print('peak_time:', time.time() - start)
                # connect_loss.append(
                #     torch.norm(parents_index - (X_view[i] @ X_view[i].T))**2)
                connect_loss.append(1 * torch.trace(
                    X_view[i].T @ X_view[i] @ X_view[i].T @ X_view[i] -
                    X_view[i].T @ (parents_index + parents_index.T) @ X_view[i]
                ))
                X_view_grad += 1 * 2 * options.gamma * (
                    2 * (X_view[i] @ X_view[i].T) - parents_index -
                    parents_index.T) @ X_view[i]
                del parents_index, D
                torch.cuda.empty_cache()
                connect_loss.append(
                    options.con * (1 + options.gamma) / options.gamma *
                    torch.trace(
                        X_view[i].T @ (X_view[i] * singular_value[i])
                        @ X_view[i].T @ (X_view[i] * singular_value[i]) -
                        2 * X_view[i].T @ connected_matrix.detach().clone() @ (
                            X_view[i] * singular_value[i])))
                X_view_grad += 4 * options.con * (1 + options.gamma) * (
                    connect_view_k - connected_matrix) @ (X_view[i] *
                                                          singular_value[i])

                if torch.trace(connect_view_k) > options.n_clusters:
                    connect_loss.append(
                        (1 + options.gamma) / options.gamma *
                        (torch.trace(
                            X_view[i].T @ (X_view[i] * singular_value[i])) -
                         options.n_clusters))
                    X_view_grad += 2 * (1 + options.gamma) * (
                        X_view[i] * singular_value[i])
                elif torch.trace(connect_view_k) < options.n_clusters:
                    connect_loss.append(
                        (1 + options.gamma) / options.gamma *
                        (options.n_clusters - torch.trace(
                            X_view[i].T @ (X_view[i] * singular_value[i]))))
                    X_view_grad -= 2 * (1 + options.gamma) * (
                        X_view[i] * singular_value[i])
                print("peak", torch.cuda.memory_allocated() / 2**20)
                start = time.time()
                norm_1 = torch.norm(X_view_grad, dim=0).detach().clone()
                norm_2 = torch.norm(X_view_orth_grad, dim=0).detach().clone()
                cos_value = torch.mean(
                    torch.abs(torch.diag(X_view_orth_grad.T @ X_view_grad)) /
                    norm_1 / norm_2).item()
                if cos_value > options.eta:
                    eta.append(
                        min(options.eta / cos_value * (1 + options.gamma),
                            min(norm_1 / norm_2)))
                else:
                    eta.append(min(1 + options.gamma, min(norm_1 / norm_2)))
                print('eta_time:', time.time() - start)
        del X_view_grad, X_view_orth_grad
        torch.cuda.empty_cache()
        view_loss.append(torch.stack(view_loss_temp))
        if len(view_loss_temp_1) > 0:
            view_loss.append(torch.stack(view_loss_temp_1))
        del view_loss_temp, view_loss_temp_1
        torch.cuda.empty_cache()
        start = time.time()
        orx_loss = torch.stack(orx_loss)
        if len(connect_loss) > 0:
            eta = torch.tensor(eta, device=device, dtype=torch.float32)
            view_loss = torch.stack(view_loss)
            connect_loss = torch.stack(connect_loss)
            if torch.isnan(torch.sum(view_loss) + torch.sum(connect_loss)):
                print(view_loss, connect_loss)
                break
            loss_init = torch.sum(eta * orx_loss) + torch.sum(
                view_loss) + options.gamma * torch.sum(connect_loss)
            loss_list.append([
                orx_loss,
                view_loss,
                connect_loss,
            ])

            # for optimizer in optimizers.values():
            #     optimizer.zero_grad()
            if ((torch.sum(orx_loss)
                 < 1e-3 * options.view_num * options.n_clusters) and
                (abs(old_loss.item() - loss_init.item()) < abs(
                    loss_init.item()) * 1e-5)) or epoch == options.maxIter - 1:
                print(torch.sum(orx_loss),
                      1e-3 * options.view_num * options.n_clusters**2)
                H = kernelkmeans((connected_matrix).detach().to(
                    torch.float32).cpu().numpy(), options.n_clusters)
                # H_normalized = H
                H_normalized = H / (np.sum(H**2, axis=1)**0.5).reshape(
                    [H.shape[0], 1])
                kmeans_model = KMeans(n_clusters=options.n_clusters,
                                      n_init='auto')

                repeat = 50
                best_inertia = np.zeros([3, repeat])
                for rep in range(repeat):
                    kmeans = kmeans_model.fit(H_normalized)
                    try:
                        y_prep = error_point(kmeans.labels_, y)[1]
                    except Exception as e:
                        y_prep = kmeans.labels_
                    y_prep = y_prep.astype('int')
                    ari = adjusted_rand_score(y, y_prep)
                    nmi = normalized_mutual_info_score(y, y_prep)
                    acc = accuracy_score(y, y_prep)
                    if ari > best_accurancy[0]:
                        best_label[0, :] = y_prep
                    if nmi > best_accurancy[1]:
                        best_label[1, :] = y_prep
                    if acc > best_accurancy[2]:
                        best_label[2, :] = y_prep
                    best_inertia[:, rep] = np.array([ari, nmi, acc])
                print(np.std(best_inertia, axis=1), np.max(best_inertia,
                                                           axis=1),
                      np.min(best_inertia, axis=1))
                best_inertia = (np.max(best_inertia, axis=1)).tolist()
                best_accurancy = np.maximum(best_inertia,
                                            best_accurancy).tolist()

                return best_accurancy, best_label, best_inertia, np.array(
                    object_function)
        else:
            view_loss = torch.stack(view_loss)
            loss_init = torch.sum(view_loss)
            loss_list.append([orx_loss, view_loss])
            connect_kernel.matrix = torch.nn.Parameter(
                sum([((X_view[i] * singular_value[i])
                      @ X_view[i].T).detach().clone()
                     for i in range(options.view_num)]) / options.view_num)
            # print('K_star', torch.cuda.memory_allocated() / 2**20)
            optimizers["optimizer_matrix"] = optim.Adam(
                [connect_kernel.matrix], lr=1e-3)
        if torch.isnan(loss_init):
            print(loss_list)
            print((torch.diag(connect_view_k + connected_matrix)**0.5
                   ).unsqueeze(1).detach().clone(), )
            break

        # back_time = time.time()
        # print("back_time---------------------", time.time() - back_time)
        (loss_init + options.regular * reg_loss).backward()
        if len(connect_loss) > 0:
            object_function[epoch - options.init_time] = loss_init.item(
            ) + options.view_num * options.gamma * torch.trace(
                connected_matrix @ connected_matrix).item()
            # back_time1 = time.time()
            # torch.sum(view_loss).backward(retain_graph=True)
            # print(time.time() - back_time1)
            # back_time1 = time.time()
            # torch.sum(eta * orx_loss).backward(retain_graph=True)
            # print(time.time() - back_time1)
            # back_time1 = time.time()
            # (options.gamma *
            #  torch.sum(connect_loss)).backward(retain_graph=False)
            # print(time.time() - back_time1)
            connect_kernel.matrix.grad = options.con * (
                1 + options.gamma) * options.view_num * (
                    4 *
                    (connected_matrix @ connected_matrix @ connected_matrix) -
                    6 * (connected_matrix @ connected_matrix) +
                    2 * connected_matrix).detach().clone()
            for i in range(options.view_num):
                connect_kernel.matrix.grad += options.con * (
                    1 + options.gamma) * (2 * connected_matrix - 2 *
                                          (X_view[i] * singular_value[i])
                                          @ X_view[i].T).detach().clone()
            # loss_init.backward(retain_graph=False)
        # else:
        #     torch.sum(view_loss).backward(retain_graph=False)

        # back_time = time.time()
        # for i in range(options.view_num):
        #     for j in range(len(options.layer_width_c) - 1):
        #         connect_kernel.layerencode[j].weight.grad = torch.autograd.grad(loss_init, connect_kernel.layerencode[j].weight, retain_graph=True)[0]
        #     singular_value[i].grad = torch.autograd.grad(loss_init, singular_value[i], retain_graph=True)[0]
        # if len(connect_loss)>0:
        #     connect_kernel.matrix.grad = torch.autograd.grad(loss_init, connect_kernel.matrix, retain_graph=False)[0]
        print("back_time", time.time() - start)
        old_loss = loss_init
        if (epoch) % 25 == 0:
            print(epoch, ":")
            print('total_loss:\n', loss_init.item())
            print('one_loss:\n', loss_list)
            print("connectivity_matrix:")
            new_rank = torch.trace(connected_matrix).item() / torch.mean(
                torch.stack(singular_value)).item()
            print(new_rank)
            print(singular_value)
            print('---------------------')
        if ((epoch) % 25 == 0) and epoch != 0:
            if epoch == options.init_time:
                np.savetxt('consensus_kernel_before.npy',
                           (connected_matrix).detach().to(
                               torch.float32).cpu().numpy())
            if new_rank > 0.9 * old_rank and epoch >= options.init_time:
                start = time.time()
                np.savetxt('consensus_kernel_after.npy',
                           (connected_matrix).detach().to(
                               torch.float32).cpu().numpy())
                old_rank = min(new_rank, options.n_clusters)
                H = kernelkmeans((connected_matrix).detach().to(
                    torch.float32).cpu().numpy(), options.n_clusters)
                H_normalized = H / (np.sum(H**2, axis=1)**0.5).reshape(
                    [H.shape[0], 1])
                kmeans_model = KMeans(n_clusters=options.n_clusters,
                                      n_init='auto')

                repeat = 50
                best_inertia = np.zeros([3, repeat])
                for rep in range(repeat):
                    kmeans = kmeans_model.fit(H_normalized)
                    try:
                        y_prep = error_point(kmeans.labels_, y)[1]
                    except Exception as e:
                        y_prep = kmeans.labels_
                    y_prep = y_prep.astype('int')
                    ari = adjusted_rand_score(y, y_prep)
                    nmi = normalized_mutual_info_score(y, y_prep)
                    acc = accuracy_score(y, y_prep)
                    if ari > best_accurancy[0]:
                        best_label[0, :] = y_prep
                    if nmi > best_accurancy[1]:
                        best_label[1, :] = y_prep
                    if acc > best_accurancy[2]:
                        best_label[2, :] = y_prep
                    best_inertia[:, rep] = np.array([ari, nmi, acc])
                print(np.std(best_inertia, axis=1), np.max(best_inertia,
                                                           axis=1),
                      np.min(best_inertia, axis=1),
                      np.mean(best_inertia, axis=1))
                best_inertia = (np.max(best_inertia, axis=1)).tolist()
                best_accurancy = np.maximum(best_inertia,
                                            best_accurancy).tolist()

                print("total", best_inertia, best_accurancy)
                print('clustering_time', time.time() - start)
            elif new_rank < 0.9 * old_rank and old_loss > old_inter_loss:
                return best_accurancy, best_label, best_inertia, np.array(
                    object_function)

            for i in range(options.view_num):
                ari = adjusted_rand_score(
                    y,
                    np.argmax(X_view[i].detach().to(
                        torch.float32).cpu().numpy(),
                              axis=1))
                nmi = normalized_mutual_info_score(
                    y,
                    np.argmax(X_view[i].detach().to(
                        torch.float32).cpu().numpy(),
                              axis=1))
                print("view_" + str(i), ari, nmi)

            print('---------------------')
        for optimizer_key, optimizer_value in zip(optimizers.keys(),
                                                  optimizers.values()):
            optimizer_value.step()
            optimizer_value.zero_grad()
        connect_kernel.matrix.grad = None
        torch.cuda.empty_cache()
    return best_accurancy, best_label, best_inertia, np.array(object_function)

cuda:0
['CPLEX', 'CVXOPT', 'ECOS', 'ECOS_BB', 'GLPK', 'GLPK_MI', 'OSQP', 'SCIPY', 'SCS']


In [2]:
import sys
import os


class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


datasets_path = "./datasets"
file_name = "YALE.mat"
data = load_data(datasets_path + '/' + file_name)
torch.cuda.empty_cache()
EmptyStruct = type('EmptyStruct', (), {})
options = EmptyStruct()
n_clusters = len(np.unique(data.y))
options.n_clusters = n_clusters
view_num = len(data.x)

options.view_num = view_num
nSmp = data.x[0].shape[0]
options.nSmp = nSmp
options.layer_width_c = [nSmp, max([min([int(nSmp**0.5),int(nSmp/n_clusters)]), n_clusters]), n_clusters]
options.init_time = 300
options.maxIter = 600
Sigma = np.ones(view_num) / view_num
options.threshold = 0.45
options.regular = 1e-2
options.con = 1
print(view_num, data.x[0].shape)

options.knn = []
options.min_k = 3
# init kernel matrix
K = []
normal_X = []
for i in range(view_num):
    if file_name[-5] == 'K':
        TempK = data.x[i]
        TempK = kcenter(TempK)
        TempK = TempK / (np.diag(TempK)**0.5).reshape([options.nSmp, 1]) / (
            np.diag(TempK)**0.5).reshape([options.nSmp, 1]).T
        TempD = np.abs(np.sum(TempK, axis=1, keepdims=True))**-.5
        V, D = eigsh(TempK * TempD * TempD.T, k=options.n_clusters, which='LA')
        print(i, V)
        K.append(TempK / V[-1])
        options.knn.append(adapted_neighbor(TempK / V[-1], options.min_k))
        del TempK
    elif data.x[i].shape[1]==data.x[i].shape[0]:
        print('distance_data')
        TempK = data.x[i]
        t = np.mean(np.mean(TempK))
        TempK = np.exp(-TempK**2 / (2 * t**2))
        K.append(TempK)
        options.knn.append(adapted_neighbor(TempK, options.min_k))
        del TempK
    else:
        temp_x = NormalizeData(data.x[i])
        TempK = pairwise_distances(temp_x, metric='euclidean')
        # logic_knn, knn = generateNeighborhood(-TempK, knn_k)
        t = np.mean(np.mean(TempK))
        TempK = np.exp(-TempK**2 / (2 * t**2))
        K.append(TempK)
        options.knn.append(adapted_neighbor(TempK, options.min_k))
        del TempK, temp_x  #mean_X

2 (165, 1024)
3.1333333333333333 4
3.4 4


In [None]:
options.alpha = 1
options.beta = 0
options.eta = 0.05
final_accurancy = np.zeros([71, 6])
loss_value = np.zeros([71, options.maxIter - options.init_time])
# final_accurancy = np.loadtxt('results/' + file_name[:-4] + "_" +
#                              str(options.eta) + ".npy")
# loss_value = np.loadtxt('results/' + file_name[:-4] + "loss_value_" + "_" +
#             str(options.eta) + ".npy")
if final_accurancy.shape[0]<71:
    final_accurancy = np.concatenate([final_accurancy,np.zeros([71-final_accurancy.shape[0],6])],axis=0)
if loss_value.shape[0]<71:
    loss_value = np.concatenate([loss_value,np.zeros([71-loss_value.shape[0],options.maxIter - options.init_time])],axis=0)
# loss_value = np.loadtxt('results/' + file_name[:-4] + "loss_value_" + "_" +
#             str(options.eta) + ".npy")
# rewrite_index = np.array([25])
# final_accurancy[rewrite_index, :] = np.zeros([len(rewrite_index), 6])
t = 0
for num in range(-10, 11, 1):
    options.gamma = 2**(num)
    torch.cuda.empty_cache()
    if np.sum(final_accurancy[t, :]) == 0:
        best, best_label, last, loss = init_Kernel_train(data.y, K, options)
        print(best, last)
        best.extend(last)
        final_accurancy[t, :] = np.array(best)
        loss_value[t, :] = loss
        np.savetxt(
            'results/' + file_name[:-4] + "_" + str(options.eta) + ".npy",
            final_accurancy)
        np.savetxt(
            'results/' + file_name[:-4] + "loss_value_" + "_" +
            str(options.eta) + ".npy", loss_value)
        np.savetxt(
            'results/' + file_name[:-4] + "_" + "alpha" + str(options.alpha) +
            "beta" + str(options.beta) + "gamma" + str(options.gamma) + "eta" +
            str(options.eta) + ".npy", best_label)
    t += 1
    torch.cuda.empty_cache()
options.beta = 0
for num in range(-2, 3, 1):
    options.alpha = 2**num
    for num_1 in range(-2, 3, 1):
        options.gamma = 10**num_1
        torch.cuda.empty_cache()
        if np.sum(final_accurancy[t, :]) == 0:
            with HiddenPrints():
                best, best_label, last, loss = init_Kernel_train(
                    data.y, K, options)
            print(best, last)
            best.extend(last)
            final_accurancy[t, :] = np.array(best)
            loss_value[t, :] = loss
            np.savetxt(
                'results/' + file_name[:-4] + "_" + str(options.eta) + ".npy",
                final_accurancy)
            np.savetxt(
                'results/' + file_name[:-4] + "loss_value_" + "_" +
                str(options.eta) + ".npy", loss_value)
            np.savetxt(
                'results/' + file_name[:-4] + "_" + "alpha" + str(options.alpha) +
                "beta" + str(options.beta) + "gamma" + str(options.gamma) + "eta" +
                str(options.eta) + ".npy", best_label)
        t += 1
        torch.cuda.empty_cache()
options.alpha = 0
for num in range(-2, 3, 1):
    options.beta = 2**num
    for num_1 in range(-2, 3, 1):
        options.gamma = 10**num_1
        torch.cuda.empty_cache()
        if np.sum(final_accurancy[t, :]) == 0:
            with HiddenPrints():
                best, best_label, last, loss = init_Kernel_train(
                    data.y, K, options)
            print(best, last)
            best.extend(last)
            final_accurancy[t, :] = np.array(best)
            loss_value[t, :] = loss
            np.savetxt(
                'results/' + file_name[:-4] + "_" + str(options.eta) + ".npy",
                final_accurancy)
            np.savetxt(
                'results/' + file_name[:-4] + "loss_value_" + "_" +
                str(options.eta) + ".npy", loss_value)
            np.savetxt(
                'results/' + file_name[:-4] + "_" + "alpha" + str(options.alpha) +
                "beta" + str(options.beta) + "gamma" + str(options.gamma) + "eta" +
                str(options.eta) + ".npy", best_label)
        t += 1
        torch.cuda.empty_cache()