In [8]:
import sys
sys.path.insert(0, '..')

from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils.datasets import load_fb15k
from torchkge.utils import MarginLoss, DataLoader

In [11]:
kg_train, _, _ = load_fb15k('../data')
dataloader = DataLoader(kg_train, batch_size=4)
data = next(iter(dataloader))
data

(tensor([ 3920,   839, 10094,  2587]),
 tensor([9220, 9523,  775, 7238]),
 tensor([ 791, 1273,  846,   95]))

## Negative Sampling

In [48]:
from torchkge.utils.operations import get_bernoulli_probs
from torch import tensor, bernoulli, cat, randint
from pandas import DataFrame

In [23]:
t = cat((kg_train.head_idx.view(-1,1), 
         kg_train.tail_idx.view(-1,1),
         kg_train.relations.view(-1,1)), dim=1)
t

tensor([[ 3920,  9220,   791],
        [  839,  9523,  1273],
        [10094,   775,   846],
        ...,
        [ 8170, 13876,   100],
        [13723,  8121,  1030],
        [ 6647,  6817,   851]])

### Compute bernoulli distribution for each relation

In [24]:
# get average number of head per tail for each relation

df = DataFrame(t.numpy(), columns=['from','to','rel'])
df = df.groupby(['rel','to']).count().groupby('rel').mean()
df.reset_index(inplace=True)
hpt = {i:v for i,v in df.values}
hpt

{0.0: 1.0,
 1.0: 1.0,
 2.0: 1.0,
 3.0: 1.0,
 4.0: 1.0,
 5.0: 1.0,
 6.0: 1.0,
 7.0: 1.5,
 8.0: 1.48,
 9.0: 1.8888888888888888,
 10.0: 1.4482758620689655,
 11.0: 2.0,
 12.0: 1.0,
 13.0: 2.5,
 14.0: 1.0,
 15.0: 2.0,
 16.0: 1.0,
 17.0: 1.0,
 18.0: 1.5,
 19.0: 1.0,
 20.0: 1.4285714285714286,
 21.0: 1.0,
 22.0: 40.6,
 23.0: 1.0,
 24.0: 35.8125,
 25.0: 1.0,
 26.0: 28.0,
 27.0: 2.6,
 28.0: 32.05,
 29.0: 2.642857142857143,
 30.0: 31.35,
 31.0: 1.5,
 32.0: 1.0,
 33.0: 1.0,
 34.0: 1.4,
 35.0: 1.625,
 36.0: 1.0,
 37.0: 1.0,
 38.0: 1.1428571428571428,
 39.0: 1.25,
 40.0: 3.0,
 41.0: 1.25,
 42.0: 1.5,
 43.0: 1.0,
 44.0: 1.5555555555555556,
 45.0: 1.0,
 46.0: 1.0,
 47.0: 1.0,
 48.0: 1.0,
 49.0: 1.0,
 50.0: 1.0,
 51.0: 4.0,
 52.0: 1.0,
 53.0: 4.0,
 54.0: 1.0,
 55.0: 2.0,
 56.0: 1.0,
 57.0: 3.0,
 58.0: 1.0,
 59.0: 2.6666666666666665,
 60.0: 1.0416666666666667,
 61.0: 1.6666666666666667,
 62.0: 3.5,
 63.0: 1.6666666666666667,
 64.0: 1.0,
 65.0: 1.0,
 66.0: 2.2857142857142856,
 67.0: 1.0,
 68.0: 2.0,
 69

In [25]:
# get average number of tail per head for each relation

df = DataFrame(t.numpy(), columns=['from','to','rel'])
df = df.groupby(['from','rel']).count().groupby('rel').mean()
df.reset_index(inplace=True)
tph = {i:v for i,v in df.values}
tph

{0.0: 1.0,
 1.0: 1.0,
 2.0: 1.0,
 3.0: 38.0,
 4.0: 1.0,
 5.0: 1.0,
 6.0: 1.0,
 7.0: 1.25,
 8.0: 2.3125,
 9.0: 1.3076923076923077,
 10.0: 2.625,
 11.0: 1.3333333333333333,
 12.0: 1.6666666666666667,
 13.0: 1.6666666666666667,
 14.0: 1.0,
 15.0: 1.4285714285714286,
 16.0: 1.0,
 17.0: 2.0,
 18.0: 1.0,
 19.0: 1.6666666666666667,
 20.0: 2.3076923076923075,
 21.0: 1.0,
 22.0: 9.36923076923077,
 23.0: 1.0,
 24.0: 8.815384615384616,
 25.0: 1.0,
 26.0: 1.0,
 27.0: 1.3928571428571428,
 28.0: 7.202247191011236,
 29.0: 1.48,
 30.0: 7.125,
 31.0: 1.5,
 32.0: 1.0,
 33.0: 1.0,
 34.0: 1.4,
 35.0: 1.4444444444444444,
 36.0: 1.0,
 37.0: 1.0,
 38.0: 1.0,
 39.0: 1.0,
 40.0: 1.0,
 41.0: 1.0,
 42.0: 1.0,
 43.0: 1.0,
 44.0: 1.4,
 45.0: 1.0,
 46.0: 1.0,
 47.0: 7.0,
 48.0: 1.0,
 49.0: 1.3333333333333333,
 50.0: 1.5,
 51.0: 1.0,
 52.0: 4.0,
 53.0: 1.0,
 54.0: 1.0,
 55.0: 1.0,
 56.0: 2.0,
 57.0: 1.2,
 58.0: 1.0,
 59.0: 2.2857142857142856,
 60.0: 5.0,
 61.0: 1.0,
 62.0: 1.75,
 63.0: 1.6666666666666667,
 64.0: 2.0

    relation   3: hpt = 1    , tph = 38   (1-to-N)
    relation 110: hpt = 47.86, tph = 1.23 (N-to-1)

In [44]:
bern_prob = [0.5]*kg_train.n_rel
for r in tph.keys():
    bern_prob[int(r)] = tph[r] / (tph[r] + hpt[r])

bern_prob = tensor(bern_prob).float()    
bern_prob

tensor([0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.1600, 0.4783])

    bern_prob[r] ~= 1 : tph[r] >> hpt[r] => 1-to-N
    bern_prob[r] ~= 0 : tph[r] << hpt[r] => N-to-1

### Create corrupted triplets

In [38]:
n_neg = 2 # the number of invalid triple per triple 

heads, tails, relations = data[0], data[1], data[2] # Batch triplets
batch_size = heads.shape[0]
neg_heads = heads.repeat(n_neg)
neg_tails = tails.repeat(n_neg)

In [43]:
bern_prob

[0.5,
 0.5,
 0.5,
 0.9743589743589743,
 0.5,
 0.5,
 0.5,
 0.45454545454545453,
 0.6097560975609756,
 0.40909090909090906,
 0.6444444444444445,
 0.4,
 0.625,
 0.39999999999999997,
 0.5,
 0.41666666666666663,
 0.5,
 0.6666666666666666,
 0.4,
 0.625,
 0.6176470588235293,
 0.5,
 0.1875,
 0.5,
 0.19753086419753088,
 0.5,
 0.034482758620689655,
 0.3488372093023256,
 0.1834862385321101,
 0.358974358974359,
 0.18518518518518517,
 0.5,
 0.5,
 0.5,
 0.5,
 0.4705882352941176,
 0.5,
 0.5,
 0.4666666666666667,
 0.4444444444444444,
 0.25,
 0.4444444444444444,
 0.4,
 0.5,
 0.4736842105263157,
 0.5,
 0.5,
 0.875,
 0.5,
 0.5714285714285715,
 0.6,
 0.2,
 0.8,
 0.2,
 0.5,
 0.3333333333333333,
 0.6666666666666666,
 0.2857142857142857,
 0.5,
 0.4615384615384615,
 0.8275862068965517,
 0.37499999999999994,
 0.3333333333333333,
 0.5,
 0.6666666666666666,
 0.5,
 0.5384615384615384,
 0.896551724137931,
 0.5,
 0.5454545454545454,
 0.5555555555555556,
 0.8944099378881988,
 0.5,
 0.11904761904761904,
 0.2428571428

In [62]:
mask = bernoulli(bern_prob[relations].repeat(n_neg)).double()
mask

tensor([0., 1., 1., 1., 0., 0., 1., 0.], dtype=torch.float64)

    [0,1,1,1] : 0 번째는 tail을 바꾸고, 1,2,3 번째는 head를 바꿔라
    [0,0,1,0] : 0,1,3 번째는 tail을 바꾸고, 2 번째는 head를 바꿔라

In [65]:
n_h_cor = int(mask.sum().item())

neg_heads[mask==1] = randint(1, kg_train.n_ent, (n_h_cor,))
neg_tails[mask==0] = randint(1, kg_train.n_ent, (batch_size*n_neg-n_h_cor,))

In [76]:
print(heads)
print(neg_heads)

tensor([ 3920,   839, 10094,  2587])
tensor([ 3920, 10470,  4018, 11415,  3920,   839,  5688,  2587])


In [77]:
print(tails)
print(neg_tails)

tensor([9220, 9523,  775, 7238])
tensor([ 8416,  9523,   775,  7238, 11422,  3980,   775,  3810])


### Get scalar energy using both triplets

In [78]:
ent_emb_dim=10
model = TransEModel(ent_emb_dim, kg_train.n_ent, 
                    kg_train.n_rel, dissimilarity_type='L2', )

In [79]:
pos, neg = model(heads, tails, neg_heads, neg_tails, relations)

In [143]:
from torch.nn.functional import normalize

pos_h = normalize(model.ent_emb(heads.repeat(2)), p=2, dim=1)  # keep dimension True
pos_t = normalize(model.ent_emb(tails.repeat(2)), p=2, dim=1)
rel = model.rel_emb(relations.repeat(2))

In [148]:
from torchkge.utils.dissimilarities import l2_dissimilarity

l2_dissimilarity(pos_h+rel, pos_t)

tensor([2.5196, 3.1738, 1.9569, 4.2622, 2.5196, 3.1738, 1.9569, 4.2622],
       grad_fn=<PowBackward0>)

In [163]:
temp = ((pos_h + rel)-pos_t)
temp

tensor([[ 0.5274, -0.2814,  0.3220, -0.1735,  0.1690, -0.0109,  0.6211,  0.0237,
          1.1920, -0.4390],
        [ 0.7475, -1.0720,  0.0696, -0.7078, -0.2831,  0.4369,  0.2699, -0.6730,
          0.1101, -0.3887],
        [ 0.0875, -0.7540, -0.1508, -0.7015, -0.1846, -0.2253, -0.4568,  0.0630,
         -0.5760, -0.4865],
        [ 0.1621, -1.3758,  0.1130, -0.9421, -0.5293, -0.7109,  0.4535, -0.3250,
          0.2047, -0.5513],
        [ 0.5274, -0.2814,  0.3220, -0.1735,  0.1690, -0.0109,  0.6211,  0.0237,
          1.1920, -0.4390],
        [ 0.7475, -1.0720,  0.0696, -0.7078, -0.2831,  0.4369,  0.2699, -0.6730,
          0.1101, -0.3887],
        [ 0.0875, -0.7540, -0.1508, -0.7015, -0.1846, -0.2253, -0.4568,  0.0630,
         -0.5760, -0.4865],
        [ 0.1621, -1.3758,  0.1130, -0.9421, -0.5293, -0.7109,  0.4535, -0.3250,
          0.2047, -0.5513]], grad_fn=<SubBackward0>)

In [176]:
from torch import sum

sum((abs(temp)**2), dim=1)

tensor([2.5196, 3.1738, 1.9569, 4.2622, 2.5196, 3.1738, 1.9569, 4.2622],
       grad_fn=<SumBackward1>)

In [169]:
((pos_h + rel)-pos_t).norm(p=2, dim=1)

tensor([1.5873, 1.7815, 1.3989, 2.0645, 1.5873, 1.7815, 1.3989, 2.0645],
       grad_fn=<NormBackward1>)

In [80]:
pos

tensor([-2.5196, -3.1738, -1.9569, -4.2622, -2.5196, -3.1738, -1.9569, -4.2622],
       grad_fn=<RepeatBackward>)

In [81]:
neg

tensor([-2.9096, -5.7738, -1.5816, -2.8168, -3.5563, -1.7940, -1.0089, -3.2351],
       grad_fn=<NegBackward>)