In [15]:
#导入数据集
import pandas as pd
df = pd.read_excel('/home/cjw/KGtest/triples.xls')
row_num = len(df.index.values)
print("三元组数：", row_num)

三元组数： 181


In [23]:
import numpy as np
# 6:2:2划分为训练集、验证集、测试集
train, valid, test = np.split(df.sample(frac=1), [int(.6*len(df)), int(.8*len(df))])
writer = pd.ExcelWriter('/home/cjw/KGtest/train.xlsx')
train.to_excel(writer, 'Sheet', index=False)
writer.save()
writer = pd.ExcelWriter('/home/cjw/KGtest/valid.xlsx')
valid.to_excel(writer, 'Sheet', index=False)
writer.save()
writer = pd.ExcelWriter('/home/cjw/KGtest/test.xlsx')
test.to_excel(writer, 'Sheet', index=False)
writer.save()


In [None]:
from ordered_set import OrderedSet
from collections import defaultdict as ddict, Counter
from torch.utils.data import DataLoader
# 数据预处理
def load_data():
    ent_set, rel_set = OrderedSet(), OrderedSet()
    for split in ['train', 'valid', 'test']:
        for line in pd.read_excel(f'/home/cjw/KGtest/{split}.xlsx').values.tolist():
            sub, rel, obj = line[0], line[4], line[2]
            ent_set.add(sub)
            rel_set.add(rel)
            ent_set.add(obj)

    ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
    rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
    rel2id.update({rel + '_reverse': idx + len(rel2id) for idx, rel in enumerate(rel_set)})
    
    id2ent = {idx: ent for ent, idx in ent2id.items()}
    id2rel = {idx: rel for rel, idx in rel2id.items()}

    num_ent = len(ent2id)
    num_rel = len(rel2id) // 2
    # embed_dim = self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim

    data = ddict(list)
    sr2o = ddict(set)

    for split in ['train', 'valid', 'test']:
        for line in pd.read_excel(f'/home/cjw/KGtest/{split}.xlsx').values.tolist():
            sub, rel, obj = line[0], line[4], line[2]
            sub, rel, obj = ent2id[sub], rel2id[rel], ent2id[obj]
            data[split].append((sub, rel, obj))
            if split == 'train':
                sr2o[(sub, rel)].add(obj)
                sr2o[(obj, rel + num_rel)].add(sub)
    data = dict(data)
    sr2o = {k: list(v) for k, v in sr2o.items()}
    # for split in ['test', 'valid']:
    #     for sub, rel, obj in data[split]:
    #         sr2o[(sub, rel)].add(obj)
    #         sr2o[(obj, rel + num_rel)].add(sub)

    sr2o_all = {k: list(v) for k, v in sr2o.items()}
    triples = ddict(list)
    for (sub, rel), obj in sr2o.items():
        triples['train'].append({'triple': (sub, rel, -1), 'label': sr2o[(sub, rel)], 'sub_samp': 1})
        
    # for split in ['test', 'valid']:
    #     for sub, rel, obj in data[split]:
    #         rel_inv = rel + num_rel
    #         triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 'label': sr2o_all[(sub, rel)]})
    #         triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': sr2o_all[(obj, rel_inv)]})
    triples = dict(triples)
    print(triples)

    num_workers = 0
    def get_data_loader(dataset_class, split, batch_size, shuffle=True):
        return DataLoader(
            dataset_class(triples[split])
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=max(0, num_workers),
            collate_fn=dataset_class.collate_fn
        )

    self.data_iter = {
        'train'	:   get_data_loader(TrainDataset, 'train', 	self.p.batch_size),
        'valid_head'	:   get_data_loader(TestDataset,  'valid_head', self.p.batch_size),
        'valid_tail'	:   get_data_loader(TestDataset,  'valid_tail', self.p.batch_size),
        'test_head'	:   get_data_loader(TestDataset,  'test_head',  self.p.batch_size),
        'test_tail'	:   get_data_loader(TestDataset,  'test_tail',  self.p.batch_size),
    }


        
load_data()

In [None]:
class TrainDataset(Dataset):
	def __init__(self, triples, params):
		self.triples	= triples
		self.p 		= params
		self.entities	= np.arange(self.p.num_ent, dtype=np.int32)

	def __len__(self):
		return len(self.triples)

	def __getitem__(self, idx):
		ele			= self.triples[idx]
		triple, label, sub_samp	= torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp'])
		trp_label		= self.get_label(label)

		if self.p.lbl_smooth != 0.0:
			trp_label = (1.0 - self.p.lbl_smooth)*trp_label + (1.0/self.p.num_ent)

		return triple, trp_label, None, None

	@staticmethod
	def collate_fn(data):
		triple		= torch.stack([_[0] 	for _ in data], dim=0)
		trp_label	= torch.stack([_[1] 	for _ in data], dim=0)

		return triple, trp_label
	
	def get_label(self, label):
		y = np.zeros([self.p.num_ent], dtype=np.float32)
		for e2 in label: y[e2] = 1.0
		return torch.FloatTensor(y)


class TestDataset(Dataset):
	def __init__(self, triples, params):
		self.triples	= triples
		self.p 		= params

	def __len__(self):
		return len(self.triples)

	def __getitem__(self, idx):
		ele		= self.triples[idx]
		triple, label	= torch.LongTensor(ele['triple']), np.int32(ele['label'])
		label		= self.get_label(label)

		return triple, label

	@staticmethod
	def collate_fn(data):
		triple		= torch.stack([_[0] 	for _ in data], dim=0)
		label		= torch.stack([_[1] 	for _ in data], dim=0)
		return triple, label
	
	def get_label(self, label):
		y = np.zeros([self.p.num_ent], dtype=np.float32)
		for e2 in label: y[e2] = 1.0
		return torch.FloatTensor(y)

In [65]:
def load_data(self):

        ent_set, rel_set = OrderedSet(), OrderedSet()
        for split in ['train', 'test', 'valid']:
            for line in open('../data/{}/{}.txt'.format(self.p.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split('\t'))
                ent_set.add(sub)
                rel_set.add(rel)
                ent_set.add(obj)

        self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
        self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
        self.rel2id.update({rel + '_reverse': idx + len(self.rel2id) for idx, rel in enumerate(rel_set)})
        

        self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
        self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}

        self.p.num_ent = len(self.ent2id)
        self.p.num_rel = len(self.rel2id) // 2
        self.p.embed_dim = self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim

        self.data = ddict(list)
        sr2o = ddict(set)

        for split in ['train', 'test', 'valid']:
            for line in open('../data/{}/{}.txt'.format(self.p.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split('\t'))
                sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]

                self.data[split].append((sub, rel, obj))

                if split == 'train':
                    sr2o[(sub, rel)].add(obj)
                    sr2o[(obj, rel + self.p.num_rel)].add(sub)
        self.data = dict(self.data)

        self.sr2o = {k: list(v) for k, v in sr2o.items()}
        for split in ['test', 'valid']:
            for sub, rel, obj in self.data[split]:
                sr2o[(sub, rel)].add(obj)
                sr2o[(obj, rel + self.p.num_rel)].add(sub)

        self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
        self.triples = ddict(list)

        for (sub, rel), obj in self.sr2o.items():
            self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})


        for split in ['test', 'valid']:
            for sub, rel, obj in self.data[split]:
                rel_inv = rel + self.p.num_rel
                self.triples['{}_{}'.format(split, 'tail')].append(
                    {'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
                self.triples['{}_{}'.format(split, 'head')].append(
                    {'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})
        print(self.triples)

        self.triples = dict(self.triples)

        def get_data_loader(dataset_class, split, batch_size, shuffle=True):
            return DataLoader(
                dataset_class(self.triples[split], self.p),
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=max(0, self.p.num_workers),
                collate_fn=dataset_class.collate_fn
            )

        self.data_iter = {
            'train'	:   get_data_loader(TrainDataset, 'train', 	self.p.batch_size),
            'valid_head'	:   get_data_loader(TestDataset,  'valid_head', self.p.batch_size),
            'valid_tail'	:   get_data_loader(TestDataset,  'valid_tail', self.p.batch_size),
            'test_head'	:   get_data_loader(TestDataset,  'test_head',  self.p.batch_size),
            'test_tail'	:   get_data_loader(TestDataset,  'test_tail',  self.p.batch_size),
        }
