# Self-supervised learning method for Heterogeneous Graph


## Modeling
Follow the concept of autoencoder

* Encoder $f_{\theta}$ using GNNs backbone, encode each $v_i \in \mathcal{V}$ into $h_i \in H$
* Pretext decoder $p_{\phi}$ takes $H$ as input for prextext task  
    *Pretext task is a definition of a problem that the model is optimized based on it in the concept of self-supervised learning*

After the SSL training, the trained model is used for downstreaming task
* $q_{\psi}$ is downstream decoder


The objective functions include
* Objective function for Graph SSL
$$
\theta^*, \phi^* = \argmin_{\theta, \phi} \mathcal{L_{ssl}}(f_{\theta}, p_{\phi}, \mathcal{D})
$$
where $(\mathcal{V}, \mathcal{E}) \sim \mathcal{D}$


* Objective function for Graph SL (downstreaming task)
$$
\theta^{**}, \psi^* = \argmin_{\theta^*, \psi} \mathcal{L_{sup}}(f_{\theta^*}, q_{\psi}, \mathcal{G}, y)
$$
where $y$ is the label


## Self-supervised learning methods
### Categories
Therer are 4 categories:
* Generation-based: this kind of method reconstruct the graph data by either *features* or *graph structure*.  
Hence $\mathcal{L_{ssl}}$ is usually defined to measure the difference between the original and reconstructed graph data. The representative is GAE. 
* Auxiliary Property-based: this kind of method define the pretext task following regression or classification. The labels are pseudo from graph itself. For instance,  
    * regression problem optimizes based the node degress, or distance to the cluter. Hence $\mathcal{L_{ssl}}$ can be MSE
    * classification problem optimized based on graph partitions, cluster indices. Hence $\mathcal{L_{ssl}}$ can be CE$

    The representative for this is M3S

* Contrastive-based method: the original $\mathcal{G}$ is augmented into $\mathcal{G}^1$ and $\mathcal{G}^2$. $\mathcal{L_{ssl}}$ is defined as contrastive loss, which is similar to the idea of triplet loss. This is call *Mutual Information (MI)* maximization. Which is, the objective function maximizes MI the positive pairs of two augmented graphs and vice versa for the negative pairs

    Representive for this are DGI, GraphCL, GCC


* Hybrid: combine the above methods using weighted or unweighted averaging. The representive is GMI, optimize the reconstruction of edge level (generation-based) and the contrast of node level (contrastive-based)


### Objective function formulation
#### Generation-based method
$$
\theta^*, \phi^* = \argmin_{\theta, \phi} \mathcal{L_{ssl}}(p_{\phi}(f_{\theta}(\hat{\mathcal{G}})), \mathcal{G})
$$
TODO: GAE brief

There are 2 sub-categories: reconstruct the *features*, or, reconstruct the *structure*

##### Reconstruct the features
The target objective function is to minimize the generated vs original features, which is
$$
\theta^*, \phi^* = \argmin_{\theta, \phi} \mathcal{L_{ssl}}(p_{\phi}(f_{\theta}(\hat{\mathcal{G}})), \hat{X})
$$

$\hat{X}$ is the definition of features such as node feature matrix, edge feature matrix, low-dimension feature, etc.

Inspired from inpainting problem in Computer Vision, *masked feature regression* is applied in graph. The adjacency matrix $A$ is masked with a binary matrix. Then, the model learn to reconstruct the masked graph to the original one. **Hence, the representation often captures the node-level knowledge**

Hence, from the above Equation, $\hat{G} = (A, \hat{X})$, and $\hat{X} = X$, which is, the encoder learns from a masked original graph data, and predict the masked nodes / edges attributes using the neighbor information. As a result, this method requires to train the graph as a whole, or a large sub-graph to ensure the effective learning. The representive is *Graph Completion*. Additionally, methods such as AttrMasking learns from both nodes and edges attributes, then, $\hat{X} = [X, X_{edge}]$

Other method finds it hard to learn from orignal attributes, these are transformed into a lower dimension instead, which can be PCA (*AttributeMask*). Hence, $\hat{X} = PCA(X)$.

As generation-based method usually relies on Autoencoder architecture, some variants of its can be applicable, such as denoising autoencoder. The representative is *MGAE*. Hence, $\hat{G} = (A, \bar{X})$, where $\bar{X} = \hat{X} + S_{noise}$ where $S_{noise} \sim \mathcal{N}$ is random noise added.

TODO: Graph Completion, AttributeMask, MGAE


##### Reconstruct the structure
The objective function is as follow
$$
\theta^*, \phi^* = \argmin_{\theta, \phi} \mathcal{L_{ssl}}(p_{\phi}(f_{\theta}(\hat{\mathcal{G}})), A)
$$

The idea is to reconstruct the topology of knowledge graph which is represents by adjeacency matrix $A$. The encoder learns from $\hat{\mathcal{G}}$ to $H$, then the decoder uses $H$ to reconstruct the node features. The reconstruction of structure is measure by predicting whether there is an edge between nodes or not. **Hence, the representation often captures the node-pair level information**

Particularly, the problem relaxes as a binary classification and can use the BCE to optimize it. The representative is *GAE*, other variants are *VGAE*, *SIG-VAE*, *ARGA / ARVGA* (integrated with GAN concept). This kind of method can suffer the imbalance problem, but can be addressed by re-weighting, or oversampling the minority classes.

Instead of learning and predicting the whole nodes in the graph, other methods randomly drop some edges, then the model learns to generalize those edges. BCE is also applied to learn and minimize between connected nodes. The representatives are *Denoising Link Reconstruction*, *EdgeMask*


### Auxiliary Property-based method
$$
\theta^*, \phi^* = \argmin_{\theta, \phi} \mathcal{L_{ssl}} (p_{\phi}(f_{\theta}(\mathcal{G})), c)
$$

where $c$ is the properties that are designed based on whether the problem is regression or classification  
TODO: M3S brief


#### Classification task
##### Clustering: 
define the pseudo label for each node by a mapping function $\Omega : \mathcal{V} \to \mathcal{C}$.

$\Omega$ is one of clustering methods

> My opinion: this method significantly relies on the clustering algorithm unless $\Omega$ is trained along with graph model

The particular optimization can be clarified as
$$
\theta^*, \phi^* = \frac{1}{|\mathcal{V}|}\sum_{\mathcal{v}_i \in \mathcal{V}}^{} \mathcal{L}_{ce} (p_{\phi}(f_{\theta}(\mathcal{\mathcal{G}_{\mathcal{v}_i}})), \Omega(\mathcal{v}_i))
$$

##### Pair-wise
focus to define the pseudo labels for pairs of nodes rather then nodes themselves. Then $\Omega: \mathcal{V} \times \mathcal{V} \to \mathcal{C}$

$\Omega$ can be defined based on the distance between nodes, or number of hops, etc.

$$
\theta^*, \phi^* = \frac{1}{|\mathcal{P}|}\sum_{\mathcal{v}_i, \mathcal{v}_j \in \mathcal{P}}^{} \mathcal{L}_{ce} (p_{\phi}(f_{\theta}(\mathcal{G}_{(\mathcal{v}_i, \mathcal{v}_j)})), \Omega(\mathcal{v}_i, \mathcal{v}_j))
$$

where $\mathcal{P} \subseteq \mathcal{V} \times \mathcal{V}$

Note that this method also belongs to the classification, then the peuso labels are discreate
> My opinion: This method can learn from the structure of the graph, which is a very good learning strategy, however, the computation cost would increase if size of the model is large due to the pair-wise generation.


#### Regression task
The difference compared with the classification task is the pseudo labels are continuous and optimized based on MSE-based loss functions. The pseudo labels can be generated from node degrees, the local importance, etc. Similar to the classification task, the learning context can be either node-centric or pairwise-centric. In the later, the pseudo labels can be the similarity of their (raw) attributes between nodes.

For the node-centric regression, the objective function is
$$
\theta^*, \phi^* = \frac{1}{|\mathcal{V}|}\sum_{\mathcal{v}_i \in \mathcal{V}}^{} \mathcal{L}_{mse} (p_{\phi}(f_{\theta}(\mathcal{\mathcal{G}_{\mathcal{v}_i}})), \Omega(\mathcal{v}_i))
$$

For the pairwise-centric regression, the objective function is
$$
\theta^*, \phi^* = \frac{1}{|\mathcal{P}|}\sum_{\mathcal{v}_i, \mathcal{v}_j \in \mathcal{P}}^{} \mathcal{L}_{mse} (p_{\phi}(f_{\theta}(\mathcal{G}_{(\mathcal{v}_i, \mathcal{v}_j)})), \Omega(\mathcal{v}_i, \mathcal{v}_j))
$$

### Contrastive-based method
$$
\theta^*, \phi^* = \argmin_{\theta, \phi} \mathcal{L_{ssl}} (p_{\phi}(f_{\theta}(\hat{\mathcal{G}^1}), f_{\theta}(\hat{\mathcal{G}^2})))
$$


# References
https://arxiv.org/pdf/2103.00111.pdf
