forked from castorini/castor
/
trecqa.py
76 lines (59 loc) · 3.29 KB
/
trecqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torch
from torchtext.data.dataset import Dataset
from torchtext.data.example import Example
from torchtext.data.field import Field
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors
from datasets.idf_utils import get_pairwise_word_to_doc_freq, get_pairwise_overlap_features
class TRECQA(Dataset):
NAME = 'trecqa'
NUM_CLASSES = 2
ID_FIELD = Field(sequential=False, tensor_type=torch.FloatTensor, use_vocab=False, batch_first=True)
TEXT_FIELD = Field(batch_first=True, tokenize=lambda x: x) # tokenizer is identity since we already tokenized it to compute external features
EXT_FEATS_FIELD = Field(tensor_type=torch.FloatTensor, use_vocab=False, batch_first=True, tokenize=lambda x: x)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True)
@staticmethod
def sort_key(ex):
return len(ex.a)
def __init__(self, path):
"""
Create a TRECQA dataset instance
"""
fields = [('id', self.ID_FIELD), ('a', self.TEXT_FIELD), ('b', self.TEXT_FIELD), ('ext_feats', self.EXT_FEATS_FIELD), ('label', self.LABEL_FIELD)]
examples = []
f1 = open(os.path.join(path, 'a.toks'), 'r')
f2 = open(os.path.join(path, 'b.toks'), 'r')
id_file = open(os.path.join(path, 'id.txt'), 'r')
label_file = open(os.path.join(path, 'sim.txt'), 'r')
sent_list_1 = [l.rstrip('.\n').split(' ') for l in f1]
sent_list_2 = [l.rstrip('.\n').split(' ') for l in f2]
word_to_doc_cnt = get_pairwise_word_to_doc_freq(sent_list_1, sent_list_2)
overlap_feats = get_pairwise_overlap_features(sent_list_1, sent_list_2, word_to_doc_cnt)
for pair_id, l1, l2, ext_feats, label in zip(id_file, sent_list_1, sent_list_2, overlap_feats, label_file):
pair_id = pair_id.rstrip('.\n')
label = label.rstrip('.\n')
example = Example.fromlist([pair_id, l1, l2, ext_feats, label], fields)
examples.append(example)
map(lambda f: f.close(), [f1, f2, label_file])
super(TRECQA, self).__init__(examples, fields)
@classmethod
def splits(cls, path, train='train-all', validation='raw-dev', test='raw-test', **kwargs):
return super(TRECQA, cls).splits(path, train=train, validation=validation, test=test, **kwargs)
@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None, unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param vectors_name: name of word vectors file
:param vectors_cache: path to word vectors file
:param batch_size: batch size
:param device: GPU device
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes
:param unk_init: function used to generate vector for OOV words
:return:
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)
train, validation, test = cls.splits(path)
cls.TEXT_FIELD.build_vocab(train, validation, test, vectors=vectors)
return BucketIterator.splits((train, validation, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)