Skip to content
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

66 lines (52 sloc) 2.37 KB
from typing import Dict
import logging
import os.path as osp
from pathlib import Path
import tarfile
from itertools import chain
from overrides import overrides
from allennlp.common.file_utils import cached_path
from import DatasetReader
from import LabelField, TextField, Field
from import Instance
from import TokenIndexer, SingleIdTokenIndexer
from import Tokenizer, WordTokenizer
logger = logging.getLogger(__name__)
class ImdbDatasetReader(DatasetReader):
TAR_URL = ''
TRAIN_DIR = 'aclImdb/train'
TEST_DIR = 'aclImdb/test'
def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
tokenizer: Tokenizer = None,
lazy: bool = False) -> None:
self._tokenizer = tokenizer or WordTokenizer()
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
def _read(self, file_path):
tar_path = cached_path(self.TAR_URL)
tf =, 'r')
cache_dir = Path(osp.dirname(tar_path))
if not (cache_dir / self.TRAIN_DIR).exists() and not (cache_dir / self.TEST_DIR).exists():
if file_path == 'train':
pos_dir = osp.join(self.TRAIN_DIR, 'pos')
neg_dir = osp.join(self.TRAIN_DIR, 'neg')
elif file_path == 'test':
pos_dir = osp.join(self.TEST_DIR, 'pos')
neg_dir = osp.join(self.TEST_DIR, 'neg')
raise ValueError(f"only 'train' and 'test' are valid for 'file_path', but '{file_path}' is given.")
path = chain(Path(cache_dir.joinpath(pos_dir)).glob('*.txt'),
for p in path:
yield self.text_to_instance(p.read_text(), 0 if 'pos' in str(p) else 1)
def text_to_instance(self, string: str, label: int) -> Instance:
fields: Dict[str, Field] = {}
tokens = self._tokenizer.tokenize(string)
fields['tokens'] = TextField(tokens, self._token_indexers)
fields['label'] = LabelField(label, skip_indexing=True)
return Instance(fields)
You can’t perform that action at this time.