In [1]:
import json
from functools import partial
from pathlib import Path
from typing import Dict, List, Tuple

from sklearn.model_selection import train_test_split

### helpers

In [2]:
Data = List[Dict]

def _write_json_file(data: Data, split: str, path_base: str) -> None:
    path = Path(path_base) / Path(f"{split}.json")
    path.parent.mkdir(exist_ok=True)
    json.dump(data, open(path, 'w'))

def write_splits(
    data: Data,
    split_ratios: Dict[str, float],
    write_dir: str,
) -> Tuple[Data, Data, Data]:
    assert sum(split_ratios.values()) == 1
    split_fn = partial(train_test_split, random_state=random_state)
    n = len(data)
    # split
    data_train, data_rest = split_fn(data, train_size=split_ratios["train"])
    data_dev, data_test = split_fn(data_rest, train_size=int(split_ratios["dev"] * n))
    # sanity checks
    n_train, n_dev, n_test = len(data_train), len(data_dev), len(data_test)
    n_total = n_train + n_dev + n_test
    assert n_total == n
    print(f'train: {n_train}, dev: {n_dev}, test: {n_test}, total = {n_total}')
    # write out
    for data, split in [(data_train, 'train'), (data_dev, 'dev'), (data_test, 'test')]:
        _write_json_file(data, split, write_dir)
    return data_train, data_dev, data_test

### params

In [3]:
path = 'train_preprocessed.json'
n = 7_500  # number of NQ questions we're using
split_ratios = {
    'train': .8,
    'dev': .1,
    'test': .1,
}
random_state = 0
path_output = 'splits_7500'

### get data

In [4]:
data = json.load(open(path))['data'][:n]
assert len(data) == n
train, dev, test = write_splits(data, split_ratios, path_output)

train: 6000, dev: 750, test: 750, total = 7500


In [5]:
!ls -lh $path_output

total 876K
-rw-r--r-- 1 root root  87K May 28 18:50 dev.json
-rw-r--r-- 1 root root  87K May 28 18:50 test.json
-rw-r--r-- 1 root root 698K May 28 18:50 train.json


### sanity check

In [6]:
for split, data in [('train', train), ('dev', dev), ('test', test)]:
    assert json.load(open(f'{path_output}/{split}.json')) == data, split

### sample data

In [7]:
train[:3]

[{'id': 'train_1328',
  'question': 'when does season 2 episode 3 of escape the night come out',
  'answers': ['June 28 , 2017']},
 {'id': 'train_2368',
  'question': 'who is the actress that plays in the new wonder woman',
  'answers': ['Gal Gadot']},
 {'id': 'train_4113',
  'question': 'what year did us land on the moon',
  'answers': ['20 July 1969']}]

In [8]:
dev[:3]

[{'id': 'train_5909',
  'question': 'how did the lost voice guy lose his voice',
  'answers': ['neurological form of cerebral palsy']},
 {'id': 'train_6441',
  'question': 'who took power in england during the glorious revolution brainly',
  'answers': ['William III']},
 {'id': 'train_4802',
  'question': "who sang i've got a brand new pair of roller skates",
  'answers': ['folk music singer Melanie']}]

In [9]:
test[:3]

[{'id': 'train_5418',
  'question': 'four aspects of being an australian that are tested in the australian citizenship test',
  'answers': ['English language',
   'traditional and national symbols',
   "Australia 's `` values ''",
   'history']},
 {'id': 'train_1933',
  'question': 'when was the last time the us hosted the world cup',
  'answers': ['1994']},
 {'id': 'train_4554',
  'question': 'whose line is it anyway how many seasons',
  'answers': ['14']}]