In [None]:
!pip install ampligraph
import numpy as np
from ampligraph.datasets import load_fb15k_237
from ampligraph.latent_features import ScoringBasedEmbeddingModel
from ampligraph.evaluation import mrr_score, hits_at_n_score
from ampligraph.latent_features.loss_functions import get as get_loss
from ampligraph.latent_features.regularizers import get as get_regularizer
from ampligraph.utils import save_model
import tensorflow as tf

# load dataset
X = load_fb15k_237()

model = ScoringBasedEmbeddingModel(k=400,
                                   eta=30,
                                   scoring_type='TransE')

# Optimizer, loss and regularizer definition
optim = tf.keras.optimizers.Adam(learning_rate=0.0001)
loss = get_loss('multiclass_nll')
regularizer = get_regularizer('LP', {'p': 2, 'lambda': 0.0001})

# Compilation of the model
model.compile(optimizer=optim, loss=loss, entity_relation_regularizer=regularizer)

# For evaluation, we can use a filter which would be used to filter out
# positives statements created by the corruption procedure.
# Here we define the filter set by concatenating all the positives
filter = {'test' : np.concatenate((X['train'], X['valid'], X['test']))}

# Early Stopping callback
checkpoint = tf.keras.callbacks.EarlyStopping(
    monitor='val_{}'.format('hits10'),
    min_delta=0,
    patience=5,
    verbose=1,
    mode='max',
    restore_best_weights=True
)

# Fit the model on training and validation set
batch_count = 64
model.fit(X['train'],
          batch_size=int(X['train'].shape[0] / batch_count),
          epochs=4000,                    # Number of training epochs
          validation_freq=20,           # Epochs between successive validation
          validation_burn_in=100,       # Epoch to start validation
          validation_data=X['valid'],   # Validation data
          validation_filter=filter,     # Filter positives from validation corruptions
          callbacks=[checkpoint],       # Early stopping callback (more from tf.keras.callbacks are supported)
          verbose=True                  # Enable stdout messages
          )


# Run the evaluation procedure 
ranks = model.evaluate(X['test'],
                       use_filter=filter,
                       corrupt_side='s,o')

# Compute and print metrics:
mrr = mrr_score(ranks)
hits_10 = hits_at_n_score(ranks, n=10)
print("MRR: %f, Hits@10: %f" % (mrr, hits_10))

# Save the model
storage_path = "output"
save_model(model, model_name_path=storage_path)