Skip to content

Commit

Permalink
Fast dataset/index loading via cached pickle files
Browse files Browse the repository at this point in the history
Caches generated datasets and indexes as binary files, which significantly
reduces load times.
  • Loading branch information
rgemulla committed Mar 11, 2020
1 parent a5578fd commit cf64dd2
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 24 deletions.
7 changes: 7 additions & 0 deletions kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ dataset:
# Additional files can be added as needed
+++: +++

# Whether to store processed dataset files and indexes as binary files in the
# dataset directory. This enables faster loading at the cost of storage space.
# LibKGE will ensure that pickled files are only used when not outdated. Note
# that the value specified here may be overwritten by dataset-specific choices
# (in dataset.yaml).
pickle: True

# Additional dataset specific keys can be added as needed
+++: +++

Expand Down
117 changes: 110 additions & 7 deletions kge/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import torch
from torch import Tensor
import numpy as np
import pickle
import inspect

from kge import Config, Configurable
import kge.indexing
from kge.indexing import create_default_index_functions
from kge.misc import kge_base_dir

from typing import Dict, List, Any, Callable, Union, Optional

# TODO add support to pickle dataset (and indexes) and reload from there

class Dataset(Configurable):
"""Stores information about a dataset.
Expand Down Expand Up @@ -56,7 +59,8 @@ def __init__(self, config, folder=None):
self._indexes: Dict[str, Any] = {}

#: functions that compute and add indexes as needed; arguments are dataset and
# key. : Indexed by key (same key as in self._indexes)
#: key. Index functions are expected to not recompute an index that is already
#: present. Indexed by key (same key as in self._indexes)
self.index_functions: Dict[str, Callable] = {}
create_default_index_functions(self)

Expand Down Expand Up @@ -85,9 +89,32 @@ def load(config: Config, preload_data=True):
return dataset

@staticmethod
def _load_triples(filename: str, delimiter="\t") -> Tensor:
def _to_valid_filename(s):
invalid_chars = "\n\t\\/"
replacement_chars = "ntbf"
trans = invalid_chars.maketrans(invalid_chars, replacement_chars)
return s.translate(trans)

@staticmethod
def _load_triples(filename: str, delimiter="\t", use_pickle=False) -> Tensor:
if use_pickle:
# check if there is a pickled, up-to-date version of the file
pickle_suffix = Dataset._to_valid_filename(f"-{delimiter}.pckl")
pickle_filename = filename + pickle_suffix
if os.path.isfile(pickle_filename) and os.path.getmtime(
pickle_filename
) > Dataset._get_newest_mtime(
None, filename
): # self=None
with open(pickle_filename, "rb") as f:
return pickle.load(f)

triples = np.loadtxt(filename, usecols=range(0, 3), dtype=int)
return torch.from_numpy(triples)
triples = torch.from_numpy(triples)
if use_pickle:
with open(pickle_filename, "wb") as f:
pickle.dump(triples, f)
return triples

def load_triples(self, key: str) -> Tensor:
"Load or return the triples with the specified key."
Expand All @@ -99,7 +126,10 @@ def load_triples(self, key: str) -> Tensor:
"Unexpected file type: "
f"dataset.files.{key}.type='{filetype}', expected 'triples'"
)
triples = Dataset._load_triples(os.path.join(self.folder, filename))
triples = Dataset._load_triples(
os.path.join(self.folder, filename),
use_pickle=self.config.get("dataset.pickle"),
)
self.config.log(f"Loaded {len(triples)} {key} triples")
self._triples[key] = triples

Expand All @@ -111,7 +141,22 @@ def _load_map(
as_list: bool = False,
delimiter: str = "\t",
ignore_duplicates=False,
use_pickle=False,
) -> Union[List, Dict]:
if use_pickle:
# check if there is a pickled, up-to-date version of the file
pickle_suffix = Dataset._to_valid_filename(
f"-{as_list}-{delimiter}-{ignore_duplicates}.pckl"
)
pickle_filename = filename + pickle_suffix
if os.path.isfile(pickle_filename) and os.path.getmtime(
pickle_filename
) > Dataset._get_newest_mtime(
None, filename
): # self=None
with open(pickle_filename, "rb") as f:
return pickle.load(f)

n = 0
dictionary = {}
warned_overrides = False
Expand All @@ -133,9 +178,14 @@ def _load_map(
array = [None] * n
for index, value in dictionary.items():
array[index] = value
return array, duplicates
result = (array, duplicates)
else:
return dictionary, duplicates
result = (dictionary, duplicates)

if use_pickle:
with open(pickle_filename, "wb") as f:
pickle.dump(result, f)
return result

def load_map(
self,
Expand Down Expand Up @@ -177,6 +227,7 @@ def load_map(
os.path.join(self.folder, filename),
as_list=False,
ignore_duplicates=ignore_duplicates,
use_pickle=self.config.get("dataset.pickle"),
)
ids = self.load_map(ids_key, as_list=True)
map_ = [map_.get(ids[i], None) for i in range(len(ids))]
Expand All @@ -191,6 +242,7 @@ def load_map(
os.path.join(self.folder, filename),
as_list=as_list,
ignore_duplicates=ignore_duplicates,
use_pickle=self.config.get("dataset.pickle"),
)

if duplicates > 0:
Expand All @@ -217,6 +269,36 @@ def shallow_copy(self):
copy.index_functions = self.index_functions
return copy

def _get_newest_mtime(self, filenames=None):
"""Return the timestamp of latest modification of relevant data files.
If `filenames` is `None`, return latest modification of relevant modules or any
of the dataset files given in the configuration.
Otherwise, return latest modification of relevant modules or any of the
specified files.
"""
newest_timestamp = max(
os.path.getmtime(inspect.getfile(Dataset)),
os.path.getmtime(inspect.getfile(kge.indexing)),
)
if filenames is None:
filenames = []
for key, entry in self.config.get("dataset.files").items():
filename = os.path.join(self.folder, entry["filename"])
filenames.append(filename)

if isinstance(filenames, str):
filenames = [filenames]

for filename in filenames:
if os.path.isfile(filename):
timestamp = os.path.getmtime(filename)
newest_timestamp = max(newest_timestamp, timestamp)

return newest_timestamp

## ACCESS ###########################################################################

def num_entities(self) -> int:
Expand Down Expand Up @@ -332,7 +414,28 @@ def index(self, key: str) -> Any:
"""
if key not in self._indexes:
use_pickle = self.config.get("dataset.pickle")
if use_pickle:
pickle_filename = os.path.join(
self.folder, Dataset._to_valid_filename(f"index-{key}.pckl")
)
if (
os.path.isfile(pickle_filename)
and os.path.getmtime(pickle_filename) > self._get_newest_mtime()
):
with open(pickle_filename, "rb") as f:
self._indexes[key] = pickle.load(f)
# call index function solely to print log messages. It's
# expected to note recompute the index (which we just loaded)
if key in self.index_functions:
self.index_functions[key](self)
return self._indexes[key]

self.index_functions[key](self)
if use_pickle:
with open(pickle_filename, "wb") as f:
pickle.dump(self._indexes[key], f)

return self._indexes[key]

@staticmethod
Expand Down
38 changes: 21 additions & 17 deletions kge/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def index_KvsAll(dataset: "Dataset", split: str, key: str):
`split_so`. If this index is already present, does not recompute it.
"""
triples = dataset.split(split)
value = None
if key == "sp":
key_cols = [0, 1]
Expand All @@ -51,15 +50,15 @@ def index_KvsAll(dataset: "Dataset", split: str, key: str):

name = split + "_" + key + "_to_" + value
if not dataset._indexes.get(name):
triples = dataset.split(split)
dataset._indexes[name] = _group_by(
triples[:, key_cols], triples[:, value_column]
)
dataset.config.log(
"{} distinct {} pairs in {}".format(
len(dataset._indexes[name]), key, split
),
prefix=" ",
)

dataset.config.log(
"{} distinct {} pairs in {}".format(len(dataset._indexes[name]), key, split),
prefix=" ",
)

return dataset._indexes.get(name)

Expand Down Expand Up @@ -118,16 +117,19 @@ def index_relation_types(dataset):
"""
create dictionary mapping from {1-N, M-1, 1-1, M-N} -> set of relations
"""
if "relation_types" in dataset._indexes:
return
relation_types = _get_relation_types(dataset)
relations_per_type = {}
for k, v in relation_types.items():
relations_per_type.setdefault(v, set()).add(k)
for k, v in relations_per_type.items():
if (
"relation_types" not in dataset._indexes
or "relations_per_type" not in dataset._indexes
):
relation_types = _get_relation_types(dataset)
relations_per_type = {}
for k, v in relation_types.items():
relations_per_type.setdefault(v, set()).add(k)
dataset._indexes["relation_types"] = relation_types
dataset._indexes["relations_per_type"] = relations_per_type

for k, v in dataset._indexes["relations_per_type"].items():
dataset.config.log("{} relations of type {}".format(len(v), k), prefix=" ")
dataset._indexes["relation_types"] = relation_types
dataset._indexes["relations_per_type"] = relations_per_type


def index_frequency_percentiles(dataset, recompute=False):
Expand Down Expand Up @@ -211,8 +213,10 @@ def _invert_ids(dataset, obj: str):
if not f"{obj}_id_to_index" in dataset._indexes:
ids = dataset.load_map(f"{obj}_ids")
inv = {v: k for k, v in enumerate(ids)}
dataset.config.log(f"Indexed {len(inv)} {obj} ids")
dataset._indexes[f"{obj}_id_to_index"] = inv
else:
inv = dataset._indexes[f"{obj}_id_to_index"]
dataset.config.log(f"Indexed {len(inv)} {obj} ids", prefix=" ")


def create_default_index_functions(dataset: "Dataset"):
Expand Down

0 comments on commit cf64dd2

Please sign in to comment.