- Contents
- TransX Models Description
- Models architecture
- Dataset
- Environment Requirements
- Script description
- Model description
- Description of Random Situation
- ModelZoo Homepage
TransE, TransH, TransR, TransD are models for Knowledge Graph Embeddings. The "Knowledge" for this model is represented as a triple (head, relation, tail) where the head and tail are entities.
The basic idea of the TransE model is making the sum of the head vector and relation vector as close as possible with the tail vector. The distance is calculated using L1 or L2 norm. The loss function used for training this model is the margin loss calculated over scores for positive and negative samples. The negative sampling is performed by replacing head or tail entities in the original triple. This model is good for managing one-to-one relations.
TransH allows us to tackle the problem of one-to-many, many-to-one and many-to-many relations. Its basic idea is to reinterpret relations as the translations on a hyperplane.
The idea of TransR is that the entity and relations can have different semantic spaces. It uses the trainable projection matrix to project the entities into the multi-relational space. It also has some shortages. For example, the projection matrix is determined only by the relation, and the heads and tails are assumed to be from the same semantic space. Moreover, the TransR model has a much larger number of parameters, which is not suitable for large-scale tasks.
TransD compensates for the flaws of the TransR model by using the dynamic mapping of the heads and tails entities. The projection matrices for heads and tails are calculated from the head-relation and tail-relation pairs correspondingly.
- Paper TransE Translating Embeddings for Modeling Multi-relational Data(2013)
- Paper TransH Knowledge Graph Embedding by Translating on Hyperplanes(2014)
- Paper TransR (download) Learning Entity and Relation Embeddings for Knowledge Graph Completion(2015
- Paper TransD Knowledge Graph Embedding via Dynamic Mapping Matrix(2015
The base elements of all models are trainable lookup tables for entities and relations which produce the embeddings.
We use Wordnet and Fresbase datasets for training the models.
The preprocessed data files are available here:
- WN18RR (Wordnet)
- Size: 3.7 MB
- Number of entities: 40943
- Number of relations: 11
- Number of train triplets: 86835
- Number of test triplets: 3134
- FB15K237 (Freebase)
- Size: 5.5 MB
- Number of entities: 14541
- Number of relations: 237
- Number of train triplets: 272115
- Number of test triplets: 28466
- Hardware(GPU)
- Prepare hardware environment with GPU.
- Framework
- For more information, please check the resources below:
./transX
├── configs # models configuration files
│ ├── default_config.yaml
│ ├── transD_fb15k237_1gpu_config.yaml
│ ├── transD_fb15k237_8gpu_config.yaml
│ ├── transD_wn18rr_1gpu_config.yaml
│ ├── transD_wn18rr_8gpu_config.yaml
│ ├── transE_fb15k237_1gpu_config.yaml
│ ├── transE_fb15k237_8gpu_config.yaml
│ ├── transE_wn18rr_1gpu_config.yaml
│ ├── transE_wn18rr_8gpu_config.yaml
│ ├── transH_fb15k237_1gpu_config.yaml
│ ├── transH_fb15k237_8gpu_config.yaml
│ ├── transH_wn18rr_1gpu_config.yaml
│ ├── transH_wn18rr_8gpu_config.yaml
│ ├── transR_fb15k237_1gpu_config.yaml
│ ├── transR_fb15k237_8gpu_config.yaml
│ ├── transR_wn18rr_1gpu_config.yaml
│ └── transR_wn18rr_8gpu_config.yaml
├── model_utils # Model Arts utilities
│ ├── config.py
│ ├── device_adapter.py
│ ├── __init__.py
│ ├── local_adapter.py
│ └── moxing_adapter.py
├── scripts # Shell scripts for training and evaluation
│ ├── run_distributed_train_gpu.sh
│ ├── run_eval_gpu.sh
│ ├── run_export_gpu.sh
│ └── run_standalone_train_gpu.sh
├── src
│ ├── base # C++ backend code for the dataset
│ │ ├── Base.cpp
│ │ ├── CMakeLists.txt
│ │ ├── Corrupt.h
│ │ ├── Random.h
│ │ ├── Reader.h
│ │ ├── Setting.h
│ │ └── Triple.h
│ ├── dataset_lib # Compiled dataset tools
│ │ └── train_dataset_lib.so
│ ├── utils
│ │ └── logging.py # Logging utilities
│ ├── dataset.py
│ ├── loss.py
│ ├── make.sh
│ ├── metric.py
│ ├── model_builder.py # Convenient scripts building models
│ ├── trans_x.py # Models definitions
│ └── __init__.py
├── eval.py # Script for evaluation of the trained model
├── export.py # Script for exporting the trained model
├── requirements.txt # Additional dependencies
├── train.py # Script for start the training process
└── README.md # Documentation in English
Parameters for both training and evaluating can be provided via a *.yaml configuration files or by directly providing the arguments to the train.py, eval.py and export.y scripts.
device_target: "GPU" # tested with GPUs only
is_train_distributed: False # Whether to use the NCCL for multi-GPU training
group_size: 1 # Number of the devices
device_id: 0 # Device ID (only for a single GPU training)
seed: 1 # Random seed
# Model options
model_name: "TransE" # Name of the model (TransE / TransH / TransR / TransD)
dim_e: 50 # Embeddings size for entities
dim_r: 50 # Embeddings size for relations
# Dataset options
dataset_root: "/path/to/dataset/root"
train_triplet_file_name: "train2id.txt"
eval_triplet_file_name: "test2id.txt"
filter_triplets_files_names: # Files with positive triplets samples
- "train2id.txt"
- "valid2id.txt"
- "test2id.txt"
entities_file_name: "entity2id.txt"
relations_file_name: "relation2id.txt"
negative_sampling_rate: 1 # The number of negative samples per a single positive sample.
train_batch_size: 868
# Logging options
train_output_dir: "train-outputs/"
eval_output_dir: "eval-output/"
export_output_dir: "export-output/"
ckpt_save_interval: 5
ckpt_save_on_master_only: True
keep_checkpoint_max: 10
log_interval: 100
# Training options
pre_trained: "" # Path to the pre-trained model (necessary for TransR)
lr: 0.5 # Learning rate
epochs_num: 1000 # Number of epochs
weight_decay: 0.0 # Weight decay
margin: 6.0 # Parameters of the Margin loss
train_use_data_sink: False
# Evaluation and export options
ckpt_file: "/path/to/trained/checkpoint"
file_format: "MINDIR"
eval_use_data_sink: False
export_batch_size: 1000 # The batch size of the exported model
You need to compile the library for generating the corrupted triplets.
The SOTA implementation uses triplets filtering to ensure that the corrupted triplets are actually not presented among the original triplets. This filtering process is difficult to vectorize in order to effectively implement in in Python, so we use our custom *.so library.
To build the library go to the ./transX/src directory and run
bash make.sh
After build is successfully finished, train_dataset_lib.so appears in ./transX/src/dataset_lib.
You can start the single GPU training process by running the python script:
-
Without pre-trained model
python train.py --config_path=/parth/to/model_config.yaml --dataset_root=/path/to/dataset
-
With pre-trained model
python train.py --config_path=/parth/to/model_config.yaml --dataset_root=/path/to/dataset --pre_trained=/path/to/pretrain.ckpt
or by running the shell script:
-
Without pre-trained model
bash scripts/run_standalone_train_gpu.sh [DATASET_ROOT] [DATASET_NAME] [MODEL_NAME]
-
With pre-trained model
bash scripts/run_standalone_train_gpu.sh [DATASET_ROOT] [DATASET_NAME] [MODEL_NAME] [PRETRAIN_CKPT]
You can start the 8-GPU training by running the following shell script
-
Without pre-trained model
bash scripts/run_distributed_train_gpu.sh [DATASET_ROOT] [DATASET_NAME] [MODEL_NAME]
-
With pre-trained model
bash scripts/run_distributed_train_gpu.sh [DATASET_ROOT] [DATASET_NAME] [MODEL_NAME] [PRETRAIN_CKPT]
DATASET_NAME must be "wn18rr" or "fb15k237"
MODEL_NAME must be "transE", "transH", "transR" or "transD"
Using this names the corresponding configuration file in ./configs directory will be selected.
The train results will be stored in the ./train-outputs directory. If shell scripts are used, the logged information will be redirected to the ./train-logs directory.
You can start evaluation by running the following python script:
python eval.py --config_path=/parth/to/model_config.yaml --dataset_root=/path/to/dataset --ckpt_file=/path/to/trained.ckpt
or shell script:
bash scripts/run_eval_gpu.sh [DATASET_ROOT] [DATASET_NAME] [MODEL_NAME] [CKPT_PATH]
DATASET_NAME must be "wn18rr" or "fb15k237"
MODEL_NAME must be "transE", "transH", "transR" or "transD"
Using this names the corresponding configuration file in ./configs directory will be selected.
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
The evaluation results will be stored in the ./eval-output directory. If the shell script is used, the logged information will be redirected to the ./eval-logs directory.
Example of the evaluation output:
...
[DATE/TIME]:INFO:start evaluation
[DATE/TIME]:INFO:evaluation finished
[DATE/TIME]:INFO:Result: hit@10 = 0.5056 hit@3 = 0.3623 hit@1 = 0.0490
You can export the model by running the following python script:
python export.py --config_path=/parth/to/model_config.yaml --dataset_root=/path/to/dataset --ckpt_file=/path/to/trained.ckpt
or by running the shell script:
bash scripts/run_export_gpu.sh [DATASET_ROOT] [DATASET_NAME] [MODEL_NAME] [CKPT_PATH]
DATASET_NAME must be "wn18rr" or "fb15k237"
MODEL_NAME must be "transE", "transH", "transR" or "transD"
Using this names the corresponding configuration file in ./configs directory will be selected.
The tested formats for export are: MINDIR.
For training the TransR models we used corresponding trained TransE models! You need to train TransE models first in order to get the better performance of the TransR model.
1 GPU Training
Parameters | ||||||||
---|---|---|---|---|---|---|---|---|
Resource | 1x V100 | 1x V100 | 1x V100 | 1x V100 | 1x V100 | 1x V100 | 1x V100 | 1x V100 |
uploaded Date (mm/dd/yyy) | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 |
MindSpore Version | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 |
Model | TransE | TransH | TransR | TransD | TransE | TransH | TransR | TransD |
Dataset | Wordnet | Wordnet | Wordnet | Wordnet | Freebase | Freebase | Freebase | Freebase |
Batch size | 868 | 868 | 868 | 868 | 2721 | 2721 | 453 | 2721 |
Learning rate | 0.5 | 0.5 | 0.05 | 0.5 | 1 | 0.5 | 0.16667 | 1 |
Epochs | 1000 | 300 | 250 | 200 | 1000 | 1000 | 1000 | 1000 |
Accuracy (Hit@10) | 0.511 | 0.504 | 0.516 | 0.508 | 0.476 | 0.481 | 0.509 | 0.483 |
Total time | 3m 0s | 1m 22s | 1m 10s | 1m | 19m 32s | 34m 21s | 7h 34m 16s | 33m 22s |
8 GPU Training
Parameters | ||||||||
---|---|---|---|---|---|---|---|---|
Resource | 8x V100 | 8x V100 | 8x V100 | 8x V100 | 8x V100 | 8x V100 | 8x V100 | 8x V100 |
uploaded Date (mm/dd/yyy) | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 |
MindSpore Version | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 |
Model | TransE | TransH | TransR | TransD | TransE | TransH | TransR | TransD |
Dataset | Wordnet | Wordnet | Wordnet | Wordnet | Freebase | Freebase | Freebase | Freebase |
Batch size | 868 | 868 | 868 | 868 | 2721 | 2721 | 453 | 2721 |
Learning rate | 0.5 | 0.5 | 0.05 | 0.5 | 8 | 4 | 1.3333 | 8 |
Epochs | 1000 | 300 | 250 | 200 | 1000 | 1000 | 1000 | 1000 |
Accuracy (Hit@10) | 0.511 | 0.507 | 0.512 | 0.514 | 0.475 | 0.483 | 0.509 | 0.481 |
Total time | 1m 17s | 31s | 27s | 52s | 3m 51s | 5m 41s | 1h 24m 18s | 6m 32s |
Parameters | ||||||||
---|---|---|---|---|---|---|---|---|
Resource | GPU V100 | GPU V100 | GPU V100 | GPU V100 | GPU V100 | GPU V100 | GPU V100 | GPU V100 |
uploaded Date (mm/dd/yyy) | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 | 02/06/2022 |
MindSpore Version | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 | 1.5.0 |
Model | TransE | TransH | TransR | TransD | TransE | TransH | TransR | TransD |
Dataset | Wordnet | Wordnet | Wordnet | Wordnet | Freebase | Freebase | Freebase | Freebase |
batch_size | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
outputs | Scores | Scores | Scores | Scores | Scores | Scores | Scores | Scores |
Hit@10 | 0.511 | 0.507 | 0.512 | 0.514 | 0.475 | 0.483 | 0.509 | 0.481 |
We also use random seed in train.py and provide the random seed into the C++ backend of the dataset generator.
Please check the official homepage.