In [39]:
from pykeen.datasets.inductive.ilp_teru import InductiveFB15k237
from pykeen.datasets.inductive.base import DisjointInductivePathDataset

from pykeen.models.inductive import InductiveNodePieceGNN

from pykeen.models import predict
from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation.rank_based_evaluator import SampledRankBasedEvaluator
from pykeen.stoppers import EarlyStopper
from pykeen.losses import NSSALoss

import torch
from torch.optim import Adam

import pandas as pd

In [2]:
BASE_URL = '../data/SocialTalk'

TRAIN_URL = f'{BASE_URL}/training/train.txt'
INDUCTIVE_INFERENCE_URL = f'{BASE_URL}/inference/train.txt'
INDUCTIVE_VALIDATION_URL = f'{BASE_URL}/inference/valid.txt'
INDUCTIVE_TEST_URL = f'{BASE_URL}/inference/test.txt'

In [3]:
class InductiveClientData(DisjointInductivePathDataset):
    def __init__(self, create_inverse_triples: bool = False, **kwargs):
        """Initialize client data from triples file path"""
        super().__init__(
            transductive_training_path=TRAIN_URL,
            inductive_inference_path=INDUCTIVE_INFERENCE_URL,
            inductive_validation_path=INDUCTIVE_VALIDATION_URL,
            inductive_testing_path=INDUCTIVE_TEST_URL,
            create_inverse_triples=create_inverse_triples,
            eager=True,
            **kwargs,
        )

In [4]:
# dataset = InductiveFB15k237(version="v4", create_inverse_triples=True)
dataset = InductiveClientData(create_inverse_triples=True)

You're trying to map triples with 15 entities and 0 relations that are not in the training set. These triples will be excluded from the mapping.
In total 15 from 16084 triples were filtered out
You're trying to map triples with 9 entities and 0 relations that are not in the training set. These triples will be excluded from the mapping.
In total 9 from 8042 triples were filtered out


In [5]:
model = InductiveNodePieceGNN(
    triples_factory=dataset.transductive_training,  # training factory, will be also used for a GNN
    inference_factory=dataset.inductive_inference,  # inference factory, will be used for a GNN
    num_tokens=12,  # length of a node hash - how many unique relations per node will be used
    aggregation="mlp",  # aggregation function, defaults to an MLP, can be any PyTorch function
    loss=NSSALoss(margin=15),  # dummy loss
    random_seed=42,
    gnn_encoder=None,  # defaults to a 2-layer CompGCN with DistMult composition function
)
optimizer = Adam(params=model.parameters(), lr=0.0005)

                                                      

In [6]:
training_loop = SLCWATrainingLoop(
    triples_factory=dataset.transductive_training,  # training triples
    model=model,
    optimizer=optimizer,
    negative_sampler_kwargs=dict(num_negs_per_pos=32),
    mode="training",   # necessary to specify for the inductive mode - training has its own set of nodes
)

# Validation and Test evaluators use a restricted protocol ranking against 50 random negatives
valid_evaluator = SampledRankBasedEvaluator(
    mode="validation",   # necessary to specify for the inductive mode - this will use inference nodes
    evaluation_factory=dataset.inductive_validation,  # validation triples to predict
    additional_filter_triples=dataset.inductive_inference.mapped_triples,   # filter out true inference triples
)

# According to the original code
# https://github.com/kkteru/grail/blob/2a3dffa719518e7e6250e355a2fb37cd932de91e/test_ranking.py#L526-L529
# test filtering uses only the inductive_inference split and does not include inductive_validation triples
# If you use the full RankBasedEvaluator, both inductive_inference and inductive_validation triples
# must be added to the additional_filter_triples
test_evaluator = SampledRankBasedEvaluator(
    mode="testing",   # necessary to specify for the inductive mode - this will use inference nodes
    evaluation_factory=dataset.inductive_testing,  # test triples to predict
    additional_filter_triples=dataset.inductive_inference.mapped_triples,   # filter out true inference triples
)

early_stopper = EarlyStopper(
    model=model,
    training_triples_factory=dataset.inductive_inference,
    evaluation_triples_factory=dataset.inductive_validation,
    frequency=1,
    patience=10000,  # for test reasons, turn it off
    result_tracker=None,
    evaluation_batch_size=256,
    evaluator=valid_evaluator,
)

In [7]:
# Training starts here
training_loop.train(
    triples_factory=dataset.transductive_training,
    stopper=early_stopper,
    num_epochs=100,
)

Training epochs on cpu: 100%|██████████| 100/100 [2:13:25<00:00, 80.06s/epoch, loss=0.000162, prev_loss=0.000162] 


[0.0008836611462998171,
 0.0006925249483855627,
 0.00042594519832943274,
 0.00028770030688344173,
 0.0002420711339997736,
 0.00021706705763834005,
 0.00020351311368394672,
 0.00019560123517120488,
 0.00018845166221624648,
 0.00018524378642488078,
 0.0001822579473505665,
 0.00017933882961923488,
 0.0001771795113855765,
 0.00017653864028490708,
 0.00017603478839470674,
 0.00017394494767156076,
 0.00017303186589369616,
 0.00017116162997848574,
 0.00017160748900795555,
 0.00017197314593931696,
 0.00017116766067479637,
 0.0001711466093292549,
 0.00016993859831978422,
 0.0001684870130050393,
 0.00017005952751816833,
 0.0001687661536985014,
 0.00016803606700864564,
 0.0001673643835319347,
 0.00016759869220956002,
 0.00016666413568919414,
 0.0001671467188827367,
 0.0001666981819530103,
 0.00016750877772261244,
 0.0001669634851665084,
 0.00016658228487648308,
 0.0001655656445388352,
 0.00016657433025464712,
 0.00016659135872032493,
 0.00016523730695067694,
 0.00016535682385959902,
 0.0001655962

In [8]:
# Test evaluation
result = test_evaluator.evaluate(
    model=model,
    mapped_triples=dataset.inductive_testing.mapped_triples,
    additional_filter_triples=dataset.inductive_inference.mapped_triples,
    batch_size=256,
)

Evaluating on cpu: 100%|██████████| 8.03k/8.03k [00:03<00:00, 2.25ktriple/s]


In [9]:
result.get_metric('hits@10')

0.6296526826839288

In [None]:
torch.save(model.state_dict(), '../data/SocialTalk/kg.model')

In [33]:
def get_example_entity_representations(inference=False):
    if inference:
        entity_representation_modules = model.inference_representation
        index = dataset.inductive_inference.entity_id_to_label.values()
    else:
        entity_representation_modules = model.entity_representations
        index = dataset.transductive_training.entity_id_to_label.values()

    entity_embeddings = entity_representation_modules[0]
    entity_embedding_tensor = entity_embeddings()
    return pd.DataFrame(data=entity_embedding_tensor.detach().numpy(), index=index)

In [37]:
embeddings = get_example_entity_representations(inference=True)
embeddings.to_csv('../data/SocialTalk/inference/embeddings.csv')

### Inference from saved model


In [40]:
model2 = InductiveNodePieceGNN(
    triples_factory=dataset.transductive_training,  # training factory, will be also used for a GNN
    inference_factory=dataset.inductive_inference,  # inference factory, will be used for a GNN
    num_tokens=12,  # length of a node hash - how many unique relations per node will be used
    aggregation="mlp",  # aggregation function, defaults to an MLP, can be any PyTorch function
    loss=NSSALoss(margin=15),  # dummy loss
    random_seed=42,
    gnn_encoder=None,  # defaults to a 2-layer CompGCN with DistMult composition function
)

                                                      

In [42]:
model2.load_state_dict(torch.load('../data/SocialTalk/kg.model'))

<All keys matched successfully>

In [53]:
embeddings = get_example_entity_representations(inference=False)

In [105]:
desired_profiles = pd.read_csv('../data/SocialTalk/Apres-profiles-SocialTalk.csv').set_index('Username')

In [61]:
user_lookup = pd.read_csv('../data/SocialTalk/audience_populated.csv')['Username']

In [107]:
top_25_profile_interests = desired_profiles.iloc[:, 32:57]  # INCLUDE (list of items per user)
top_25_audience_report = desired_profiles.iloc[:, 191:]  # INCLUDE (tuples w/ percentages)
age_distribution = desired_profiles.iloc[:, 183:189]  # INCLUDE (list of percentages)

In [109]:
x = top_25_profile_interests.unstack()
x = x[x.notnull()].reset_index().drop('level_0', axis=1)
x.columns = ['user_id', 'interest_category']
x.to_csv('../data/SocialTalk/predict/interest_categories.csv', index=None)

In [110]:
x1 = top_25_audience_report[[i for i in top_25_audience_report if i[-1] == '%']]
x2 = top_25_audience_report[[i for i in top_25_audience_report if i[-1] != '%']]

filter_audiences = x1 >= 0.19
filter_audiences.columns = [i[:-2] for i in filter_audiences.columns]

x3 = x2[filter_audiences].unstack()
x3 = x3[x3.notnull()].reset_index().drop('level_0', axis=1)
x3.columns = ['user_id', 'audience_category']
x3.to_csv('../data/SocialTalk/predict/audience_categories.csv', index=None)

In [111]:
filter_ages = age_distribution >= 0.17
x = age_distribution[filter_ages]
x.columns = [i[:-2] for i in x.columns]
x = x.unstack()
x = x[x.notnull()].reset_index().drop(0, axis=1)
x.columns = ['age_category', 'user_id']
x.to_csv('../data/SocialTalk/predict/age_categories.csv', index=None)