# Learning Graph Embeddings with HMLET (End) Model

Kong et. al. in their paper *Linear, or Non-Linear, That is the Question!* discussed about the issue of manually deciding whether to go with linear transformation or non-linear transformation while learning the graph embeddings. To address this issue, they proposed a method that will automate this manual decision. In other words, for each node of the graph (and at each later), model will learn an additional binary vector of length 2 (i.e. 1️⃣0️⃣ would indicate linear-transformation of the given node, similarly, 0️⃣1️⃣ means a non-linear transformation).

The authors dubbed the model **H**ybrid **M**ethod of **L**inear and nonlin**E**ar collaborative fil**T**ering (HMLET, pronounced as Hamlet).

There are four variants:

![Four variants of HMLET in terms of the location of the non-linear propagation. HMLET(End) shows the best accuracy in experiments. It was known that the problem of over-smoothing happens with more than 2 non-linear propagation layers, and therefore we are using up to 2 non-linear layers.](https://github.com/sparsh-ai/stanza/raw/S693545/images/img0.png)

Four variants of HMLET in terms of the location of the non-linear propagation. HMLET(End) shows the best accuracy in experiments. It was known that the problem of over-smoothing happens with more than 2 non-linear propagation layers, and therefore we are using up to 2 non-linear layers.

Each method except HMLET(All) uses up to 2 non-linear layers since it is known that more than 2 non-linear layers cause the problem of over-smoothing. Authors tested various options of where to put them. First, HMLET(Front) focuses on the fact that GCNs are highly influenced by close neighborhood, i.e., in the first and second layers. Therefore, HMLET(Front) adopts the gating module in the front and uses only the linear propagation layers afterwards. Second, HMLET(Middle) only uses the linear propagation in the front and last and then adopts the gating module in the second and third layers. Last, as the gating module is located in the third and fourth layers, HMLET(End) focuses on gating in the third and fourth layers.

![The comparison of NDCG@20 with all types of HMLET in three public benchmarks.](https://github.com/sparsh-ai/stanza/raw/S693545/images/img1.png)

The comparison of NDCG@20 with all types of HMLET in three public benchmarks.

Experiments revealed that the HMLET (End) variant gave superior performance on the above three public datasets. So we will focus on this variant for our in-depth analysis.

**HMLET (End) Model**

Here is the architecture diagram of HMLET (End) and the algorithm to train this model:

![The detailed workflow of HMLET(End). This model variant bypasses the nonlinearity propagation on the first and second layers to address the over-smoothing problem and then propagates the non-linear embedding in the third and fourth layers. We prepare both the linear and non-linear propagation steps in a layer and let our gating module with STGS (straight-through Gumbel softmax) decide which one to use for each node. For instance, it can select a sequence of linear → linear → linear → non-linear for some nodes while it can select a totally different sequence for other nodes.](https://github.com/sparsh-ai/stanza/raw/S693545/images/img2.png)

The detailed workflow of HMLET(End). This model variant bypasses the nonlinearity propagation on the first and second layers to address the over-smoothing problem and then propagates the non-linear embedding in the third and fourth layers. We prepare both the linear and non-linear propagation steps in a layer and let our gating module with STGS (straight-through Gumbel softmax) decide which one to use for each node. For instance, it can select a sequence of linear → linear → linear → non-linear for some nodes while it can select a totally different sequence for other nodes.

![Untitled](https://github.com/sparsh-ai/stanza/raw/S693545/images/img3.png)

**Linear embedding propagation update equation (equation #6):**

![Untitled](https://github.com/sparsh-ai/stanza/raw/S693545/images/img4.png)

Here, $\dfrac{1}{\sqrt{|\mathcal{N}_u||\mathcal{N}_v|}}$ is a symmetric normalization term to restrict the scale of embeddings into a reasonable boundary.

**Non-linear embedding propagation update equation (equation #7):**

![Untitled](https://github.com/sparsh-ai/stanza/raw/S693545/images/img5.png)

where 𝜙 is a non-linear activation, e.g., ELU, Leaky ReLU, etc.

**BPR Loss (equation #9):**

![Untitled](https://github.com/sparsh-ai/stanza/raw/S693545/images/img6.png)

where 𝜎 is the sigmoid function. $Θ$ is the initial embeddings and the parameters of the gating modules, and 𝜆 controls the $𝐿_2$ regularization strength. We use each observed user-item interaction as a positive instance and a bunch of negative instances selected using negative-sampling strategy.

Here is the PyTorch implementation of this HMLET (End) model (note that for clarity, only showing the core functions):

```python
class HMLET_End(nn.Module):

	def __choosing_one(self, features, gumbel_out):
		feature = torch.sum(torch.mul(features, gumbel_out), dim=1)  # batch x embedding_dim (or batch x embedding_dim x layer_num)
		return feature

	def computer(self, gum_temp, hard):     
		
		self.Graph = self.g_train   
		if self.dropout:
			if self.training:
				g_droped = self.__dropout(self.keep_prob)
			else:
				g_droped = self.Graph        
		else:
			g_droped = self.Graph
    
		# Init users & items embeddings  
		users_emb = self.embedding_user.weight
		items_emb = self.embedding_item.weight
      
		## Layer 0
		all_emb_0 = torch.cat([users_emb, items_emb])
		
		# Residual embeddings
		embs = [all_emb_0]
	
		## Layer 1
		all_emb_lin_1 = torch.sparse.mm(g_droped, all_emb_0)
		
		# Residual embeddings	
		embs.append(all_emb_lin_1)
   
		## layer 2
		all_emb_lin_2 = torch.sparse.mm(g_droped, all_emb_lin_1)
		
		# Residual embeddings
		embs.append(all_emb_lin_2)
		
		## layer 3
		all_emb_lin_3 = torch.sparse.mm(g_droped, all_emb_lin_2)
		all_emb_non_1 = self.activation_function(torch.sparse.mm(g_droped, all_emb_0))
		
		# Gating
		stack_embedding_1 = torch.stack([all_emb_lin_3, all_emb_non_1],dim=1)
		concat_embeddings_1 = torch.cat((all_emb_lin_3, all_emb_non_1),-1)

		gumbel_out_1, lin_count_3, non_count_3 = self.gating_network_list[0](concat_embeddings_1, gum_temp, hard, self.config['division_noise'])
		embedding_1 = self.__choosing_one(stack_embedding_1, gumbel_out_1)

		# Residual embeddings
		embs.append(embedding_1)
  	
		# layer 4
		all_emb_lin_4 = torch.sparse.mm(g_droped, embedding_1)
		all_emb_non_2 = self.activation_function(torch.sparse.mm(g_droped, embedding_1))
		
		# Gating
		stack_embedding_2 = torch.stack([all_emb_lin_4, all_emb_non_2],dim=1)
		concat_embeddings_2 = torch.cat((all_emb_lin_4, all_emb_non_2),-1)

		gumbel_out_2, lin_count_4, non_count_4 = self.gating_network_list[1](concat_embeddings_2, gum_temp, hard, self.config['division_noise'])
		embedding_2 = self.__choosing_one(stack_embedding_2, gumbel_out_2)

		# Residual embeddings  		
		embs.append(embedding_2)

		## Stack & mean residual embeddings
		embs = torch.stack(embs, dim=1)
		light_out = torch.mean(embs, dim=1)
   
		users, items = torch.split(light_out, [self.num_users, self.num_items])
		
		return users, items, [lin_count_3, non_count_3, lin_count_4, non_count_4], embs

	def getUsersRating(self, users, gum_temp, hard):
		all_users, all_items, gating_dist, embs = self.computer(gum_temp, hard)
		
		users_emb = all_users[users.long()]
		items_emb = all_items

		rating = self.activation_function(torch.matmul(users_emb, items_emb.t()))

		return rating, gating_dist, embs

	def getEmbedding(self, users, pos_items, neg_items, gum_temp, hard):
		all_users, all_items, gating_dist, embs = self.computer(gum_temp, hard)
		
		users_emb = all_users[users]
		pos_emb = all_items[pos_items]
		neg_emb = all_items[neg_items]

		users_emb_ego = self.embedding_user(users)
		pos_emb_ego = self.embedding_item(pos_items)
		neg_emb_ego = self.embedding_item(neg_items)

		return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego, gating_dist, embs

	def bpr_loss(self, users, pos, neg, gum_temp, hard):
		(users_emb, pos_emb, neg_emb, 
		userEmb0,  posEmb0, negEmb0, gating_dist, embs) = self.getEmbedding(users.long(), pos.long(), neg.long(), gum_temp, hard)
		
		reg_loss = (1/2)*(userEmb0.norm(2).pow(2) + 
							posEmb0.norm(2).pow(2)  +
							negEmb0.norm(2).pow(2))/float(len(users))
		
		pos_scores = torch.mul(users_emb, pos_emb)
		pos_scores = torch.sum(pos_scores, dim=1)
		neg_scores = torch.mul(users_emb, neg_emb)
		neg_scores = torch.sum(neg_scores, dim=1)
		
		loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
		
		return loss, reg_loss, gating_dist, embs
		
	def forward(self, users, items, gum_temp, hard):
		# compute embedding
		all_users, all_items, gating_dist, embs = self.computer(gum_temp, hard)

		users_emb = all_users[users]
		items_emb = all_items[items]

		inner_pro = torch.mul(users_emb, items_emb)
		gamma     = torch.sum(inner_pro, dim=1)

		return gamma, gating_dist, embs
```

**Gating Module**

Gating module is the core component of the model. Here is the algorithm to calculate the gating decision vectors (i.e. 1️⃣0️⃣ would indicate linear-transformation of the given node, similarly, 0️⃣1️⃣ means a non-linear transformation):

![The intuition behind this technique is that i) the embeddings of nodes may exhibit both the linearity and the non-linearity in their characteristics and ii) the linearity and the non-linearity of nodes may vary from one layer to another.](https://github.com/sparsh-ai/stanza/raw/S693545/images/img7.png)

The intuition behind this technique is that i) the embeddings of nodes may exhibit both the linearity and the non-linearity in their characteristics and ii) the linearity and the non-linearity of nodes may vary from one layer to another.

If 𝜉 is the first or second type, a designated embedding type is selected. If 𝜉 is the third type, the input embeddings, i.e., the linear and non-linear embedding, are concatenated and then passed to an MLP (multi-layer perceptron) (Lines 7 and 8 in Algorithm 1). The result of the MLP is a logit vector 𝒍, an input for STGS (Line 9). The logit vector 𝒍 corresponds to log 𝝅 in gumbel-softmax formulation. 𝒈 represents a linear or non-linear selection by the gating module, i.e., 𝒈 is a two-dimensional one-hot vector. Therefore, 𝒆 𝐺 is the same as either of 𝒆 𝐿 or 𝒆 𝑁 (Line 10).

[Here](https://colab.research.google.com/gist/sparsh-ai/a46167213b2b308953b2654381c8725b/t611269-learning-graph-embeddings-of-gowalla-dataset-using-hmlet-model.ipynb#scrollTo=XvBM9e8sx8Wl&line=38&uniqifier=1) is the PyTorch implementation of this gating module:

```python
class Gating_Net(nn.Module):

    def __init__(self, embedding_dim, mlp_dims):
        super(Gating_Net, self).__init__()
        self.embedding_dim = embedding_dim
        self.softmax =  nn.LogSoftmax(dim=1)
        fc_layers = []
        for i in range(len(mlp_dims)):
            if i == 0:
                fc_layers.append(nn.Linear(embedding_dim*2, mlp_dims[i]))
            else:
                fc_layers.append(nn.Linear(mlp_dims[i-1], mlp_dims[i]))	
            if i != len(mlp_dims) - 1:
                fc_layers.append(nn.BatchNorm1d(mlp_dims[i]))
                fc_layers.append(nn.ReLU(inplace=True))
        self.mlp = nn.Sequential(*fc_layers)

    def gumbel_softmax(self, logits, temperature, division_noise, hard):
        """Sample from the Gumbel-Softmax distribution and optionally discretize.
        Args:
          logits: [batch_size, n_class] unnormalized log-probs
          temperature: non-negative scalar
          hard: if True, take argmax, but differentiate w.r.t. soft sample y
        Returns:
          [batch_size, n_class] sample from the Gumbel-Softmax distribution.
          If hard=True, then the returned sample will be one-hot, otherwise it will
          be a probabilitiy distribution that sums to 1 across classes
        """
        y = self.gumbel_softmax_sample(logits, temperature, division_noise) ## (0.6, 0.2, 0.1,..., 0.11)
        if hard:
            k = logits.size(1) # k is numb of classes
            # y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)  ## (1, 0, 0, ..., 0)
            y_hard = torch.eq(y, torch.max(y, dim=1, keepdim=True)[0]).type_as(y)
            y = (y_hard - y).detach() + y
        return y

    def gumbel_softmax_sample(self, logits, temperature, division_noise):
        """ Draw a sample from the Gumbel-Softmax distribution"""
        noise = self.sample_gumbel(logits)
        y = (logits + (noise/division_noise)) / temperature
        return F.softmax(y)

    def sample_gumbel(self, logits):
        """Sample from Gumbel(0, 1)"""
        noise = torch.rand(logits.size())
        eps = 1e-20
        noise.add_(eps).log_().neg_()
        noise.add_(eps).log_().neg_()
        return Variable(noise.float()).cuda()

    def forward(self, feature, temperature, hard, division_noise): #z= batch x z_dim // #feature =  batch x num_gen x 256*8*8
        x = self.mlp(feature)
        out = self.gumbel_softmax(x, temperature, division_noise, hard)
        out_value = out.unsqueeze(2)
        out = out_value.repeat(1, 1, self.embedding_dim)
                
        return out, torch.sum(out_value[:,0]), torch.sum(out_value[:,1])
```

**Experiment on Gowalla dataset**

We filtered out those users and items with less than ten interactions in all datasets, i.e., a 10-core setting. For testing, we then split a dataset into training (80%), validation (10%), and test (10%) sets. We choose the best hyperparameter set via the grid search with the validation set. The best setting found in HMLET is as follows: the number of linear layers is set to 4; the number of non-linear layers is set to 2 for HMLET(End); the optimizer is Adam; the learning rate is 0.001; the $𝐿_2$ regularization coefficient 𝜆 is 1E-4; the mini-batch size is 2,048; the dropout rate is 0.4. And we use the temperature 𝜏 with an initialization to 0.7, a minimum temperature of 0.01, and a decay factor of 0.995. Also, for fair comparison, we set the embedding sizes for all methods to 512. In non-linear layers, we test two non-linear activation functions: Leaky-ReLU (negative slope = 0.01) and ELU (𝛼 = 1.0).

Here are the parameters:

```jsx
==============================
Model: HMLET_End
Model config: {'embedding_dim': 512, 'activation_function': 'elu', 'dropout': 1, 'keep_prob': 0.6, 'a_split': 0, 'gating_mlp_dims': [128, 2]}
Dataset: gowalla
EPOCHS: 4
Pretrain: False
BPR batch size: 2048
Test batch size: 100
Test topks: [10, 20, 30, 40, 50]
N fold: 100
Tensorboard: True
==============================
==============================
DATA PATH: /content/data/gowalla
SAVE FILE PATH: /content/checkpoints/HMLET_End/gowalla
LOAD FILE PATH: /content/checkpoints/HMLET_End/gowalla/
EXCEL PATH: /content/excel
BOARD PATH: /content/tensorboard
==============================
==============================
Cuda: True
CUDA device: 0
==============================
==============================
Multicore: 1
CORES: 1
==============================
```

We ran it for 4 epochs. Here are the results:

```jsx
Loading /content/data/gowalla
==============================
810128 interactions for training
108621 interactions for testing
<__main__.Loader object at 0x7fb5d94e2510> Sparsity : 0.0008396216228570436
==============================
<__main__.Loader object at 0x7fb5d94e2510> is ready to go
activation_function: ELU(alpha=1.0)
loading adjacency matrix
successfully train loaded...
don't split the matrix

Train 1 ==============================
gum_temp: 0.7
EPOCH[1/4] loss0.386
train time: 534.3337316513062
decay gum_temp: 0.6965087354348776
model save...
Valid ==================================================
valid mode
{'precision': array([0.02332116, 0.01668453, 0.0136216 , 0.01180293, 0.01056637]), 'recall': array([0.07605411, 0.10674152, 0.13055237, 0.1500007 , 0.16690163]), 'ndcg': array([0.06059464, 0.07001512, 0.07663515, 0.08171554, 0.08589203])}
Test ==================================================
test mode
{'precision': array([0.02286154, 0.01636747, 0.01323375, 0.01149692, 0.01028535]), 'recall': array([0.0733682 , 0.10507128, 0.12697428, 0.14644623, 0.16393784]), 'ndcg': array([0.05854139, 0.06809456, 0.07418164, 0.07923892, 0.08345424])}

Train 2 ==============================
gum_temp: 0.6965087354348776
EPOCH[2/4] loss0.125
train time: 533.7940173149109
decay gum_temp: 0.6930348836244177
model save...
Valid ==================================================
valid mode
{'precision': array([0.02645946, 0.0188398 , 0.01532527, 0.01327243, 0.01181431]), 'recall': array([0.08573863, 0.11987728, 0.14631547, 0.1682178 , 0.18665247]), 'ndcg': array([0.06959012, 0.08001479, 0.0873653 , 0.09306554, 0.09759169])}
Test ==================================================
test mode
{'precision': array([0.02595954, 0.01831335, 0.01496528, 0.01297558, 0.01155402]), 'recall': array([0.08334737, 0.11692721, 0.14290043, 0.16482907, 0.18364577]), 'ndcg': array([0.06762457, 0.07773272, 0.08496656, 0.09062759, 0.09517945])}

Train 3 ==============================
gum_temp: 0.6930348836244177
EPOCH[3/4] loss0.090
train time: 534.7814464569092
decay gum_temp: 0.6895783577221438
model save...
Valid ==================================================
valid mode
{'precision': array([0.02750109, 0.01956158, 0.01601076, 0.01382172, 0.01241183]), 'recall': array([0.08927831, 0.12504565, 0.1532377 , 0.1750567 , 0.19631655]), 'ndcg': array([0.07269772, 0.08360102, 0.09145334, 0.09717201, 0.10236788])}
Test ==================================================
test mode
{'precision': array([0.02699444, 0.01916404, 0.01566191, 0.01349973, 0.01205774]), 'recall': array([0.08641382, 0.12267856, 0.15009758, 0.17216471, 0.19164879]), 'ndcg': array([0.07056823, 0.08149902, 0.08911955, 0.09480704, 0.09956793])}

Train 4 ==============================
gum_temp: 0.6895783577221438
EPOCH[4/4] loss0.077
train time: 535.70445728302
decay gum_temp: 0.6861390713147286
model save...
Valid ==================================================
valid mode
{'precision': array([0.02835181, 0.02017617, 0.01645622, 0.01423201, 0.01270456]), 'recall': array([0.09210259, 0.12944523, 0.1578907 , 0.1813253 , 0.20186935]), 'ndcg': array([0.07493502, 0.08629033, 0.09422578, 0.10031303, 0.10533204])}
Test ==================================================
test mode
{'precision': array([0.02766093, 0.01966642, 0.01600688, 0.01386563, 0.01239936]), 'recall': array([0.08931023, 0.12607232, 0.15384466, 0.1774857 , 0.19812262]), 'ndcg': array([0.07260988, 0.08374147, 0.09144997, 0.09752696, 0.1025441 ])}
```

We saved the trained model, outputs, and other artifacts like Tensorboard metrics. Here is the folder structure and size of these artifacts:

```jsx
.
├── [ 12K]  checkpoints
│   └── [8.0K]  HMLET_End
│       └── [4.0K]  gowalla
├── [ 13M]  data
│   └── [ 13M]  gowalla
│       ├── [6.9M]  s_pre_adj_mat_train.npz
│       ├── [753K]  test.txt
│       ├── [4.4M]  train.txt
│       └── [752K]  val.txt
├── [ 15K]  excel
│   ├── [5.4K]  test_HMLET_End_512_0.7_0.005_0.01_1_3_1_0.6_[10, 20, 30, 40, 50].xlsx
│   └── [5.4K]  valid_HMLET_End_512_0.7_0.005_0.01_1_3_1_0.6_[10, 20, 30, 40, 50].xlsx
└── [218K]  tensorboard
    ├── [ 30K]  11-18-09h39m04s-
    │   └── [ 26K]  events.out.tfevents.1637228346.9f9d02c6c3c0.161.0
    ├── [ 43K]  11-18-09h56m39s-
    │   └── [ 39K]  events.out.tfevents.1637229401.9f9d02c6c3c0.432.0
    └── [141K]  11-18-10h21m54s-
        ├── [ 77K]  events.out.tfevents.1637230914.9f9d02c6c3c0.432.1
        ├── [4.0K]  Test_NDCG@[10, 20, 30, 40, 50]_10
        ├── [4.0K]  Test_NDCG@[10, 20, 30, 40, 50]_20
        ├── [4.0K]  Test_NDCG@[10, 20, 30, 40, 50]_30
        ├── [4.0K]  Test_NDCG@[10, 20, 30, 40, 50]_40
        ├── [4.0K]  Test_NDCG@[10, 20, 30, 40, 50]_50
        ├── [4.0K]  Test_Precision@[10, 20, 30, 40, 50]_10
        ├── [4.0K]  Test_Precision@[10, 20, 30, 40, 50]_20
        ├── [4.0K]  Test_Precision@[10, 20, 30, 40, 50]_30
        ├── [4.0K]  Test_Precision@[10, 20, 30, 40, 50]_40
        ├── [4.0K]  Test_Precision@[10, 20, 30, 40, 50]_50
        ├── [4.0K]  Test_Recall@[10, 20, 30, 40, 50]_10
        ├── [4.0K]  Test_Recall@[10, 20, 30, 40, 50]_20
        ├── [4.0K]  Test_Recall@[10, 20, 30, 40, 50]_30
        ├── [4.0K]  Test_Recall@[10, 20, 30, 40, 50]_40
        └── [4.0K]  Test_Recall@[10, 20, 30, 40, 50]_50

  13M used in 25 directories, 9 files
```

Moreover, we repeated the experimental analysis on 2 other datasets - Amazon books, and Yelp 2018 business dataset. The procedure is exactly same. Here are the Jupyter notebooks for these experiments.

1. [Learning Graph Embeddings with HMLET model on Yelp dataset](https://nbviewer.org/gist/sparsh-ai/da0a1e723f8675ae59a7273bf78a49a6)
2. [Learning Graph Embeddings with HMLET model on Gowalla dataset](https://nbviewer.org/gist/sparsh-ai/a46167213b2b308953b2654381c8725b)
3. [Learning Graph Embeddings with HMLET model on Amazon-books dataset](https://nbviewer.org/gist/sparsh-ai/0ac3a7e2bb7c507a417ed6953a8b0ff1)