Skip to content

Commit

Permalink
added MultiNLI (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
kahne authored and mttk committed Sep 11, 2018
1 parent 1008644 commit 300a378
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ The datasets module currently contains:

* Sentiment analysis: SST and IMDb
* Question classification: TREC
* Entailment: SNLI
* Entailment: SNLI, MultiNLI
* Language modeling: abstract class + WikiText-2
* Machine translation: abstract class + Multi30k, IWSLT, WMT14
* Sequence tagging (e.g. POS/NER): abstract class + UDPOS
Expand Down
30 changes: 30 additions & 0 deletions test/snli.py → test/nli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torchtext import data
from torchtext import datasets

# Testing SNLI
TEXT = data.Field()
LABEL = data.Field(sequential=False)

Expand All @@ -27,3 +28,32 @@
print(batch.premise)
print(batch.hypothesis)
print(batch.label)


# Testing MultiNLI
TEXT = data.Field()
LABEL = data.Field(sequential=False)

train, val, test = datasets.MultiNLI.splits(TEXT, LABEL)

print(train.fields)
print(len(train))
print(vars(train[0]))

TEXT.build_vocab(train)
LABEL.build_vocab(train)

train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_size=3)

batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)

train_iter, val_iter, test_iter = datasets.MultiNLI.iters(batch_size=4)

batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)
3 changes: 2 additions & 1 deletion torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .language_modeling import LanguageModelingDataset, WikiText2, PennTreebank # NOQA
from .snli import SNLI
from .nli import SNLI, MultiNLI
from .sst import SST
from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA
from .sequence_tagging import SequenceTaggingDataset, UDPOS, CoNLL2000Chunking # NOQA
Expand All @@ -10,6 +10,7 @@

__all__ = ['LanguageModelingDataset',
'SNLI',
'MultiNLI',
'SST',
'TranslationDataset',
'Multi30k',
Expand Down
45 changes: 35 additions & 10 deletions torchtext/datasets/snli.py → torchtext/datasets/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, eos_token='<pad>', lower=False):
list(reversed(p)) for p in parse])


class SNLI(data.TabularDataset):
class NLIDataset(data.TabularDataset):

urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip']
dirname = 'snli_1.0'
name = 'snli'
urls = []
dirname = ''
name = 'nli'

@staticmethod
def sort_key(ex):
Expand All @@ -35,8 +35,7 @@ def sort_key(ex):

@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl',
test='snli_1.0_test.jsonl'):
train='train.jsonl', validation='val.jsonl', test='test.jsonl'):
"""Create dataset objects for splits of the SNLI dataset.
This is the most flexible way to use the dataset.
Expand All @@ -48,8 +47,7 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data',
parse_field: The field that will be used for shift-reduce parser
transitions, or None to not include them.
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose snli_1.0
subdirectory the data files will be stored.
expanded into.
train: The filename of the train data. Default: 'train.jsonl'.
validation: The filename of the validation data, or None to not
load the validation set. Default: 'dev.jsonl'.
Expand All @@ -59,13 +57,13 @@ def splits(cls, text_field, label_field, parse_field=None, root='.data',
path = cls.download(root)

if parse_field is None:
return super(SNLI, cls).splits(
return super(NLIDataset, cls).splits(
path, root, train, validation, test,
format='json', fields={'sentence1': ('premise', text_field),
'sentence2': ('hypothesis', text_field),
'gold_label': ('label', label_field)},
filter_pred=lambda ex: ex.label != '-')
return super(SNLI, cls).splits(
return super(NLIDataset, cls).splits(
path, root, train, validation, test,
format='json', fields={'sentence1_binary_parse':
[('premise', text_field),
Expand Down Expand Up @@ -113,3 +111,30 @@ def iters(cls, batch_size=32, device=0, root='.data',

return data.BucketIterator.splits(
(train, val, test), batch_size=batch_size, device=device)


class SNLI(NLIDataset):
urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip']
dirname = 'snli_1.0'
name = 'snli'

@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl',
test='snli_1.0_test.jsonl'):
return super(SNLI, cls).splits(text_field, label_field, parse_field,
root, train, validation, test)


class MultiNLI(NLIDataset):
urls = ['http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip']
dirname = 'multinli_1.0'
name = 'multinli'

@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='multinli_1.0_train.jsonl',
validation='multinli_1.0_dev_matched.jsonl',
test='multinli_1.0_dev_mismatched.jsonl'):
return super(MultiNLI, cls).splits(text_field, label_field, parse_field,
root, train, validation, test)

0 comments on commit 300a378

Please sign in to comment.