-
Notifications
You must be signed in to change notification settings - Fork 93
/
util.py
391 lines (296 loc) · 11.2 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
"""
A miscellaneous set of utility functions. Try not to put things in here
unless they really have no other place.
"""
import os
import re
import shutil
from collections import defaultdict
from pkg_resources import Requirement, resource_filename
from pathlib import Path
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname
import h5py as h5
import numpy as np
import attr
import psutil
import json
import rapidjson
import yaml
from typing import Any, Dict, Hashable, Iterable, List, Optional
import sleap.version as sleap_version
def json_loads(json_str: str) -> Dict:
"""
A simple wrapper around the JSON decoder we are using.
Args:
json_str: JSON string to decode.
Returns:
Result of decoding JSON string.
"""
try:
return rapidjson.loads(json_str)
except:
return json.loads(json_str)
def json_dumps(d: Dict, filename: str = None):
"""
A simple wrapper around the JSON encoder we are using.
Args:
d: The dict to write.
filename: The filename to write to.
Returns:
None
"""
encoder = rapidjson
if filename:
with open(filename, "w") as f:
encoder.dump(d, f, ensure_ascii=False)
else:
return encoder.dumps(d)
def attr_to_dtype(cls: Any):
"""
Converts classes with basic types to numpy composite dtypes.
Arguments:
cls: class to convert
Returns:
numpy dtype.
"""
dtype_list = []
for field in attr.fields(cls):
if field.type == str:
dtype_list.append((field.name, h5.special_dtype(vlen=str)))
elif field.type is None:
raise TypeError(
f"numpy dtype for {cls} cannot be constructed because no "
+ "type information found. Make sure each field is type annotated."
)
elif field.type in [str, int, float, bool]:
dtype_list.append((field.name, field.type))
else:
raise TypeError(
f"numpy dtype for {cls} cannot be constructed because no "
+ f"{field.type} is not supported."
)
return np.dtype(dtype_list)
def usable_cpu_count() -> int:
"""
Gets number of CPUs usable by the current process.
Takes into consideration cpusets restrictions.
Returns:
The number of usable cpus
"""
try:
result = len(os.sched_getaffinity(0))
except AttributeError:
try:
result = len(psutil.Process().cpu_affinity())
except AttributeError:
result = os.cpu_count()
return result
def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict):
"""
Saves dictionary to an HDF5 file.
Calls itself recursively if items in dictionary are not
`np.ndarray`, `np.int64`, `np.float64`, `str`, or bytes.
Objects must be iterable.
Args:
h5file: The HDF5 filename object to save the data to.
Assume it is open.
path: The path to group save the dict under.
dic: The dict to save.
Raises:
ValueError: If type for item in dict cannot be saved.
Returns:
None
"""
for key, item in list(dic.items()):
print(f"Saving {key}:")
if item is None:
h5file[path + key] = ""
elif isinstance(item, bool):
h5file[path + key] = int(item)
elif isinstance(item, list):
items_encoded = []
for it in item:
if isinstance(it, str):
items_encoded.append(it.encode("utf8"))
else:
items_encoded.append(it)
h5file[path + key] = np.asarray(items_encoded)
elif isinstance(item, (str)):
h5file[path + key] = item.encode("utf8")
elif isinstance(item, (np.ndarray, np.int64, np.float64, str, bytes, float)):
h5file[path + key] = item
elif isinstance(item, dict):
save_dict_to_hdf5(h5file, path + key + "/", item)
elif isinstance(item, int):
h5file[path + key] = item
else:
raise ValueError("Cannot save %s type" % type(item))
def frame_list(frame_str: str) -> Optional[List[int]]:
"""
Converts 'n-m' string to list of ints.
Args:
frame_str: string representing range
Returns:
List of ints, or None if string does not represent valid range.
"""
# Handle ranges of frames. Must be of the form "1-200" (or "1,-200")
if "-" in frame_str:
min_max = frame_str.split("-")
min_frame = int(min_max[0].rstrip(","))
max_frame = int(min_max[1])
return list(range(min_frame, max_frame + 1))
return [int(x) for x in frame_str.split(",")] if len(frame_str) else None
def uniquify(seq: Iterable[Hashable]) -> List:
"""
Returns unique elements from list, preserving order.
Note: This will not work on Python 3.5 or lower since dicts don't
preserve order.
Args:
seq: The list to remove duplicates from.
Returns:
The unique elements from the input list extracted in original
order.
"""
# Raymond Hettinger
# https://twitter.com/raymondh/status/944125570534621185
return list(dict.fromkeys(seq))
def weak_filename_match(filename_a: str, filename_b: str) -> bool:
"""
Check if paths probably point to same file.
Compares the filename and names of two directories up.
Args:
filename_a: first path to check
filename_b: path to check against first path
Returns:
True if the paths probably match.
"""
# convert all path separators to /
filename_a = filename_a.replace("\\", "/")
filename_b = filename_b.replace("\\", "/")
# remove unique pid so we can match tmp directories for same zip
filename_a = re.sub(r"/tmp_\d+_", "tmp_", filename_a)
filename_b = re.sub(r"/tmp_\d+_", "tmp_", filename_b)
# check if last three parts of path match
return filename_a.split("/")[-3:] == filename_b.split("/")[-3:]
def dict_cut(d: Dict, a: int, b: int) -> Dict:
"""
Helper function for creating subdictionary by numeric indexing of items.
Assumes that `dict.items()` will have a fixed order.
Args:
d: The dictionary to "split"
a: Start index of range of items to include in result.
b: End index of range of items to include in result.
Returns:
A dictionary that contains a subset of the items in the original dict.
"""
return dict(list(d.items())[a:b])
def get_package_file(filename: str) -> str:
"""Returns full path to specified file within sleap package."""
package_path = Requirement.parse("sleap")
result = resource_filename(package_path, filename)
return result
def get_config_file(
shortname: str, ignore_file_not_found: bool = False, get_defaults: bool = False
) -> str:
"""
Returns the full path to the specified config file.
The config file will be at ~/.sleap/<version>/<shortname>
If that file doesn't yet exist, we'll look for a <shortname> file inside
the package config directory (sleap/config) and copy the file into the
user's config directory (creating the directory if needed).
Args:
shortname: The short filename, e.g., shortcuts.yaml
ignore_file_not_found: If True, then return path for config file
regardless of whether it exists.
get_defaults: If True, then just return the path to default config file.
Raises:
FileNotFoundError: If the specified config file cannot be found.
Returns:
The full path to the specified config file.
"""
if not get_defaults:
desired_path = os.path.expanduser(
f"~/.sleap/{sleap_version.__version__}/{shortname}"
)
# Make sure there's a ~/.sleap/<version>/ directory to store user version of the
# config file.
try:
os.makedirs(os.path.expanduser(f"~/.sleap/{sleap_version.__version__}"))
except FileExistsError:
pass
# If we don't care whether the file exists, just return the path
if ignore_file_not_found:
return desired_path
# If we do care whether the file exists, check the package version of the
# config file if we can't find the user version.
if get_defaults or not os.path.exists(desired_path):
package_path = get_package_file(f"sleap/config/{shortname}")
if not os.path.exists(package_path):
raise FileNotFoundError(
f"Cannot locate {shortname} config file at {desired_path} or {package_path}."
)
if get_defaults:
return package_path
# Copy package version of config file into user config directory.
shutil.copy(package_path, desired_path)
return desired_path
def get_config_yaml(shortname: str, get_defaults: bool = False) -> dict:
config_path = get_config_file(shortname, get_defaults=get_defaults)
with open(config_path, "r") as f:
return yaml.load(f, Loader=yaml.Loader)
def save_config_yaml(shortname: str, data: Any) -> dict:
yaml_path = get_config_file(shortname, ignore_file_not_found=True)
with open(yaml_path, "w") as f:
print(f"Saving config: {yaml_path}")
yaml.dump(data, f)
def make_scoped_dictionary(
flat_dict: Dict[str, Any], exclude_nones: bool = True
) -> Dict[str, Dict[str, Any]]:
"""Converts dictionary with scoped keys to dictionary of dictionaries.
Args:
flat_dict: The dictionary to convert. Keys should be strings with
`scope.foo` format.
exclude_nodes: Whether to exclude items where value is None.
Returns:
Dictionary in which keys are `scope` and values are dictionary with
`foo` (etc) as keys and original value of `scope.foo` as value.
"""
scoped_dict = defaultdict(dict)
for key, val in flat_dict.items():
if "." in key and (not exclude_nones or val is not None):
scope, subkey = key.split(".")
scoped_dict[scope][subkey] = val
return scoped_dict
def find_files_by_suffix(
root_dir: str, suffix: str, prefix: str = "", depth: int = 0
) -> List[os.DirEntry]:
"""
Returns list of files matching suffix, optionally searching in subdirs.
Args:
root_dir: Path to directory where we start searching
suffix: File suffix to match (e.g., '.json')
prefix: Optional file prefix to match
depth: How many subdirectories deep to keep searching
Returns:
List of os.DirEntry objects.
"""
with os.scandir(root_dir) as file_iterator:
files = [file for file in file_iterator]
subdir_paths = [file.path for file in files if file.is_dir()]
matching_files = [
file
for file in files
if file.is_file()
and file.name.endswith(suffix)
and (not prefix or file.name.startswith(prefix))
]
if depth:
for subdir in subdir_paths:
matching_files.extend(
find_files_by_suffix(subdir, suffix, prefix, depth=depth - 1)
)
return matching_files
def parse_uri_path(uri: str) -> str:
"""Parse a URI starting with 'file:///' to a posix path."""
return Path(url2pathname(urlparse(unquote(uri)).path)).as_posix()