Skip to content

Commit

Permalink
[WIP] Fixing data cache (#1314)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet committed May 20, 2021
1 parent 36e33e2 commit 99557ef
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
21 changes: 21 additions & 0 deletions .circleci/cached_datasets_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
IMDB
AG_NEWS
SogouNews
DBpedia
YelpReviewPolarity
YelpReviewFull
YahooAnswers
AmazonReviewPolarity
AmazonReviewFull
UDPOS
CoNLL2000Chunking
Multi30k
IWSLT2016
IWSLT2017
WMT14
WikiText2
WikiText103
PennTreebank
SQuAD1
SQuAD2
EnWik9
4 changes: 3 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ commands:
steps:
- run:
name: Generate CCI cache key
command: echo "$(date "+%D")" > .cachekey
command:
echo "$(date "+%D")" > .cachekey
cat cached_datasets_list.txt >> .cachekey
- persist_to_workspace:
root: .
paths:
Expand Down
4 changes: 3 additions & 1 deletion .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ commands:
steps:
- run:
name: Generate CCI cache key
command: echo "$(date "+%D")" > .cachekey
command:
echo "$(date "+%D")" > .cachekey
cat cached_datasets_list.txt >> .cachekey
- persist_to_workspace:
root: .
paths:
Expand Down
7 changes: 5 additions & 2 deletions test/common/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
def check_cache_status():
assert os.path.exists(CACHE_STATUS_FILE), "Cache status file does not exists"
with open(CACHE_STATUS_FILE, 'r') as f:
missing_datasets = []
cache_status = json.load(f)
for dataset_name in cache_status:
for split in cache_status[dataset_name]:
if cache_status[dataset_name][split]['status'] == "fail":
raise FileNotFoundError("Failing all raw dataset unit tests as cache is missing atleast one raw dataset")
missing_datasets.append(dataset_name + '_' + split)
if missing_datasets:
raise FileNotFoundError("Failing all raw dataset unit tests as cache is missing {} datasets".format(missing_datasets))


def generate_data_cache():
Expand All @@ -30,7 +33,7 @@ def generate_data_cache():
if dataset_name not in cache_status:
cache_status[dataset_name] = {}
try:
if dataset_name == "Multi30k" or dataset_name == 'WMT14':
if dataset_name == 'WMT14':
_ = torchtext.experimental.datasets.raw.DATASETS[dataset_name](split=split)
else:
_ = torchtext.datasets.DATASETS[dataset_name](split=split)
Expand Down

0 comments on commit 99557ef

Please sign in to comment.