In [None]:
dataset_name = 'WebNLG'
train_path = "data/{}/train_triples.json".format(dataset_name)
rel_dict_path = "data/{}/rel2id.json".format(dataset_name)

In [2]:
import json
import torch
from torch import nn
f = json.load(open(rel_dict_path))
num_rels = len(f[0])

In [3]:
def co_occurance(train_path, rel_dict_path):
    #计算关系共现矩阵
    
    #初始化
    train_data = json.load(open(train_path))
    rel_dic = json.load(open(rel_dict_path))[1]
    co_occurance_matrix = torch.zeros([len(rel_dic), len(rel_dic)])
    
    #统计共现频数
    for data_dic in train_data:
        triple_list = data_dic['triple_list']
        rels = [rel_dic[x[1].strip()] for x in triple_list]
        
        for i in range(len(rels)):
            for j in range(i+1, len(rels)):
                co_occurance_matrix[rels[i]][rels[j]] += 1
                co_occurance_matrix[rels[j]][rels[i]] += 1
    
    return co_occurance_matrix

def weight(x_ij, x_max=100):
    if x_ij<x_max:
        return torch.pow(x_ij/x_max, 0.75)
    else:
        return 1

In [4]:
co = co_occurance(train_path, rel_dict_path)
eps = 1e-5

In [5]:
# co += eps
# co = co/co.sum(dim=1)

In [6]:
class RCO(nn.Module):
    def __init__(self, num_rels, em_dim):
        super(RCO, self).__init__()
        self.em = nn.Embedding(num_rels, em_dim)
        self.bias = nn.Embedding(num_rels, 1)
    
    def forward(self, id_pair):
        i, j = id_pair[:,0], id_pair[:, 1]
        wi = self.em(i)
        wj = self.em(j)
        bi = self.bias(i)
        bj = self.bias(j)
        return torch.diag(torch.matmul(wi, wj.T))+bi+bj
        

In [7]:
model = RCO(num_rels, 768)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [8]:
x = []
y = []
for i in range(num_rels):
    for j in range(i):
        x.append([i,j])
        y.append(co[i][j])

x = torch.tensor(x*10)
y = torch.tensor(y*10)

from torch.utils.data import Dataset, DataLoader
class Mydata(Dataset):
    def __init__(self, x, y):
        self.data = list(zip(x,y))
        
    def __getitem__(self, idx):
        assert idx < len(self.data)
        return self.data[idx]
    def __len__(self):
        return len(self.data)


In [9]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset = Mydata(x, y)
dataloader = DataLoader(dataset, batch_size = 1, shuffle=True)

In [10]:
def train(dataloader, model, loss_fn, optimizer):
    model = model.to(device)
    size = len(dataloader.dataset)
    total_loss = []
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        i, j = X[0]
        loss = weight(co[i][j])*loss_fn(pred, y)
#         loss = 
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss.append(loss.item())
        
    return torch.mean(torch.tensor(total_loss))

In [14]:
for epoch in range(10):
    loss = train(dataloader, model, loss_fn, optimizer).item()
    print(f"loss: {loss:>7f}")

loss: 0.187095
loss: 0.180174
loss: 0.186903
loss: 0.177222
loss: 0.161811
loss: 0.151648
loss: 0.152723
loss: 0.149716
loss: 0.143706
loss: 0.150055


In [15]:
torch.save(model.em.state_dict(), 'WebNLG.em')