<a href="https://colab.research.google.com/github/yl3394/GMMDA/blob/main/gmmda_coauthor_physics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GMMDA for Coauthor-Physcis Dataset
## Section 1: Initilization Section for GNN
This section we defined necessary functions and classes for training and evluating the GMM-based auxiliary learning GCN for the graph data augmentation task. 

In [None]:
import os
import torch
#os.environ['TORCH'] = torch.__version__
os.environ['TORCH'] = '1.13.0+cu116' # now the system is 1.13.1+cu116, but no precompiled wheel of it
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git


In [None]:
#@title Connect to Google Drive 
ds = 'coauthor-physics'
# If model weights should be saved directly in google drive (takes around 4-5 GB).
save_to_gdrive = True 
if save_to_gdrive:
    from google.colab import drive
    drive.mount('/content/drive')

# Enter the directory name to save model at.
OUTPUT_DIR = f"gda_dnml/{ds}" 
if save_to_gdrive:
    OUTPUT_DIR = "/content/drive/MyDrive/" + OUTPUT_DIR
else:
    OUTPUT_DIR = "/content/" + OUTPUT_DIR

print(f"[*] Weights will be saved at {OUTPUT_DIR}")
!mkdir -p $OUTPUT_DIR

DATA_DIR = OUTPUT_DIR + '/data'
print(f"[*] Data will be saved at {DATA_DIR}")
!mkdir -p $DATA_DIR

MODEL_DIR = OUTPUT_DIR + '/model'
print(f"[*] Model will be saved at {DATA_DIR}")
!mkdir -p $MODEL_DIR

EXP_RESULT_DIR = OUTPUT_DIR + '/output'
print(f"[*] Experimental results will be saved at {EXP_RESULT_DIR}")
!mkdir -p $EXP_RESULT_DIR

FIGURE_DIR = OUTPUT_DIR + '/figure'
print(f"[*] Image will be saved at {FIGURE_DIR}")
!mkdir -p $FIGURE_DIR

FIGURE_ABLATION_DIR = FIGURE_DIR + '/ablation_analysis'
print(f"[*] Image of ablation analysis will be saved at {FIGURE_ABLATION_DIR}")
!mkdir -p $FIGURE_ABLATION_DIR

# make sub-directories for analysis 
c_analysis_dir = FIGURE_ABLATION_DIR + '/candidate_analysis'
tsne_dnml_dir = c_analysis_dir + '/dnml_tsne_best_emb'
!mkdir -p $c_analysis_dir
!mkdir -p $tsne_dnml_dir


In [None]:
#@title Load Data to Get Number of Components
import pickle 

with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-data", 'rb') as fp:
    data = pickle.load(fp)

with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-dataset", 'rb') as fp:
    dataset = pickle.load(fp)

num_classes = dataset.num_classes

print(f"""{ds}:
  Size of augmented training set: {data.train_mask.sum()}, size of original training set: {data.train_mask.sum()}; 
  Size of augmented test set: {data.test_mask.sum()}, size of original test set: {data.test_mask.sum()};
  Size of augmented validation set: {data.val_mask.sum()}, size of original validation set: {data.val_mask.sum()}.
""")

In [None]:
#@title GNN Code 
##### import sys
import math
import copy
import random
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from scipy.sparse import csr_matrix
import gc 

class GNN(object):
    """Graph Neural Networks that can be easily called and used.

    Authors of this code package:
    Tong Zhao, tzhao2@nd.edu
    Tianwen Jiang, twjiang@ir.hit.edu.cn

    Last updated: 11/25/2019

    Parameters
    ----------
    adj_matrix: scipy.sparse.csr_matrix
        The adjacency matrix of the graph, where nonzero entries indicates edges.
        The number of each nonzero entry indicates the number of edges between these two nodes.

    features: numpy.ndarray, optional
        The 2-dimension np array that stores given raw feature of each node, where the i-th row
        is the raw feature vector of node i.
        When raw features are not given, one-hot degree features will be used.

    labels: list or 1-D numpy.ndarray, optional
        The class label of each node. Used for supervised learning.
        
    learn_method: {'unsup', 'sup', 'aux'}, defualt 'unsup'
        Whether to use supervised learning, unsupervised learning or supervised+unsupervised (auxiliry) learning.

    model: {'gat', 'graphsage'}, default 'gat'
        The GNN model to be used.
        - 'graphsage' is GraphSAGE: https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
        - 'gat' is graph attention network: https://arxiv.org/pdf/1710.10903.pdf

    n_layer: int, optional, default 2
        Number of layers in the GNN

    emb_size: int, optional, default 128
        Size of the node embeddings to be learnt

    random_state, int, optional, default 1234
        Random seed

    device: {'cpu', 'cuda', 'auto'}, default 'auto'
        The device to use.

    epochs: int, optional, default 5
        Number of epochs for training

    batch_size: int, optional, default 20
        Number of node per batch for training

    lr: float, optional, default 0.7
        Learning rate

    unsup_loss_type: {'margin', 'normal'}, default 'margin'
        Loss function to be used for unsupervised learning
        - 'margin' is a hinge loss with margin of 3
        - 'normal' is the unsupervised loss function described in the paper of GraphSAGE

    print_progress: bool, optional, default True
        Whether to print the training progress
    """
    def __init__(self, 
                 adj_matrix, 
                 features=None, 
                 labels=None, 
                 learn_method = 'unsup',
                 model='gat', 
                 n_layer=2, 
                 emb_size=64, 
                 random_state=1234, 
                 device='auto', 
                 epochs=5, 
                 batch_size=20, 
                 lr=0.7, 
                 unsup_loss_type='margin', 
                 print_progress=True):
        super(GNN, self).__init__()
        # fix random seeds
        random.seed(random_state)
        np.random.seed(random_state)
        torch.manual_seed(random_state)
        torch.cuda.manual_seed_all(random_state)
        # set parameters
        self.learn_method = learn_method            
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.unsup_loss_type = unsup_loss_type
        self.print_progress = print_progress
        self.gat = False
        self.gcn = False
        if model == 'gat':
            self.gat = True
            self.model_name = 'GAT'
        elif model == 'gcn':
            self.gcn = True
            self.model_name = 'GCN'
        else:
            self.model_name = 'GraphSAGE'
        # set device
        if device == 'auto':
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        # load data
        self.dl = DataLoader(adj_matrix, features, labels, learn_method, self.device)

        self.gnn = GNN_model(n_layer, emb_size, self.dl, self.device, gat=self.gat, gcn=self.gcn)
        self.gnn.to(self.device)

        if learn_method != 'unsup':
            n_classes = len(set(labels))
            self.classification = Classification(emb_size, n_classes)
            self.classification.to(self.device)
        
        print(f'K has been set to = {num_classes}')
        self.gmm = GaussianMixtureModel(n_components = num_classes)
        self.gmm.to(self.device)
            
    def fit(self):
        train_nodes = copy.deepcopy(self.dl.nodes_train)                
                
        if self.learn_method == 'sup':
            # superivsed learning
            labels = self.dl.labels
            models = [self.gnn, self.classification]
        elif self.learn_method == 'aux':
            # superivsed learning
            labels = self.dl.labels
            models = [self.gnn, self.classification]
            # unsuperivsed learning
            unsup_loss = Unsup_Loss(self.dl, self.device)
            if self.unsup_loss_type == 'margin':
                num_neg = 6
            elif self.unsup_loss_type == 'normal':
                num_neg = 100
        else:
            unsup_loss = Unsup_Loss(self.dl, self.device)
            models = [self.gnn]
            if self.unsup_loss_type == 'margin':
                num_neg = 6
            elif self.unsup_loss_type == 'normal':
                num_neg = 100

        print(f"Initializing a GNN... with shape {train_nodes.shape}")
        nodes_batch_pretrain = train_nodes
        # extend nodes batch for unspervised learning
        nodes_batch_pretrain = np.asarray(list(unsup_loss.extend_nodes(nodes_batch_pretrain, num_neg=num_neg)))
        # print(f"nodes_batch_pretrain unique {len(np.unique(nodes_batch_pretrain))}")
        # feed nodes batch to the GNN and returning the nodes embeddings                
        gnn_pretrain_optimizer = torch.optim.SGD(self.gnn.parameters(), lr=self.lr)
        classification_pretrain_optimizer = torch.optim.SGD(self.classification.parameters(), lr=self.lr)
        
        # add pre-train O_1 
        for _ in range(200): #200
            # clean up cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            
            embs_batch_pretrain = self.gnn(nodes_batch_pretrain)
            #print(f'outputing embedding pretrain with shape {embs_batch_pretrain.shape}, nodes_batch_pretrain: {nodes_batch_pretrain.shape}')
            
            loss_net_pretrain = unsup_loss.get_loss_margin(embs_batch_pretrain, nodes_batch_pretrain)
            
            logists_pretrain = self.classification(embs_batch_pretrain)
            labels_batch = labels[nodes_batch_pretrain]
            loss_sup_pretrain = -torch.sum(logists_pretrain[range(logists_pretrain.size(0)), labels_batch], 0)
            loss_sup_pretrain /= len(nodes_batch_pretrain)
            
            loss_pretrain = loss_sup_pretrain + loss_net_pretrain
            
            # clean up cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            
            loss_pretrain.backward(retain_graph=True)
            if _ % 50 == 0:
                print(f"GNN pre-training step {_}th iteration's loss = {loss_net_pretrain}")
            gnn_pretrain_optimizer.step()
            gnn_pretrain_optimizer.zero_grad()
            
            classification_pretrain_optimizer.step()
            classification_pretrain_optimizer.zero_grad() 
    
        print(f'outputing embedding pretrain with shape {embs_batch_pretrain.shape}')
                
        for epoch in range(self.epochs):
            np.random.shuffle(train_nodes)
            
            # pretrain GMM for a better initialization
            # print('Initializing a pre-trained GMM ...')
            embs_all = self.gnn(train_nodes)
            gmm_optimizer = torch.optim.SGD(self.gmm.parameters(), lr=0.0000001, momentum=0.9) # lr=0.000001, momentum=0.8)

            # I_2
            for _ in range(5_00): #1_000
                gmm_pretrain_loss = self.gmm(embs_all)
                gmm_pretrain_loss.backward(retain_graph=True)
                if _ % 500 == 0:
                    print(f"GMM training step {_}th iteration's loss = {gmm_pretrain_loss}")
                gmm_optimizer.step()
                if torch.all(self.gmm._get_weights() >= 0) == True: 
                    gmm_optimizer.zero_grad()
                else: 
                    break 
                # clean up cache
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
                
            weights, means, stdevs = self.gmm._get_parameters()

            # update GMM parameters
            self.gmm._update_parameters(weights, means, stdevs)
            loss_gmm = self.gmm(embs_all)
            loss_gmm /= len(train_nodes)

            params = []
            
            for model in models:
                
                """for name, param in model.named_parameters():
                    if param.requires_grad:
                        print(name, param.data)"""
                
                for param in model.parameters():
                    if param.requires_grad:
                        params.append(param)
            optimizer = torch.optim.SGD(params, lr=self.lr)
            optimizer.zero_grad()
            for model in models:
                model.zero_grad()

            batches = math.ceil(len(train_nodes) / self.batch_size)
            visited_nodes = set()
            if self.print_progress:
                tqdm_bar = tqdm(range(batches), ascii=True, leave=False)
            else:
                tqdm_bar = range(batches)
                
            # I_3
            for i3 in range(300):  #300
                for index in tqdm_bar:
                    if self.learn_method != 'sup' and len(visited_nodes) == len(train_nodes):
                        # finish this epoch if all nodes are visited
                        if self.print_progress:
                            tqdm_bar.close()
                        break
                    nodes_batch = train_nodes[index*self.batch_size:(index+1)*self.batch_size]
                    # extend nodes batch for unspervised learning
                    if self.learn_method != 'sup':
                        nodes_batch = np.asarray(list(unsup_loss.extend_nodes(nodes_batch, num_neg=num_neg)))
                    visited_nodes |= set(nodes_batch)
                    # feed nodes batch to the GNN and returning the nodes embeddings
                    embs_batch = self.gnn(nodes_batch)

                    # calculate loss
                    if self.learn_method == 'sup':
                        # superivsed learning
                        logists = self.classification(embs_batch)
                        labels_batch = labels[nodes_batch]
                        loss_sup = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
                        loss_sup /= len(nodes_batch)
                        loss = loss_sup 
                    elif self.learn_method == 'aux':
                        # superivsed learning
                        logists = self.classification(embs_batch)
                        labels_batch = labels[nodes_batch]
                        loss_sup = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
                        loss_sup /= len(nodes_batch)
                        # unsuperivsed learning
                        if self.unsup_loss_type == 'margin':
                            loss_net = unsup_loss.get_loss_margin(embs_batch, nodes_batch)
                        elif self.unsup_loss_type == 'normal':
                            loss_net = unsup_loss.get_loss_sage(embs_batch, nodes_batch)
                        loss = loss_sup + loss_net 
                    else:
                        if self.unsup_loss_type == 'margin':
                            loss_net = unsup_loss.get_loss_margin(embs_batch, nodes_batch)
                        elif self.unsup_loss_type == 'normal':
                            loss_net = unsup_loss.get_loss_sage(embs_batch, nodes_batch)
                        loss = loss_net 


                    # add losses together 
                    loss = loss + loss_gmm


                    if self.print_progress:
                        progress_message = '{} Epoch: [{}/{}], I3 iteration [{}/100] current loss: {:.4f}, touched nodes [{}/{}] '.format(
                                        self.model_name, epoch+1, self.epochs, i3,loss.item(), len(visited_nodes), len(train_nodes))
                        tqdm_bar.set_description(progress_message)

                    loss.backward()
                    for model in models:
                        nn.utils.clip_grad_norm_(model.parameters(), 5)
                    optimizer.step()
                    optimizer.zero_grad()
                    for model in models:
                        model.zero_grad()
                     
                    # clean up cache
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    gc.collect()

                    #nc_logits_eval = self.classification(embs_batch)
                    #class_label_eval = nc_logits_eval.max(1).indices
                    #class_label_eval = class_label_eval.cpu().detach().numpy()

                    #val_acc = accuracy_score(labels, class_label_eval)
                          
            print('{} Epoch: [{}/{}], current loss: {:.4f}'.format(
                self.model_name, epoch+1, self.epochs, loss.item()))
            
            # clean up cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            
            
    def _get_models(self):
        return(self.gnn, self.gmm, self.classification)

    def generate_embeddings(self, nodes):
        nodes = nodes #self.dl.nodes_train
        print(f'nodes: {nodes.shape}')
        b_sz = 500
        batches = math.ceil(len(nodes) / b_sz)
        embs = []
        for index in range(batches):
            nodes_batch = nodes[index*b_sz:(index+1)*b_sz]
            with torch.no_grad():
                embs_batch = self.gnn(nodes_batch)
            assert len(embs_batch) == len(nodes_batch)
            embs.append(embs_batch)
        assert len(embs) == batches
        embs = torch.cat(embs, 0)
        assert len(embs) == len(nodes)
        return embs.cpu().numpy()

    def predict(self, nodes):
        if self.learn_method == 'unsup':
            print('GNN.predict() is only supported for supervised learning.')
            sys.exit(0)
        nodes = self.dl.nodes_train
        b_sz = 500
        batches = math.ceil(len(nodes) / b_sz)
        preds = []
        for index in range(batches):
            nodes_batch = nodes[index*b_sz:(index+1)*b_sz]
            with torch.no_grad():
                embs_batch = self.gnn(nodes_batch)
                logists = self.classification(embs_batch)
                _, predicts = torch.max(logists, 1)
                preds.append(predicts)
        assert len(preds) == batches
        preds = torch.cat(preds, 0)
        assert len(preds) == len(nodes)
        return preds.cpu().numpy()
    
    def eval_node_cls(nc_logits, labels):
        """ evaluate node classification results """
        preds = torch.argmax(nc_logits, dim=1)
        correct = torch.sum(preds == labels)
        acc = correct.item() / len(labels)
        return acc

    def release_cuda_cache(self):
        torch.cuda.empty_cache()


class DataLoader(object):
    def __init__(self, adj_matrix, raw_features, labels, learn_method, device):
        super(DataLoader, self).__init__()
        self.adj_matrix = adj_matrix
        # load adjacency list and node features
        self.adj_list = self.get_adj_list(adj_matrix)
        if raw_features is None:
            features = self.get_features()
        else:
            features = raw_features
        assert features.shape[0] == len(self.adj_list) == self.adj_matrix.shape[0]
        self.features = torch.FloatTensor(features).to(device)
        #self.nodes_train = list(range(len(self.adj_list)))
        
        # split data into training, test, and validation sets temp 
        #features = torch.FloatTensor(getattr(dataCenter, ds+'_feats')).to(device)
        #labels = torch.FloatTensor(getattr(dataCenter, ds+'_labels')).to(device)
        #adj_lists = getattr(dataCenter, ds+'_adj_lists')
        self.nodes_test = getattr(dataCenter, ds+'_test')
        self.nodes_val = getattr(dataCenter, ds+'_val')
        self.nodes_train = getattr(dataCenter, ds+'_train')
        
        if learn_method != 'unsup':
            self.labels = np.asarray(labels)

    def get_adj_list(self, adj_matrix):
        """build adjacency list from adjacency matrix"""
        adj_list = {}
        for i in range(adj_matrix.shape[0]):
            adj_list[i] = set(np.where(adj_matrix[i].toarray() != 0)[1])
        return adj_list

    def get_features(self):
        """
        When raw features are not available,
        build one-hot degree features from the adjacency list.
        """
        max_degree = np.max(np.sum(self.adj_matrix != 0, axis=1))
        features = np.zeros((self.adj_matrix.shape[0], max_degree))
        for node, neighbors in self.adj_list.items():
            features[node, len(neighbors)-1] = 1
        return features


class Classification(nn.Module):
    def __init__(self, emb_size, num_classes):
        super(Classification, self).__init__()
        self.fc1 = nn.Linear(emb_size, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, embeds):
        x = F.elu(self.fc1(embeds))
        x = F.elu(self.fc2(x))
        logists = torch.log_softmax(x, 1)
        return logists


class Unsup_Loss(object):
    """docstring for UnsupervisedLoss"""
    def __init__(self, dl, device):
        super(Unsup_Loss, self).__init__()
        self.Q = 10
        self.N_WALKS = 4
        self.WALK_LEN = 4
        self.N_WALK_LEN = 5
        self.MARGIN = 3
        self.adj_lists = dl.adj_list
        self.adj_matrix = dl.adj_matrix
        self.train_nodes = dl.nodes_train
        self.device = device

        self.target_nodes = None
        self.positive_pairs = []
        self.negative_pairs = []
        self.node_positive_pairs = {}
        self.node_negative_pairs = {}
        self.unique_nodes_batch = []

    def get_loss_sage(self, embeddings, nodes):
        assert len(embeddings) == len(self.unique_nodes_batch)
        assert False not in [nodes[i]==self.unique_nodes_batch[i] for i in range(len(nodes))]
        node2index = {n:i for i,n in enumerate(self.unique_nodes_batch)}

        nodes_score = []
        print(f'len(self.node_positive_pairs):{ len(self.node_positive_pairs)};len(self.node_negative_pairs):{len(self.node_negative_pairs)} ')
        assert len(self.node_positive_pairs) == len(self.node_negative_pairs)
        for node in self.node_positive_pairs:
            pps = self.node_positive_pairs[node]
            nps = self.node_negative_pairs[node]
            if len(pps) == 0 or len(nps) == 0:
                continue

            # Q * Exception(negative score)
            indexs = [list(x) for x in zip(*nps)]
            node_indexs = [node2index[x] for x in indexs[0]]
            neighb_indexs = [node2index[x] for x in indexs[1]]
            neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
            neg_score = self.Q*torch.mean(torch.log(torch.sigmoid(-neg_score)), 0)

            # multiple positive score
            indexs = [list(x) for x in zip(*pps)]
            node_indexs = [node2index[x] for x in indexs[0]]
            neighb_indexs = [node2index[x] for x in indexs[1]]
            pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
            pos_score = torch.log(torch.sigmoid(pos_score))

            nodes_score.append(torch.mean(- pos_score - neg_score).view(1,-1))

        loss = torch.mean(torch.cat(nodes_score, 0))
        return loss

    def get_loss_margin(self, embeddings, nodes):
        
        assert len(embeddings) == len(self.unique_nodes_batch)
        assert False not in [nodes[i]==self.unique_nodes_batch[i] for i in range(len(nodes))]
        node2index = {n:i for i,n in enumerate(self.unique_nodes_batch)}

        nodes_score = []
        
        #print(f'len(self.node_positive_pairs): {len(self.node_positive_pairs)}; len(self.node_negative_pairs):{len(self.node_negative_pairs)}')
        assert len(self.node_positive_pairs) == len(self.node_negative_pairs)
        for node in self.node_positive_pairs:
            pps = self.node_positive_pairs[node]
            nps = self.node_negative_pairs[node]
            if len(pps) == 0 or len(nps) == 0:
                continue

            indexs = [list(x) for x in zip(*pps)]
            node_indexs = [node2index[x] for x in indexs[0]]
            neighb_indexs = [node2index[x] for x in indexs[1]]
            pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
            pos_score, _ = torch.min(torch.log(torch.sigmoid(pos_score)), 0)

            indexs = [list(x) for x in zip(*nps)]
            node_indexs = [node2index[x] for x in indexs[0]]
            neighb_indexs = [node2index[x] for x in indexs[1]]
            neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
            neg_score, _ = torch.max(torch.log(torch.sigmoid(neg_score)), 0)

            nodes_score.append(torch.max(torch.tensor(0.0).to(self.device),
                                         neg_score-pos_score+self.MARGIN).view(1, -1))
        loss = torch.mean(torch.cat(nodes_score, 0), 0)
        return loss

    def extend_nodes(self, nodes, num_neg=6):
        self.positive_pairs = []
        self.node_positive_pairs = {}
        self.negative_pairs = []
        self.node_negative_pairs = {}

        self.target_nodes = nodes
        self.get_positive_nodes(nodes)
        self.get_negative_nodes(nodes, num_neg)
        self.unique_nodes_batch = list(set([i for x in self.positive_pairs for i in x])
                                       | set([i for x in self.negative_pairs for i in x]))
        assert set(self.target_nodes) <= set(self.unique_nodes_batch)
        return self.unique_nodes_batch

    def get_positive_nodes(self, nodes):
        return self._run_random_walks(nodes)

    def get_negative_nodes(self, nodes, num_neg):
        for node in nodes:
            neighbors = set([node])
            frontier = set([node])
            for _ in range(self.N_WALK_LEN):
                current = set()
                for outer in frontier:
                    current |= self.adj_lists[int(outer)]
                frontier = current - neighbors
                neighbors |= current
            far_nodes = set(self.train_nodes) - neighbors
            neg_samples = random.sample(far_nodes, num_neg) if num_neg < len(far_nodes) else far_nodes
            self.negative_pairs.extend([(node, neg_node) for neg_node in neg_samples])
            self.node_negative_pairs[node] = [(node, neg_node) for neg_node in neg_samples]
        return self.negative_pairs

    def _run_random_walks(self, nodes):
        for node in nodes:
            if len(self.adj_lists[int(node)]) == 0:
                continue
            cur_pairs = []
            for _ in range(self.N_WALKS):
                curr_node = node
                for _ in range(self.WALK_LEN):
                    cnts = self.adj_matrix[int(curr_node)].toarray().squeeze()
                    neighs = []
                    for n in np.where(cnts != 0)[0]:
                        neighs.extend([n] * int(cnts[n]))
                    # neighs = self.adj_lists[int(curr_node)]
                    next_node = random.choice(list(neighs))
                    # self co-occurrences are useless
                    if next_node != node and next_node in self.train_nodes:
                        self.positive_pairs.append((node,next_node))
                        cur_pairs.append((node,next_node))
                    curr_node = next_node

            self.node_positive_pairs[node] = cur_pairs
        return self.positive_pairs


class SageLayer(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, input_size, out_size, gat=False, gcn=False):
        super(SageLayer, self).__init__()

        self.input_size = input_size
        self.out_size = out_size

        self.gat = gat
        self.gcn = gcn
        self.weight = nn.Parameter(torch.FloatTensor(out_size, self.input_size if self.gat or self.gcn else 2 * self.input_size))

        self.init_params()

    def init_params(self):
        for param in self.parameters():
            nn.init.xavier_uniform_(param)

    def forward(self, self_feats, aggregate_feats):
        """
        Generates embeddings for a batch of nodes.
        nodes	 -- list of nodes
        """
        if self.gat or self.gcn:
            combined = aggregate_feats
        else:
            combined = torch.cat([self_feats, aggregate_feats], dim=1)
        combined = F.relu(self.weight.mm(combined.t())).t()
        return combined

class Attention(nn.Module):
    """Computes the self-attention between pair of nodes"""
    def __init__(self, input_size, out_size):
        super(Attention, self).__init__()

        self.input_size = input_size
        self.out_size = out_size
        self.attention_raw = nn.Linear(2*input_size, 1, bias=False)
        self.attention_emb = nn.Linear(2*out_size, 1, bias=False)

    def forward(self, row_embs, col_embs):
        if row_embs.size(1) == self.input_size:
            att = self.attention_raw
        elif row_embs.size(1) == self.out_size:
            att = self.attention_emb
        e = att(torch.cat((row_embs, col_embs), dim=1))
        return F.leaky_relu(e, negative_slope=0.2)

class GNN_model(nn.Module):
    """docstring for GraphSage"""
    def __init__(self, num_layers, out_size, dl, device, gat=False, gcn=False, agg_func='MEAN'):
        super(GNN_model, self).__init__()

        self.input_size = dl.features.size(1)
        self.out_size = out_size
        self.num_layers = num_layers
        self.gat = gat
        self.gcn = gcn
        self.device = device
        self.agg_func = agg_func

        self.raw_features = dl.features
        self.adj_lists = dl.adj_list
        self.adj_matrix = dl.adj_matrix

        for index in range(1, num_layers+1):
            layer_size = out_size if index != 1 else self.input_size
            setattr(self, 'sage_layer'+str(index), SageLayer(layer_size, out_size, gat=self.gat, gcn=self.gcn))
        if self.gat:
            self.attention = Attention(self.input_size, out_size)

    def forward(self, nodes_batch):
        """
        Generates embeddings for a batch of nodes.
        nodes_batch	-- batch of nodes to learn the embeddings
        """
        lower_layer_nodes = list(nodes_batch)
        nodes_batch_layers = [(lower_layer_nodes,)]
        for _ in range(self.num_layers):
            lower_layer_nodes, lower_samp_neighs, lower_layer_nodes_dict= self._get_unique_neighs_list(lower_layer_nodes)
            nodes_batch_layers.insert(0, (lower_layer_nodes, lower_samp_neighs, lower_layer_nodes_dict))

        assert len(nodes_batch_layers) == self.num_layers + 1

        pre_hidden_embs = self.raw_features
        for index in range(1, self.num_layers+1):
            nb = nodes_batch_layers[index][0]
            pre_neighs = nodes_batch_layers[index-1]
            aggregate_feats = self.aggregate(nb, pre_hidden_embs, pre_neighs)
            sage_layer = getattr(self, 'sage_layer'+str(index))
            if index > 1:
                nb = self._nodes_map(nb, pre_neighs)
            cur_hidden_embs = sage_layer(self_feats=pre_hidden_embs[nb], aggregate_feats=aggregate_feats)
            pre_hidden_embs = cur_hidden_embs

        return pre_hidden_embs

    def _nodes_map(self, nodes, neighs):
        _, samp_neighs, layer_nodes_dict = neighs
        assert len(samp_neighs) == len(nodes)
        index = [layer_nodes_dict[x] for x in nodes]
        return index

    def _get_unique_neighs_list(self, nodes, num_sample=10):
        _set = set
        to_neighs = [self.adj_lists[int(node)] for node in nodes]
        if self.gcn or self.gat:
            samp_neighs = to_neighs
        else:
            _sample = random.sample
            samp_neighs = [_set(_sample(to_neigh, num_sample)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
        samp_neighs = [samp_neigh | set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)]
        _unique_nodes_list = list(set.union(*samp_neighs))
        i = list(range(len(_unique_nodes_list)))
        # unique node 2 index
        unique_nodes = dict(list(zip(_unique_nodes_list, i)))
        return _unique_nodes_list, samp_neighs, unique_nodes

    def aggregate(self, nodes, pre_hidden_embs, pre_neighs):
        unique_nodes_list, samp_neighs, unique_nodes = pre_neighs

        assert len(nodes) == len(samp_neighs)
        indicator = [(nodes[i] in samp_neighs[i]) for i in range(len(samp_neighs))]
        assert False not in indicator
        if not self.gat and not self.gcn:
            samp_neighs = [(samp_neighs[i]-set([nodes[i]])) for i in range(len(samp_neighs))]
        if len(pre_hidden_embs) == len(unique_nodes):
            embed_matrix = pre_hidden_embs
        else:
            embed_matrix = pre_hidden_embs[torch.LongTensor(unique_nodes_list)]
        # get row and column nonzero indices for the mask tensor
        row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
        column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
        # get the edge counts for each edge
        edge_counts = self.adj_matrix[nodes][:, unique_nodes_list].toarray()
        edge_counts = torch.FloatTensor(edge_counts).to(embed_matrix.device)
        torch.sqrt_(edge_counts)
        if self.gat:
            indices = (torch.LongTensor(row_indices), torch.LongTensor(column_indices))
            nodes_indices = torch.LongTensor([unique_nodes[nodes[n]] for n in row_indices])
            row_embs = embed_matrix[nodes_indices]
            col_embs = embed_matrix[column_indices]
            atts = self.attention(row_embs, col_embs).squeeze()
            mask = torch.zeros(len(samp_neighs), len(unique_nodes)).to(embed_matrix.device)
            mask.index_put_(indices, atts)
            mask = mask * edge_counts
            # softmax
            mask = torch.exp(mask) * (mask != 0).float()
            mask = F.normalize(mask, p=1, dim=1)
        else:
            mask = torch.zeros(len(samp_neighs), len(unique_nodes)).to(embed_matrix.device)
            mask[row_indices, column_indices] = 1
            # multiply edge counts to mask
            mask = mask * edge_counts
            mask = F.normalize(mask, p=1, dim=1)
            mask = mask.to(embed_matrix.device)

        if self.agg_func == 'MEAN':
            aggregate_feats = mask.mm(embed_matrix)
        elif self.agg_func == 'MAX':
            indexs = [x.nonzero() for x in mask != 0]
            aggregate_feats = []
            for feat in [embed_matrix[x.squeeze()] for x in indexs]:
                if len(feat.size()) == 1:
                    aggregate_feats.append(feat.view(1, -1))
                else:
                    aggregate_feats.append(torch.max(feat,0)[0].view(1, -1))
            aggregate_feats = torch.cat(aggregate_feats, 0)

        return aggregate_feats


In [None]:
#@title GaussianMixtureModel code
import torch
from torch import nn
from torch import optim
import torch.distributions as D

class GaussianMixtureModel(torch.nn.Module):
    # https://discuss.pytorch.org/t/fit-gaussian-mixture-model/121826

    def __init__(self, n_components: int=num_classes):
        super().__init__()
        weights = torch.ones(n_components, )
        means   = torch.randn(n_components, )
        stdevs  = torch.tensor(np.abs(np.random.randn(n_components, )))
        self.weights = torch.nn.Parameter(weights)
        self.means   = torch.nn.Parameter(means)
        self.stdevs  = torch.nn.Parameter(stdevs)
        
    def _update_parameters(self, new_weights, new_means, new_stdevs):
        self.weights = torch.nn.Parameter(new_weights)
        self.means   = torch.nn.Parameter(new_means)
        self.stdevs  = torch.nn.Parameter(new_stdevs)
        
    def _get_weights(self):
        return self.weights
    
    def _get_parameters(self):
        return self.weights, self.means, self.stdevs
    
    def forward(self, x):
        #print(self.weights)
        mix  = D.Categorical(self.weights)
        #std_weight = 1e-4
        #comp = D.Normal(self.means, std_weight * self.stdevs.abs())
        comp = D.Normal(self.means, self.stdevs)
        gmm  = D.MixtureSameFamily(mix, comp)
        return - gmm.log_prob(x).mean()
    
    

In [None]:
#@title Data loader 
import sys
import os

from collections import defaultdict
import numpy as np

class DataCenter(object):
    """docstring for DataCenter"""
    """def __init__(self, config):
        super(DataCenter, self).__init__()
        self.config = config"""
    def __init__(self):
        super(DataCenter, self).__init__()
        
    def load_dataSet(self, dataSet='cora'):
        if dataSet == 'cora':
            cora_content_file = DATA_DIR + '/cora.content'
            cora_cite_file = DATA_DIR + '/cora.cites'

            feat_data = []
            labels = [] # label sequence of node
            node_map = {} # map node to Node_ID
            label_map = {} # map label to Label_ID
            with open(cora_content_file) as fp:
                for i,line in enumerate(fp):
                    info = line.strip().split()
                    feat_data.append([float(x) for x in info[1:-1]])
                    node_map[info[0]] = i
                    if not info[-1] in label_map:
                        label_map[info[-1]] = len(label_map)
                    labels.append(label_map[info[-1]])
            feat_data = np.asarray(feat_data)
            labels = np.asarray(labels, dtype=np.int64)
            
            row = []
            col = []
            adj_lists = defaultdict(set)
            with open(cora_cite_file) as fp:
                for i,line in enumerate(fp):
                    info = line.strip().split()
                    assert len(info) == 2
                    paper1 = node_map[info[0]]
                    paper2 = node_map[info[1]]
                    adj_lists[paper1].add(paper2)
                    adj_lists[paper2].add(paper1)
                    row.extend([paper1, paper2])
                    col.extend([paper2, paper1])
                    
            #row = np.asarray(row)
            #col = np.asarray(col)
            #adj_matrix = csr_matrix((np.ones(len(row)), (row, col)), shape=(len(node_map), len(node_map)))

            assert len(feat_data) == len(labels) == len(adj_lists)
            test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])

            setattr(self, dataSet+'_test', test_indexs)
            setattr(self, dataSet+'_val', val_indexs)
            setattr(self, dataSet+'_train', train_indexs)

            setattr(self, dataSet+'_feats', feat_data)
            setattr(self, dataSet+'_labels', labels)
            setattr(self, dataSet+'_adj_lists', adj_lists)
            #setattr(self, dataSet+'_adj_matrix', adj_matrix)

        elif dataSet == 'citeseer':
          dataset = datasets.Planetoid(root=f'/tmp/Citeseer', name='Citeseer')
          data = dataset[0].to(device)

          #adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
          feat_data = data.x.to('cpu').detach().numpy().copy()
          labels = data.y.to('cpu').detach().numpy().copy()
          edge_index = data.edge_index.to('cpu').detach().numpy().copy()

          adj_lists = defaultdict(set)
          for (origin, dest) in zip(edge_index[0], edge_index[1]):
            adj_lists[origin].add(dest)
            adj_lists[dest].add(origin)

          test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])

          setattr(self, dataSet+'_test', test_indexs)
          setattr(self, dataSet+'_val', val_indexs)
          setattr(self, dataSet+'_train', train_indexs)

          setattr(self, dataSet+'_feats', feat_data)
          setattr(self, dataSet+'_labels', labels)
          setattr(self, dataSet+'_adj_lists', adj_lists)

        elif dataSet == 'cora-full':
          dataset = datasets.CoraFull(root=f'/tmp/CoraFull')
          data = dataset[0].to(device)

          #adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
          feat_data = data.x.to('cpu').detach().numpy().copy()
          labels = data.y.to('cpu').detach().numpy().copy()
          edge_index = data.edge_index.to('cpu').detach().numpy().copy()

          adj_lists = defaultdict(set)
          for (origin, dest) in zip(edge_index[0], edge_index[1]):
            adj_lists[origin].add(dest)
            adj_lists[dest].add(origin)

          test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])

          setattr(self, dataSet+'_test', test_indexs)
          setattr(self, dataSet+'_val', val_indexs)
          setattr(self, dataSet+'_train', train_indexs)

          setattr(self, dataSet+'_feats', feat_data)
          setattr(self, dataSet+'_labels', labels)
          setattr(self, dataSet+'_adj_lists', adj_lists)

        elif dataSet == 'coauthor-cs':
          dataset = datasets.Coauthor(root=f'/tmp/Coauthor-CS', name='CS')
          data = dataset[0].to(device)

          #adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
          feat_data = data.x.to('cpu').detach().numpy().copy()
          labels = data.y.to('cpu').detach().numpy().copy()
          edge_index = data.edge_index.to('cpu').detach().numpy().copy()

          adj_lists = defaultdict(set)
          for (origin, dest) in zip(edge_index[0], edge_index[1]):
            adj_lists[origin].add(dest)
            adj_lists[dest].add(origin)

          test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])

          setattr(self, dataSet+'_test', test_indexs)
          setattr(self, dataSet+'_val', val_indexs)
          setattr(self, dataSet+'_train', train_indexs)

          setattr(self, dataSet+'_feats', feat_data)
          setattr(self, dataSet+'_labels', labels)
          setattr(self, dataSet+'_adj_lists', adj_lists)

        elif dataSet == 'coauthor-physics':
          dataset = datasets.Coauthor(root=f'/tmp/Coauthor-Physics', name='Physics')
          data = dataset[0].to(device)

          #adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
          feat_data = data.x.to('cpu').detach().numpy().copy()
          labels = data.y.to('cpu').detach().numpy().copy()
          edge_index = data.edge_index.to('cpu').detach().numpy().copy()

          adj_lists = defaultdict(set)
          for (origin, dest) in zip(edge_index[0], edge_index[1]):
            adj_lists[origin].add(dest)
            adj_lists[dest].add(origin)

          test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])

          setattr(self, dataSet+'_test', test_indexs)
          setattr(self, dataSet+'_val', val_indexs)
          setattr(self, dataSet+'_train', train_indexs)

          setattr(self, dataSet+'_feats', feat_data)
          setattr(self, dataSet+'_labels', labels)
          setattr(self, dataSet+'_adj_lists', adj_lists)

        elif dataSet == 'pubmed':
          dataset = datasets.Planetoid(root=f'/tmp/PubMed', name='PubMed')
          data = dataset[0].to(device)

          #adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
          feat_data = data.x.to('cpu').detach().numpy().copy()
          labels = data.y.to('cpu').detach().numpy().copy()
          edge_index = data.edge_index.to('cpu').detach().numpy().copy()

          adj_lists = defaultdict(set)
          for (origin, dest) in zip(edge_index[0], edge_index[1]):
            adj_lists[origin].add(dest)
            adj_lists[dest].add(origin)

          test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])

          setattr(self, dataSet+'_test', test_indexs)
          setattr(self, dataSet+'_val', val_indexs)
          setattr(self, dataSet+'_train', train_indexs)

          setattr(self, dataSet+'_feats', feat_data)
          setattr(self, dataSet+'_labels', labels)
          setattr(self, dataSet+'_adj_lists', adj_lists)


    def _split_data(self, num_nodes, test_split = 4, val_split = 4):
        rand_indices = np.random.permutation(num_nodes)

        test_size = num_nodes // test_split
        val_size = num_nodes // val_split
        train_size = num_nodes - (test_size + val_size)

        test_indexs = rand_indices[:test_size]
        val_indexs = rand_indices[test_size:(test_size+val_size)]
        train_indexs = rand_indices[(test_size+val_size):]
        
        return test_indexs, val_indexs, train_indexs

In [None]:
#@title Model evluation 
import sys
import os
import torch
import random
import math

from sklearn.utils import shuffle
from sklearn.metrics import f1_score

import torch.nn as nn
import numpy as np

def evaluate(dataCenter, ds, gnn_model, classification_model, device, max_vali_f1, name, cur_epoch):
    nodes_test = getattr(dataCenter, ds+'_test')
    nodes_val = getattr(dataCenter, ds+'_val')
    labels = getattr(dataCenter, ds+'_labels')

    models = [gnn_model, classification_model]

    params = []
    for model in models:
        for param in model.parameters():
            if param.requires_grad:
                param.requires_grad = False
                params.append(param)

    embs = gnn_model(nodes_val)
    logists = classification_model(embs)
    _, predicts = torch.max(logists, 1)
    labels_val = labels[nodes_val]
    assert len(labels_val) == len(predicts)
    comps = zip(labels_val, predicts.data)

    vali_f1 = f1_score(labels_val, predicts.cpu().data, average="micro")
    #print("Validation F1:", vali_f1)

    if vali_f1 > max_vali_f1:
        max_vali_f1 = vali_f1
        embs = gnn_model(nodes_test)
        logists = classification_model(embs)
        _, predicts = torch.max(logists, 1)
        labels_test = labels[nodes_test]
        assert len(labels_test) == len(predicts)
        comps = zip(labels_test, predicts.data)

        test_f1 = f1_score(labels_test, predicts.cpu().data, average="micro")
        print("Test F1:", test_f1)

        for param in params:
            param.requires_grad = True

        torch.save(models, 'models/model_best_{}_ep{}_{:.4f}.torch'.format(name, cur_epoch, test_f1))

    for param in params:
        param.requires_grad = True

    return max_vali_f1

def get_gnn_embeddings(gnn_model, dataCenter, ds):
    print('Loading embeddings from trained GNN model.')
    features = np.zeros((len(getattr(dataCenter, ds+'_labels')), gnn_model.out_size))
    nodes = np.arange(len(getattr(dataCenter, ds+'_labels'))).tolist()
    b_sz = 500
    batches = math.ceil(len(nodes) / b_sz)
    embs = []
    for index in range(batches):
        nodes_batch = nodes[index*b_sz:(index+1)*b_sz]
        embs_batch = gnn_model(nodes_batch)
        assert len(embs_batch) == len(nodes_batch)
        embs.append(embs_batch)
        # if ((index+1)*b_sz) % 10000 == 0:
        #     print(f'Dealed Nodes [{(index+1)*b_sz}/{len(nodes)}]')

    assert len(embs) == batches
    embs = torch.cat(embs, 0)
    assert len(embs) == len(nodes)
    print('Embeddings loaded.')
    return embs.detach()

def train_classification(dataCenter, gnn_model, classification_model, ds, device, max_vali_f1, name, epochs=800):
    print('Training Classification ...')
    c_optimizer = torch.optim.SGD(classification_model.parameters(), lr=0.5)
    # train classification, detached from the current graph
    #classification.init_params()
    b_sz = 100
    train_nodes = getattr(dataCenter, ds+'_train')
    labels = getattr(dataCenter, ds+'_labels')
    features = get_gnn_embeddings(gnn_model, dataCenter, ds)
    for epoch in range(epochs):
        train_nodes = shuffle(train_nodes)
        batches = math.ceil(len(train_nodes) / b_sz)
        visited_nodes = set()
        for index in range(batches):
            nodes_batch = train_nodes[index*b_sz:(index+1)*b_sz]
            visited_nodes |= set(nodes_batch)
            labels_batch = labels[nodes_batch]
            embs_batch = features[nodes_batch]

            logists = classification_model(embs_batch)
            loss = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
            loss /= len(nodes_batch)
            # print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Dealed Nodes [{}/{}] '.format(epoch+1, epochs, index, batches, loss.item(), len(visited_nodes), len(train_nodes)))

            loss.backward()
            
            nn.utils.clip_grad_norm_(classification.parameters(), 5)
            c_optimizer.step()
            c_optimizer.zero_grad()

        max_vali_f1 = evaluate(dataCenter, ds, gnn_model, classification_model, device, max_vali_f1, name, epoch)
    return classification_model, max_vali_f1


## Section 2: DNML Computation Code

In [None]:
#@title DNML Computation
!pip install pyper

"""DNML GMM

Functions to calculate the parametric complexity for DNML.
They are only used for comparison methods.

.. References:
    * Tianyi Wu, Shinya Sugawara, and Kenji Yamanishi. Decposed Normalized
    Maximum Likelihood Codelength Criterion for Selecting Hierarchical Latent
    Variable Models. In Proceedings of the 23rd ACM SIGKDD International
    Conference on Knowledge Discovery and Data Mining. Halifax, Canada,
    1165--1174.
    * Kenji Yamanishi, Tianyi Wu, Shinya Sugawara, and Makoto Okada.
    The Decomposed Normalized Maximum Likelihood Code-Length Criterion for
    Selecting Hierarchical Latent Variable Models. Data Mining and Knowledge
    Discovery, 33, 1017--1058, 2019.
"""

import math
import numpy as np
from scipy.special import logsumexp, loggamma


def _pc_multinomial(N, K):
    """parametric complexity for multinomial distributions.

    Args:
        N (int): number of data.
        K (int): number of clusters.

    Returns:
        float: parametric complexity for multinomial distributions.
    """
    PC_list = [0]

    # K = 1
    if K >= 1:
        PC_list.append(1)

    # K = 2
    if K >= 2:
        r1 = np.arange(N + 1)
        r2 = N - r1
        logpc_2 = logsumexp(sum([
            loggamma(N + 1),
            (-1) * loggamma(r1 + 1),
            (-1) * loggamma(r2 + 1),
            r1 * np.log(r1 / N + 1e-50),
            r2 * np.log(r2 / N + 1e-50)
        ]))
        PC_list.append(np.exp(logpc_2))

    # K >= 3
    for k in range(3, K + 1):
        PC_list.append(
            PC_list[k - 1] + N * PC_list[k - 2] / (k - 2)
        )

    return PC_list[-1]


def _log_pc_gaussian(N_list, D, R, lmd_min):
    """log parametric complexity for a Gaussian distribution.

    Args:
        N_list (np.ndarray): list of the number of data.
        D (int): dimension of data.
        R (float): upper bound of ||mean||^2.
        lmd_min (float): lower bound of the eigenvalues of the covariance
            matrix.

    Returns:
        np.ndarray: list of the parametric complexity.
    """
    N_list = np.array(N_list)

    log_PC_list = sum([
        D * N_list * np.log(N_list / 2 / math.e) / 2,
        (-1) * D * (D - 1) * np.log(math.pi) / 4,
        (-1) * np.sum(
            loggamma((N_list.reshape(-1, 1) - np.arange(1, D + 1)) / 2),
            axis=1
        ),
        (D + 1) * np.log(2 / D),
        (-1) * loggamma(D / 2),
        D * np.log(R) / 2,
        (-1) * D**2 * np.log(lmd_min) / 2
    ])

    return log_PC_list


def log_pc_gmm(K_max, N_max, D, *, R=1e+3, lmd_min=1e-3):
    """log PC of GMM.

    Calculate (log) parametric complexity of Gaussian mixture model.

    Args:
        K_max (int): max number of clusters.
        N_max (int): max number of data.
        D (int): dimension of data.
        R (float): upper bound of ||mean||^2.
        lmd_min (float): lower bound of the eigenvalues of the covariance
            matrix.

    Returns:
        np.ndarray: array of (log) parametric complexity.
            returns[K, N] = log C(K, N)
    """
    log_PC_array = np.zeros([K_max + 1, N_max + 1])
    r1_min = D + 1

    # N = 0
    log_PC_array[:, 0] = -np.inf

    # K = 0
    log_PC_array[0, :] = -np.inf

    # K = 1
    # N <= r1_min
    log_PC_array[1, :r1_min] = -np.inf
    # N > r1_min
    N_list = np.arange(r1_min, N_max + 1)
    log_PC_array[1, r1_min:] = _log_pc_gaussian(
        N_list,
        D=D,
        R=R,
        lmd_min=lmd_min
    )

    # K > 1
    for k in range(2, K_max + 1):
        for n in range(1, N_max + 1):
            r1 = np.arange(n + 1)
            r2 = n - r1
            log_PC_array[k, n] = logsumexp(sum([
                loggamma(n + 1),
                (-1) * loggamma(r1 + 1),
                (-1) * loggamma(r2 + 1),
                r1 * np.log(r1 / n + 1e-100),
                r2 * np.log(r2 / n + 1e-100),
                log_PC_array[1, r1],
                log_PC_array[k - 1, r2]
            ]))

    return log_PC_array


from copy import deepcopy
import sys

sys.path.append('../')

import numpy as np
from tqdm.notebook import tqdm


"""Functions for Gaussian Mixture Model.
"""

import math

import numpy as np
from numpy.random import RandomState
from scipy.special import logsumexp
from scipy.stats import multivariate_normal
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture



class GMMUtils():
    """Useful Functions for Gaussian Mixture Model.
    """

    def __init__(self, rho, means, covariances):
        """
        Args:
            rho (ndarray): Mixture proportion (shape = (K,)).
            means (ndarray): Mean vectors (shape = (K, D)).
            covariances (ndarray): Covariance matrices (shape = (K, D, D)).
        """
        self.rho = rho
        self.means = means
        self.covariances = covariances
        self.K = len(rho)

    def sample(self, N=100, random_state=None):
        """Sample from GMM.

        Args:
            N (int): Number of the points to sample.
            random_state (Optional[int]): Random state.
        Returns:
            ndarray: Sampled points (shape = (N, K)).
        """
        random = RandomState(random_state)
        nk = random.multinomial(N, self.rho)
        X = []
        for mean, cov, size in zip(self.means, self.covariances, nk):
            X_new = multivariate_normal.rvs(
                mean=mean,
                cov=cov,
                size=size,
                random_state=random
            )
            if size == 0:
                pass
            elif size == 1:
                X.append(X_new)
            else:
                X.extend(X_new)
        return np.array(X)

    def logpdf(self, X):
        """Calculate log pdf.

        Args:
            X (ndarray): Data (shape = (N, D)).
        Returns:
            ndarray: Matrix of log pdf (shape = (N, K)).
        """
        N = len(X)
        log_pdf = np.zeros([N, self.K])
        for k in range(self.K):
            log_pdf[:, k] = multivariate_normal.logpdf(
                X,
                self.means[k],
                self.covariances[k],
                allow_singular=True
            )
        return log_pdf

    def prob_latent(self, X):
        """Probability of the latent variables.

        Args:
            X (ndarray): Data (shape = (N, D)).
        Returns:
            ndarray: Matrix of latent probabilities (shape = (N, K)).
        """
        log_pdf = self.logpdf(X)
        log_rho_pdf = np.log(self.rho + 1e-50) + log_pdf
        log_prob = (
            log_rho_pdf -
            logsumexp(log_rho_pdf, axis=1).reshape((-1, 1))
        )
        return np.exp(log_prob)


def _comp_loglike(*, X, Z, rho, means, covariances):
    """complete log-likelihood

    Args:
        X (ndarray): Data (shape = (N, K)).
        Z (ndarray): Latent variables (shape = (N,)).
        rho (ndarray): Mixture proportion (shape = (K,)).
        means (ndarray): Mean vectors (shape = (K, D)).
        covariances (ndarray): Covariance matrices (shape = (K, D, D)).
    Returns:
        float: Complete log likelihood.
    """
    _, D = X.shape
    K = len(means)
    nk = np.bincount(Z, minlength=K)

    if min(nk) <= 0:
        return np.nan
    else:
        c_loglike = 0
        for k in range(K):
            #print(f'current c_loglike = {c_loglike} with k = {k} out of {range(K)}...')
            c_loglike += nk[k] * np.log(rho[k])
            #print(f'nk[k]={nk[k]}; rho[k]={rho[k]}; c_loglike={c_loglike}')
            c_loglike -= 0.5 * nk[k] * D * np.log(2 * math.pi * math.e)
            #print(f'np.log(2 * math.pi * math.e)={np.log(2 * math.pi * math.e)}; c_loglike={c_loglike}')
            c_loglike -= 0.5 * nk[k] * np.log(np.linalg.det(covariances[k]))
            #print(f'covariances[k]={covariances[k]};np.log(np.linalg.det(covariances[k])={np.log(np.linalg.det(covariances[k]))}; c_loglike={c_loglike}')
        return c_loglike


class GMMModelSelection():
    """Model Selection of Gaussian Mixture Model.
    """
    def __init__(
        self,
        K_min=1,
        K_max=20,
        reg_covar=1e-3,
        random_state=None,
        mode='GMM_BIC',
        weight_concentration_prior=1.0,
        tol=1e-3,
        degrees_of_freedom_prior=None,
    ):
        """
        Args:
            K_max (int): Maximum number of the components.
            reg_covar (float): Reguralization for covariance.
            random_state (Optional[int]): Random state.
            mode (str): Estimation mode. Choose from the following:
                'GMM_BIC' (EM algorithm + BIC)
                'GMM_DNML' (EM algorithm + DNML)
                'BGMM' (Variational Bayes based on Dirichlet distribution).
            weight_concentration_prior (float): Weight concentration prior
                for BGMM.
            tol (float): Tolerance for GMM convergence.
        """
        self.K_max = K_max
        self.K_min = K_min
        self.reg_covar = reg_covar
        self.random_state = random_state
        self.mode = mode
        self.weight_concentration_prior = weight_concentration_prior
        self.tol = tol
        self.degrees_of_freedom_prior = degrees_of_freedom_prior

    def fit(self, X):
        """Select the best model.

        Args:
            X (ndarray): Data (shape = (N, K)).
        """
        
        if self.mode == 'GMM_DNML':
            log_pc_array = log_pc_gmm(
                K_max=self.K_max,
                N_max=X.shape[0],
                D=X.shape[1]
            )

        if self.mode in ['GMM_BIC', 'GMM_DNML']:
            model_list = []
            criterion_list = []
            for K in range(self.K_min, self.K_max + 1):
                #print(f'Fitting a GMM with K={K} components...')
                # Fit
                model_new = GaussianMixture(
                    n_components=K,
                    reg_covar=self.reg_covar,
                    random_state=self.random_state,
                    n_init=10,
                    tol=self.tol,
                    max_iter=10000
                )
                model_new.fit(X)
                model_list.append(model_new)
                # Calculate information criterion
                if self.mode == 'GMM_BIC':
                    criterion_list.append(model_new.bic(X))
                elif self.mode == 'GMM_DNML':
                    Z = model_new.predict(X)
                    loglike = _comp_loglike(
                        X=X,
                        Z=Z,
                        rho=model_new.weights_,
                        means=model_new.means_,
                        covariances=model_new.covariances_
                    )
                    complexity = np.log(_pc_multinomial(len(X), K))
                    for k in range(K):
                        Z_k = sum(Z == k)
                        if log_pc_array[1, Z_k] != - np.inf:
                            complexity += log_pc_array[1, Z_k]
                    criterion_list.append(- loglike + complexity)
                    #print(f'loglike:{loglike}, complexity:{complexity}')
                    #print(criterion_list)
            idx_best = np.nanargmin(criterion_list)
            self.model_best_ = model_list[idx_best]

        elif self.mode == 'BGMM':
            self.model_best_ = BayesianGaussianMixture(
                n_components=self.K_max,
                reg_covar=self.reg_covar,
                random_state=self.random_state,
                weight_concentration_prior=self.weight_concentration_prior,
                weight_concentration_prior_type='dirichlet_distribution',
                max_iter=10000,
                n_init=10,
                tol=self.tol,
                degrees_of_freedom_prior=self.degrees_of_freedom_prior
            )
            self.model_best_.fit(X)
        else:
            raise ValueError('methods should be GMM_BIC, GMM_DNML or BGMM.')

        self.K_ = self.model_best_.n_components
        self.rho_ = self.model_best_.weights_
        self.means_ = self.model_best_.means_
        self.covariances_ = self.model_best_.covariances_
        return(criterion_list)

    def prob_latent(self, X):
        """Probability of the latent variables.

        Args:
            X (ndarray): Data (shape = (N, D)).
        Returns:
            ndarray: Matrix of latent probabilities (shape = (N, K)).
        """
        analysis = GMMUtils(
            rho=self.rho_,
            means=self.means_,
            covariances=self.covariances_
        )
        return analysis.prob_latent(X)

    def predict(self, X):
        """Predict latent labels.

        Args:
            X (ndarray): Data (shape = (N, D)).
        Returns:
            ndarray: predicted labels (shape = (N,)).
        """
        prob_latent_ = self.prob_latent(X)
        return np.argmax(prob_latent_, axis=1)

    
def experiment_gmm_dnml(X, Z_true, repeat=1):
    
    X = X
    Z_true = Z_true
    
    def one_trial(seed):
        gmm = GMMModelSelection(K_min=num_classes, K_max=num_classes, mode='GMM_DNML', random_state=seed, tol=1e-3)
        l_ = gmm.fit(X)
        #Z_pred = gmm.predict(X)
        K_ = gmm.K_
        #f_measure_ = f_measure(Z_true, Z_pred)
        #ari_ = ari(Z_true, Z_pred)

        return K_, l_ # f_measure_, ari_,

    K_list = []
    #f_measure_list = []
    #ari_list = []
    l_list = []

    for t in tqdm(range(repeat), leave=False):
        K_, l_ = one_trial(seed=t) #f_measure_, ari_,
        #K_list.append(K_)
        #f_measure_list.append(f_measure_)
        #ari_list.append(ari_)
        l_list.append(l_)

    #print('--- average score ---')
    #print(f'  K:         {np.mean(K_list)}')
    #print(f'  DNML:         {np.mean(l_list)}')
    #print(f'  f_measure: {np.mean(f_measure_list)}')
    #print(f'  ARI:       {np.mean(ari_list)}')
    return l_list #K_list, f_measure_list, ari_list, 

In [None]:
#@title Conditional Reversion 
!pip install dgl

class InnerProductDecoder(torch.nn.Module):
    r"""The inner product decoder from the `"Variational Graph Auto-Encoders"
    <https://arxiv.org/abs/1611.07308>`_ paper

    .. math::
        \sigma(\mathbf{Z}\mathbf{Z}^{\top})

    where :math:`\mathbf{Z} \in \mathbb{R}^{N \times d}` denotes the latent
    space produced by the encoder."""
    def forward(self, z, edge_index, sigmoid=True):
        r"""Decodes the latent variables :obj:`z` into edge probabilities for
        the given node-pairs :obj:`edge_index`.

        Args:
            z (Tensor): The latent space :math:`\mathbf{Z}`.
            sigmoid (bool, optional): If set to :obj:`False`, does not apply
                the logistic sigmoid function to the output.
                (default: :obj:`True`)
        """
        value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1) 
        return torch.sigmoid(value) if sigmoid else value


    def forward_all(self, z, sigmoid=True):
        r"""Decodes the latent variables :obj:`z` into a probabilistic dense
        adjacency matrix.

        Args:
            z (Tensor): The latent space :math:`\mathbf{Z}`.
            sigmoid (bool, optional): If set to :obj:`False`, does not apply
                the logistic sigmoid function to the output.
                (default: :obj:`True`)
        """
        adj = torch.matmul(z, z.t())# F.cosine_similarity(z, z) #torch.matmul(z, z.t()) # F.cosine_similarity
        return torch.sigmoid(adj) if sigmoid else adj
    

class CosineSimilarityDecoder(torch.nn.Module):
    r"""The inner product decoder from the `"Variational Graph Auto-Encoders"
    <https://arxiv.org/abs/1611.07308>`_ paper

    .. math::
        \sigma(\mathbf{Z}\mathbf{Z}^{\top})

    where :math:`\mathbf{Z} \in \mathbb{R}^{N \times d}` denotes the latent
    space produced by the encoder."""


    def forward_all(self, z, sigmoid=True):
        r"""Decodes the latent variables :obj:`z` into a probabilistic dense
        adjacency matrix.

        Args:
            z (Tensor): The latent space :math:`\mathbf{Z}`.
            sigmoid (bool, optional): If set to :obj:`False`, does not apply
                the logistic sigmoid function to the output.
                (default: :obj:`True`)
        """
        #adj = torch.matmul(z, z.t())# F.cosine_similarity(z, z) #torch.matmul(z, z.t()) # F.cosine_similarity
        adj = torch.empty(0,z.shape[0])
        for n in range(z.shape[0]):
            if n % 1000 == 0: 
                print(f'decoded {n} nodes')
            score = F.cosine_similarity(z[n], z) 
            adj = torch.cat([adj, score[None, :]], axis=0)
        return torch.sigmoid(adj) if sigmoid else adj
    

def th_delete(tensor, indices):
    mask = th.ones(tensor.numel(), dtype=th.bool)
    mask[indices] = False
    return tensor[mask]


In [None]:
#@title Embeddings Selection
import pickle
from itertools import chain
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE
import seaborn as sns
import pandas as pd

# conduct augmented embeddings selection
class EmbeddingsSelection():

  def __init__(
      self,
      emb,
      num_classes,
      true_labels,
      prop_B=0.1, 
      num_candidates=500,

      ):
    self.prop_B = prop_B
    self.emb = emb
    self.num_classes = num_classes
    self.B = math.ceil(self.prop_B*self.emb.shape[0])
    self.num_candidates = num_candidates
    self.true_labels = true_labels

    # create placeholders
    self.lst_emb_samples = []
    self.lst_label_samples = []
    self.lst_augmented_emb_dnml = []


    self.pretrained_gmm = GaussianMixture(n_components=self.num_classes, random_state=42).fit(self.emb)
    self.gmm_labels = self.pretrained_gmm.predict(self.emb)

    # mapping between GMM cluster label and true class label through majority vote
    self.df_mapping = pd.DataFrame()
    self.df_mapping["gmm_labels"] = self.gmm_labels
    self.df_mapping["true_labels"] = self.true_labels

    self.df_mapping = self.df_mapping.groupby(['gmm_labels', 'true_labels']).size()
    self.df_mapping = pd.DataFrame(self.df_mapping).reset_index()
    self.df_mapping = self.df_mapping.loc[self.df_mapping.reset_index().groupby(['true_labels'])[0].idxmax()]

    self.gmm_lbl = np.array(self.df_mapping.gmm_labels)  
    self.true_lbl = np.array(self.df_mapping.true_labels)


  # Getting GMM predicted label accuracy for all data
  def _get_GMM_label_accuracy(self):
    df_gmm_acc = pd.DataFrame()
    df_gmm_acc["gmm_labels"] = self.gmm_labels
    df_gmm_acc["true_labels"] = self.true_labels

    #GMM label : true class label 
    label_map_dict = {self.gmm_lbl[i]: self.true_lbl[i] for i in range(len(self.true_lbl))}

    # map GMM labels to its corresponding true class label 
    df_gmm_acc = df_gmm_acc.replace({"gmm_labels": label_map_dict})

    print(f"Accuracy of GMM predicted labels is: {len(df_gmm_acc[df_gmm_acc.gmm_labels==df_gmm_acc.true_labels])/len(df_gmm_acc)}")


  def compute_dnml(self):
    all_samples = self.pretrained_gmm.sample(n_samples=self.num_candidates * self.B)

    sample_df = pd.DataFrame(all_samples[0])
    sample_df['label'] = all_samples[1]

    for _ in range(self.num_candidates): 

        sub_sample_df = sample_df.sample(self.B, replace=False)

        emb_sample = sub_sample_df.iloc[: , :-1]
        label_sample = sub_sample_df.iloc[: , -1]

        out = np.zeros_like(label_sample)
        for gl,tl in zip(self.gmm_lbl,self.true_lbl):
            out[label_sample==gl] = tl
        label_sample = out

        #print(f'emb_sample:{emb_sample.shape}; label_sample:{label_sample.shape}, sub_sample_df: {sub_sample_df.shape}')
        self.lst_emb_samples.append(emb_sample)
        self.lst_label_samples.append(label_sample)


        dnml_list = experiment_gmm_dnml(X=emb_sample, Z_true=label_sample)
        self.lst_augmented_emb_dnml.append(dnml_list)

        if _ % 20 == 0:
          print(f"Finish {_}th candidate.")
    
    # placeholders    
    aug_full_emb_list = []
    aug_full_lbl_list = []
    for i in range(self.num_candidates):
        aug_full_lbl_list.append(np.concatenate([self.true_labels, self.lst_label_samples[i]]))
        aug_full_emb_list.append(np.concatenate([self.emb, self.lst_emb_samples[i]]))

    lst_aug_emb_dnml_mean = np.mean(self.lst_augmented_emb_dnml, axis=1)
    lst_aug_emb_dnml_mean = list(chain(*lst_aug_emb_dnml_mean))

    # get best graph embeddings' index
    self.dnml_min_id = np.argmin(lst_aug_emb_dnml_mean)

    # best performer min
    self.dnml_min_emb = aug_full_emb_list[self.dnml_min_id]
    self.dnml_min_lbl = aug_full_lbl_list[self.dnml_min_id]
    self.dnml_min_lbl = torch.tensor(self.dnml_min_lbl).type(torch.LongTensor)

    del aug_full_lbl_list, aug_full_emb_list

    # return the selected full augmented embeddings, and its labels 
    return(self.dnml_min_emb, self.dnml_min_lbl, self.B)

    """with open(EXP_RESULT_DIR + f'/{ds}_aug_emb_dnml_mean_C={n_candidates}_prop_B={prop_B}', 'wb') as fp:
        pickle.dump(lst_aug_emb_dnml_mean, fp)"""

  def _get_graph_info(self):
    return(self.dnml_min_id, self.dnml_min_emb, self.dnml_min_lbl, self.df_mapping)

  def _plot_tsne(self):
    print("Plotting TSEN project of the best augmented embeddings...")

    tsne = TSNE(n_components=2, verbose=1, random_state=123)
    emb_tsne = tsne.fit_transform(self.dnml_min_emb) 

    df = pd.DataFrame()
    #df["y"] = self.true_labels
    df["comp-1"] = emb_tsne[:,0]
    df["comp-2"] = emb_tsne[:,1]
    df["gmm_labels"] = self.dnml_min_lbl

    plt.figure(figsize=(5, 5))
    ax = sns.scatterplot(x="comp-1", y="comp-2", hue=df.gmm_labels.tolist(),
                    palette=sns.color_palette("hls", k_hat),
                    data=df)

    ax.get_legend().remove()
    ax.set(title=f"Selected Embeddings GMM Labels \n ($C\%={self.num_candidates}$, {base_model_name})")
    fig = ax.get_figure()
    return fig



## Section 3: Downstream Node Classificaiton Preparation

In [None]:
import os
import torch
#os.environ['TORCH'] = torch.__version__
os.environ['TORCH'] = '1.13.0+cu116' # now the system is 1.13.1+cu116, but no precompiled wheel of it
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git


#@title Define downstream GCN, GAT, GraphSAGE  
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv, SAGEConv

class GAT(torch.nn.Module):
    def __init__(self,num_node_features, num_classes, emb_dim=16):
        super(GAT, self).__init__()
        self.hid = emb_dim
        self.in_head = 8
        self.out_head = 1
        
        
        self.conv1 = GATConv(num_node_features, self.hid, heads=self.in_head, dropout=0.6)
        self.conv2 = GATConv(self.hid*self.in_head, num_classes, concat=False,
                             heads=self.out_head, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
                
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        
        return F.log_softmax(x, dim=1)


class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes, emb_dim=128):
        super().__init__()
        self.hid = emb_dim
        self.conv1 = GCNConv(num_node_features, self.hid)
        self.conv2 = GCNConv(self.hid, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

class GraphSAGE(torch.nn.Module):
  """GraphSAGE"""
  def __init__(self, num_node_features, num_classes, emb_dim=8):
    super().__init__()
    self.hid = emb_dim
    self.sage1 = SAGEConv(num_node_features, self.hid)
    self.sage2 = SAGEConv(self.hid, num_classes)

  def forward(self, data):
    x, edge_index = data.x, data.edge_index

    h = self.sage1(x, edge_index).relu()
    h = F.dropout(h, p=0.5, training=self.training)
    h = self.sage2(h, edge_index)
    return F.log_softmax(h, dim=1)


In [None]:
#@title Node Classification
from torch_geometric import datasets
from torch_geometric.utils.convert import to_scipy_sparse_matrix
# Implementation of matplotlib spy function
import matplotlib.pyplot as plt
import numpy as np
import warnings

class NodeClassification():
  """Downstream Node Classification"""
  def __init__(
        self,
        B,
        emb,
        ds,
        gmm_lbl,
        generate_label=False,
        threasold=0.75
    ):
    self.emb = emb
    self.gmm_lbl = gmm_lbl.to(device)
    self.B = B
    self.decoder = CosineSimilarityDecoder()
    if ds in ['cora-full']: #cora-full
      #self.org_dataset = datasets.CoraFull(root=f'/tmp/{ds}')
      with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-data", 'rb') as fp:
          self.org_data = pickle.load(fp)
      with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-dataset", 'rb') as fp:
          self.org_dataset = pickle.load(fp)
    elif ds in ['Cora', 'PubMed', 'Citeseer']:
      self.org_dataset = datasets.Planetoid(root=f'/tmp/{ds}', name=ds)
    elif ds in ['coauthor-cs']:
      #self.org_dataset = datasets.Coauthor(root=f'/tmp/{ds}', name='CS')
      with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-data", 'rb') as fp:
          self.org_data = pickle.load(fp)

      with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-dataset", 'rb') as fp:
          self.org_dataset = pickle.load(fp)

    elif ds in ['coauthor-physics']:
      #self.org_dataset = datasets.Coauthor(root=f'/tmp/{ds}', name='Physics')
      with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-data", 'rb') as fp:
          self.org_data = pickle.load(fp)

      with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-dataset", 'rb') as fp:
          self.org_dataset = pickle.load(fp)


    self.org_data = self.org_data.to(device)
    self.org_data.edge_index = self.org_data.edge_index.to(device)
    self.org_adj = to_scipy_sparse_matrix(self.org_data.edge_index)

    self.num_classes = self.org_dataset.num_classes

    # decode and revert back to adjacency matrix 
    self.decode_embeddings(self.emb, threasold)



  def decode_embeddings(self, emb, threasold):
    self.emb = torch.tensor(emb)

    self.aug_g = self.decoder.forward_all(self.emb)
    self.aug_g = self.aug_g.fill_diagonal_(0)   

    self.aug_g = (self.aug_g>=threasold).float()
    #print(f'shape of  self.aug_g is{ self.aug_g.shape}')

    val = torch.from_numpy(self.org_adj.data.astype(np.float64)).to(torch.float).to(device)    #Presuming values are floats, can use np.int64 for dtype=int8
    out = torch.sparse.FloatTensor(self.org_data.edge_index, val, torch.Size(self.org_adj.shape)).to_dense() 

    self.aug_g[:self.org_adj.shape[0],:self.org_adj.shape[1]] = out
    #print(f'shape of  self.aug_g is{ self.aug_g.shape}')


  def _plot_adjacencies(self, markersize=0.5):
    figure, axis = plt.subplots(1, 2, figsize=(15, 15))
      
    # plot original adjacency matrix
    axis[0].spy(self.org_adj, markersize = markersize)
    axis[0].set_title("Adjacency Plot of Original Matrix")

    # plot conditionally reverted augmented adjacency matrix
    axis[1].spy(self.aug_g, markersize = markersize)
    axis[1].set_title("Adjacency Plot of Conditionally Reverted Augmented Matrix")

  def _data_prepare(self, generate_label):  
    self.aug_data = self.org_data.clone() #self.org_dataset[0].to(device)
    # prepare placeholders
    # building masks for training-validation-test sets 
    unmask = torch.zeros(self.B, dtype=torch.bool).to(device)    
    
    aug_y = torch.zeros(self.B).type(torch.int64)

    aug_x = self.aug_data.x.clone()
    aug_nodes_features = torch.zeros_like(torch.empty(self.B, aug_x.shape[1]))

    # prepare masks 
    mask = unmask == 0
    if generate_label:
      train_mask = self.aug_data.train_mask.clone()
      # add augmented nodes to train set only if generate_label=true
      train_mask = torch.cat((train_mask, mask), -1)
    else:
      train_mask = self.aug_data.train_mask.clone()
      train_mask = torch.cat((train_mask, unmask), -1)

    test_mask = self.aug_data.test_mask.clone()
    test_mask = torch.cat((test_mask, unmask), -1)

    val_mask = self.aug_data.val_mask.clone()
    val_mask = torch.cat((val_mask, unmask), -1)

    self.aug_data.train_mask = train_mask
    self.aug_data.test_mask = test_mask
    self.aug_data.val_mask = val_mask

    print(f"""
      Size of augmented training set: {self.aug_data.train_mask.sum()}, size of original training set: {self.org_data.train_mask.sum()}; 
      Size of augmented test set: {self.aug_data.test_mask.sum()}, size of original test set: {self.org_data.test_mask.sum()};
      Size of augmented validation set: {self.aug_data.val_mask.sum()}, size of original validation set: {self.org_data.val_mask.sum()}.
    """)

    # prepare edge index
    self.aug_data.edge_index = self.aug_g.nonzero().t().contiguous()

    # prepare y [not in use]
    # add 0's as labels to all augmented nodes [will be masked]
    #self.aug_data.y = torch.cat((self.aug_data.y, aug_y.to(device)), -1)
    self.aug_data.y = self.gmm_lbl

    # prepare features

    # make 0's for all augmented nodes in original feature space 
    aug_x = torch.cat((aug_x, aug_nodes_features.to(device)), 0)
        
    # make 0's for all all nodes in new feature space 
    all_nodes_features = torch.zeros_like(torch.empty(aug_x.shape[0], self.B))
    aug_x = torch.cat((aug_x, all_nodes_features.to(device)), 1)

    # add identity matrix to the augmeneted nodes in new feature space 
    aug_x[self.org_data.x.shape[0]:, self.org_data.x.shape[1]:] = torch.eye(self.B)

    self.aug_data.x = aug_x

    # double check the nodes size 
    assert aug_x.shape[0] == self.aug_data.num_nodes
    del unmask, mask, train_mask, test_mask, val_mask, aug_y, aug_x
    self.aug_data = self.aug_data.to(device)

    return(self.aug_data, self.org_data)
    
  def train_gnn(self, gnn_type='GAT', lr=0.005, weight_decay=5e-4, epochs=200):
    if gnn_type == 'GAT':
      self.model = GAT(
          num_node_features = self.aug_data.num_node_features, 
          num_classes=self.num_classes).to(device)
    elif gnn_type == 'GCN':
      self.model = GCN(
          num_node_features = self.aug_data.num_node_features, 
          num_classes=self.num_classes).to(device)
    elif gnn_type == 'GraphSAGE':
      self.model = GraphSAGE(
          num_node_features = self.aug_data.num_node_features, 
          num_classes=self.num_classes).to(device)
    else: 
      warnings.simplefilter("The requested method has not been implemented.")

    print(f'Number of training nodes: {self.aug_data.train_mask.sum()}')
    print(f'Number of validation nodes: {self.aug_data.val_mask.sum()}')
    print(f'Number of test nodes: {self.aug_data.test_mask.sum()}')

    optimizer = torch.optim.Adam(
        self.model.parameters(), 
        lr=lr, 
        weight_decay=weight_decay
        )

    losses = []
    for epoch in range(epochs):
        self.model.train()
        optimizer.zero_grad()
        out = self.model(self.aug_data)
        #print(f'out={out.shape}')
        #loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        #print(f'out[self.aug_data.train_mask]={out[self.aug_data.train_mask].shape}')
        #print(f'self.aug_data.y[self.aug_data.train_mask]={self.aug_data.y[self.aug_data.train_mask].shape}')
        loss = F.nll_loss(out[self.aug_data.train_mask], self.aug_data.y[self.aug_data.train_mask])
        losses.append(loss.detach().cpu().numpy())

        if epoch % 50 == 0:
            print(loss)
        
        loss.backward()
        optimizer.step()

    # print model evaluation 
    self.model.eval()
    _, pred = self.model(self.aug_data).max(dim=1)
    correct = float(pred[self.aug_data.train_mask].eq(self.aug_data.y[self.aug_data.train_mask]).sum().item())
    self.train_acc = correct / self.aug_data.train_mask.sum().item()
    print('Training Set Accuracy: {:.4f}'.format(self.train_acc))

    correct = float(pred[self.aug_data.test_mask].eq(self.aug_data.y[self.aug_data.test_mask]).sum().item())
    self.test_acc = correct / self.aug_data.test_mask.sum().item()
    print('Test Set Accuracy: {:.4f}'.format(self.test_acc))

  def _get_graph_info(self):
    num_edges = len(self.aug_data.edge_index[0])
    train_data_size = self.aug_data.train_mask.sum().item()
    test_data_size = self.aug_data.test_mask.sum().item()
    val_data_size = self.aug_data.val_mask.sum().item()
    num_nodes = self.aug_data.x.shape[0]
    num_features = self.aug_data.x.shape[1]

    return(self.train_acc, self.test_acc, num_edges, train_data_size, test_data_size, val_data_size, num_nodes, num_features)


# ⛳️ GMMDA Data Augmentation and Downstream Node Classification 

In [None]:
from torch_geometric import datasets
from torch_geometric.utils.convert import to_scipy_sparse_matrix
import gc

import numpy as np
from scipy.sparse import csr_matrix
import pickle

print(f'The loading dataset is {ds}')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# visualization settings 
font = {'size'   : 13}
plt.rc('font', **font)

# GNN settings 
gnn_type = 'GCN' #'GCN', 'GraphSAGE'
num_trails = 5
prop_B = 0.07
num_candidates = 1 #200 #500

with open(MODEL_DIR + f'/{gnn_type}_aux_gmm_v2', 'rb') as fp:
    aux_gmm = pickle.load(fp)

dataCenter = DataCenter()
dataCenter.load_dataSet(ds)

# create save directory 
temp_dir = EXP_RESULT_DIR + f'/{gnn_type}_B%={prop_B}_C={num_candidates}_trails={num_trails}'
print(f"[*] Experiments of B%={prop_B} C={num_candidates} #trails={num_trails} [{gnn_type}] will be saved at {temp_dir}")
!mkdir -p $temp_dir

temp_model_dir = temp_dir + f'/model'
!mkdir -p $temp_model_dir

# load preprocessed data
with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-data", 'rb') as fp:
    data = pickle.load(fp)

with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-dataset", 'rb') as fp:
    dataset = pickle.load(fp)

num_classes = dataset.num_classes

adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
raw_features = data.x.to('cpu').detach().numpy().copy()
labels = data.y.to('cpu').detach().numpy().copy()

# number of true classes
k_hat = len(np.unique(labels))
# all nodes indices 
nodes = np.arange(0, len(labels), 1) 

# make placeholders 
lst_embs_aux_gmm = []
lst_best_emb = []
lst_gmm_lbl = []

df_evaluation_no_aug_label = pd.DataFrame(columns=[
    'train_acc', 
    'test_acc', 
    'num_nodes', 
    'num_edges', 
    'train_data_size', 
    'test_data_size',
    'val_data_size',
    'num_features'])

df_evaluation_aug_label = pd.DataFrame(columns=[
    'train_acc', 
    'test_acc', 
    'num_nodes', 
    'num_edges', 
    'train_data_size', 
    'test_data_size',
    'val_data_size',
    'num_features'])


for i in range(num_trails):
  # free up some space 
  if torch.cuda.is_available():
        torch.cuda.empty_cache()
  gc.collect()

  # retrain model 
  """if i % 15 == 0: 
    with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-data", 'rb') as fp:
        data = pickle.load(fp)

    with open(f"/content/drive/MyDrive/gda_dnml/{ds}/data/{ds}-preprocessed-dataset", 'rb') as fp:
        dataset = pickle.load(fp)

    adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
    raw_features = data.x.to('cpu').detach().numpy().copy()
    labels = data.y.to('cpu').detach().numpy().copy()

    # start training 
    aux_gmm = GNN(
        adj_matrix, 
        features=raw_features, 
        labels=labels, 
        batch_size=raw_features.shape[0], #len(train_nodes)-1, #len(labels)-1, 
        emb_size = 64,
        epochs = 5, 
        learn_method='aux', 
        model=gnn_type.lower(), 
        print_progress=True
    )

    # train the model
    aux_gmm.fit()

    #save to local 
    with open(temp_model_dir + f'/{gnn_type}_aux_gmm_v2_itr={i}', 'wb') as fp:
      pickle.dump(aux_gmm, fp)"""

  embs_aux_gmm = aux_gmm.generate_embeddings(nodes)
  lst_embs_aux_gmm.append(embs_aux_gmm)

  #save to local 
  with open(temp_dir + f'/{gnn_type}_lst_embs_aux_gmm', 'wb') as fp:
    pickle.dump(lst_embs_aux_gmm, fp)

  # conduct embedding selection based on DNML criteria
  es = EmbeddingsSelection(
      emb=embs_aux_gmm, 
      prop_B=prop_B,
      num_classes=dataset.num_classes, 
      true_labels=labels,
      num_candidates=num_candidates
  )
  best_emb, gmm_lbl, B = es.compute_dnml()

  lst_best_emb.append(best_emb)
  lst_gmm_lbl.append(gmm_lbl)

  #save to local 
  with open(temp_dir + f'/{gnn_type}_lst_best_emb', 'wb') as fp:
    pickle.dump(lst_best_emb, fp)

  with open(temp_dir + f'/{gnn_type}_lst_gmm_lbl', 'wb') as fp:
    pickle.dump(lst_gmm_lbl, fp)

  #fig = es._plot_tsne()
  #fig.savefig(tsne_dnml_dir + f'/TSNE_of_C={num_candidates}_Bpctg={prop_B}_2.png')

  # performance of no-augmented-class-label 
  nc = NodeClassification(
      B=B, 
      emb=best_emb, 
      ds=ds,
      gmm_lbl=gmm_lbl
      ) # set threshold >1 to get original graph performance #0.72
  #nc._plot_adjacencies()
  aug_data_no_aug_label, _ = nc._data_prepare(generate_label=False)
  nc.train_gnn(gnn_type=gnn_type)
  #aug_data_no_aug_label

  train_acc, test_acc, num_edges, train_data_size, \
  test_data_size, val_data_size, num_nodes, num_features = nc._get_graph_info()

  df_evaluation_no_aug_label = df_evaluation_no_aug_label.append({
    'train_acc': train_acc, 
    'test_acc': test_acc, 
    'num_nodes': num_nodes, 
    'num_edges': num_edges, 
    'train_data_size': train_data_size, 
    'test_data_size': test_data_size,
    'val_data_size': val_data_size,
    'num_features': num_features}, ignore_index=True)

  df_evaluation_no_aug_label.to_csv(temp_dir+f'/df_evaluation_no_aug_label.csv', index=True)

  # performance of augmented-class-label 
  aug_data_aug_label, _ = nc._data_prepare(generate_label=True)
  nc.train_gnn(gnn_type=gnn_type)

  train_acc, test_acc, num_edges, train_data_size, \
  test_data_size, val_data_size, num_nodes, num_features = nc._get_graph_info()

  df_evaluation_aug_label = df_evaluation_aug_label.append({
    'train_acc': train_acc, 
    'test_acc': test_acc, 
    'num_nodes': num_nodes, 
    'num_edges': num_edges, 
    'train_data_size': train_data_size, 
    'test_data_size': test_data_size,
    'val_data_size': val_data_size,
    'num_features': num_features}, ignore_index=True)

  df_evaluation_aug_label.to_csv(temp_dir+'/df_evaluation_aug_label.csv', index=True)


  #aug_data_aug_label