### Format Knowledge-Graph Embeddings for Hopwise `dataset.get_preload_weight()` function
This notebook shows you how you can format kge methods embeddings to be loaded with `dataset.get_preload_weight`


📚 [Load Pretrained Embedding Documentation](https://recbole.io/docs/user_guide/usage/load_pretrained_embedding.html)

**Load Libraries**

In [None]:
import os

import pandas as pd
import torch

from hopwise.data import create_dataset

%cd ..

/home/gmedda/projects/hopwise


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


### Load Checkpoint

In [48]:
checkpoint_name = "saved/TransE-Apr-09-2025_21-23-46.pth"

In [49]:
checkpoint = torch.load(checkpoint_name)

**The Embeddings detected are**

In [50]:
checkpoint["state_dict"].keys()

odict_keys(['user_embedding.weight', 'entity_embedding.weight', 'relation_embedding.weight'])

**Do you want to exclude some embeddings?**

In [51]:
excluded = ["relation_bias_embedding.weight"]

**The Dataset detected is**

In [52]:
dataset_name = checkpoint["config"]["dataset"]
dataset_name

'test'

**The Dataset folder detected is**

In [53]:
data_path = checkpoint["config"]["data_path"]
data_path

'tests/test_data/test'

**Create the mappings between embedding and original entity/relation/user**

- Users have a mapping 1-1 so we don't need a mapping.

- We suppose that indexing starts at 1. (tipically 0 is reserved for [PAD])

In [54]:
dataset = create_dataset(checkpoint["config"])

In [55]:
dataset.field2token_id["tail_id"]

{'[PAD]': 0,
 '91': 1,
 '142': 2,
 '201': 3,
 '202': 4,
 '203': 5,
 '204': 6,
 '205': 7,
 '206': 8,
 '207': 9,
 '208': 10,
 '209': 11,
 '210': 12,
 '211': 13,
 '212': 14,
 '213': 15,
 '214': 16,
 '215': 17,
 '216': 18,
 '217': 19,
 '218': 20,
 '219': 21,
 '220': 22,
 '221': 23,
 '222': 24,
 '223': 25,
 '224': 26,
 '225': 27,
 '226': 28,
 '227': 29,
 '228': 30,
 '229': 31,
 '230': 32,
 '231': 33,
 '232': 34,
 '233': 35,
 '234': 36,
 '235': 37,
 '236': 38,
 '237': 39,
 '238': 40,
 '239': 41,
 '240': 42,
 '241': 43,
 '242': 44,
 '243': 45,
 '244': 46,
 '245': 47,
 '246': 48,
 '247': 49,
 '248': 50,
 '249': 51,
 '250': 52,
 '251': 53,
 '252': 54,
 '253': 55,
 '254': 56,
 '255': 57,
 '256': 58,
 '257': 59,
 '258': 60,
 '259': 61,
 '260': 62,
 '261': 63,
 '262': 64,
 '263': 65,
 '264': 66,
 '265': 67,
 '266': 68,
 '267': 69,
 '268': 70,
 '269': 71,
 '270': 72,
 '271': 73,
 '272': 74,
 '273': 75,
 '274': 76,
 '275': 77,
 '276': 78,
 '277': 79,
 '278': 80,
 '279': 81,
 '280': 82,
 '281': 83,
 

In [56]:
# create the reverse mapping
uid2token = {id: token for token, id in dataset.field2token_id["user_id"].items()}
print(uid2token)
eid2token = {id: token for token, id in dataset.field2token_id["tail_id"].items()}
print(eid2token)
rid2token = {id: token for token, id in dataset.field2token_id["relation_id"].items()}
print(rid2token)

{0: '[PAD]', 1: '6', 2: '38', 3: '97', 4: '7', 5: '10', 6: '99', 7: '25', 8: '59', 9: '115', 10: '138', 11: '194', 12: '11', 13: '162', 14: '135', 15: '160', 16: '42', 17: '168', 18: '58', 19: '62', 20: '44', 21: '72', 22: '82', 23: '43', 24: '90', 25: '68', 26: '172', 27: '19', 28: '5', 29: '13', 30: '1', 31: '92', 32: '151', 33: '54', 34: '14', 35: '193', 36: '158', 37: '181', 38: '16', 39: '95', 40: '145', 41: '187', 42: '184', 43: '18', 44: '144', 45: '200', 46: '142', 47: '87', 48: '197', 49: '104', 50: '83', 51: '125', 52: '23', 53: '128', 54: '60', 55: '65', 56: '137', 57: '96', 58: '117', 59: '94', 60: '130', 61: '45', 62: '131', 63: '109', 64: '198', 65: '157', 66: '56', 67: '118', 68: '189', 69: '185', 70: '22', 71: '8', 72: '15', 73: '102', 74: '77', 75: '85', 76: '108', 77: '188', 78: '161', 79: '21', 80: '113', 81: '79', 82: '150', 83: '24', 84: '17', 85: '148', 86: '110', 87: '84', 88: '26', 89: '2', 90: '57', 91: '121', 92: '186', 93: '175', 94: '153', 95: '63', 96: '66'

In [57]:
# # add dummy relation, check kge code
# rid2token[len(rid2token)] = 'ui_dummy_relation'

In [58]:
assert len(eid2token.keys()) == checkpoint["state_dict"]["entity_embedding.weight"].shape[0]
assert len(rid2token.keys()) == checkpoint["state_dict"]["relation_embedding.weight"].shape[0]

*if the assertion check fails, make sure that you've trained the kge without adding dummy relations/entities explicitly when creating relation/entity embeddings!*

### Create the new embeddings

In [59]:
def format_embedding(weight, columns, emb_type):
    weight = weight.detach().cpu().numpy()
    new_emb_dict = {columns[0]: list(), columns[1]: list()}

    if emb_type == "entity":
        mapping = eid2token
    elif emb_type == "relation":
        mapping = rid2token
    elif emb_type == "user":
        mapping = uid2token

    # Create index
    new_emb_dict[columns[0]] = [mapping[id] if mapping is not None else id for id in range(1, weight.shape[0])]

    # Create embedding
    new_emb_dict[columns[1]] = [" ".join(f"{x}" for x in row) for row in weight[1:]]

    filename = f"{dataset_name}.{emb_type}emb"
    df = pd.DataFrame(new_emb_dict)
    print(f"[+] Saving the new {dataset_name} {columns[0]} embedding in {data_path}/{filename}!")
    df.to_csv(os.path.join(data_path, filename), sep="\t", index=False)

In [60]:
for emb_name, emb in checkpoint["state_dict"].items():
    if emb_name in excluded:
        continue
    # What is? Entity? User? Relation? Item?
    emb_type = emb_name.split("_")[0]
    # Create the new embedding file columns
    columns = [f"{emb_type}_embedding_id:token", f"{emb_type}_embedding:float_seq"]
    print(f"[+] Formatting {emb_name} with columns {columns}")
    format_embedding(emb, columns, emb_type)

[+] Formatting user_embedding.weight with columns ['user_embedding_id:token', 'user_embedding:float_seq']
[+] Saving the new test user_embedding_id:token embedding in tests/test_data/test/test.useremb!
[+] Formatting entity_embedding.weight with columns ['entity_embedding_id:token', 'entity_embedding:float_seq']
[+] Saving the new test entity_embedding_id:token embedding in tests/test_data/test/test.entityemb!
[+] Formatting relation_embedding.weight with columns ['relation_embedding_id:token', 'relation_embedding:float_seq']
[+] Saving the new test relation_embedding_id:token embedding in tests/test_data/test/test.relationemb!


### Next?

Now, in the dataset folder there are these file

In [15]:
os.listdir(data_path)

['ml-100k.user',
 'ml-100k.relationemb',
 'ml-100k.item',
 'ml-100k.inter',
 'ml-100k.useremb',
 'ml-100k.link',
 'ml-100k.entityemb',
 'ml-100k.kg']

**We want to make sure that the dataset configuration is ok.**

Suppose that the output of the format embedding phase is:

```text
    [+] Formatting user_embedding.weight with columns ['userid:token', 'user_embedding:float_seq']
    [+] Saving the new ml-1m userid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.useremb!
    [+] Formatting entity_embedding.weight with columns ['entityid:token', 'entity_embedding:float_seq']
    [+] Saving the new ml-1m entityid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.entityemb!
    [+] Formatting relation_embedding.weight with columns ['relationid:token', 'relation_embedding:float_seq']
    [+] Saving the new ml-1m relationid:token embedding in /home/recsysdatasets/ml-1m/ml-1m.relationemb!
```

Then, you should go to the dataset configuration file (in our case is in `hopwise/properties/dataset/ml-1m.yaml`) and add the new files to be loaded


```text
    additional_feat_suffix: [useremb, entityemb, relationemb]  
    load_col:                                                  
        useremb: [userid, user_embedding]
        entityemb: [entityid, entity_embedding]
        relationemb: [relationid, relation_embedding]
    
    alias_of_user_id: [userid]
    alias_of_entity_id: [entityid]
    alias_of_relation_id: [relationid]
    
    preload_weight:
      userid: user_embedding
      entityid: entity_embedding
      relationid: relation_embedding

```



### The end

Now in your code you should be able to access to pretrained embeddings in your model through:

*Torch*
```python
    pretrained_user_emb = dataset.get_preload_weight('userid')
    pretrained_entity_emb = dataset.get_preload_weight('entityid')
    pretrained_relation_emb = dataset.get_preload_weight('relationid')
    
    self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))
    self.entity_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_entity_emb))
    self.relation_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_relation_emb))
```

*Numpy*:
```python
    self.pretrained_user_emb = dataset.get_preload_weight('userid')
    self.entity_embedding = dataset.get_preload_weight('entityid')
    self.relation_embedding = dataset.get_preload_weight('relationid')
```


