forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 3
/
__init__.py
859 lines (693 loc) · 26.1 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import pickle
import re
import sys
import copy
import json
import yaml
import redis
import bisect
import shutil
import difflib
import hashlib
import datetime
import requests
import tempfile
import importlib
import contextlib
import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse
from ..config import C
from ..log import get_module_logger, set_log_with_config
log = get_module_logger("utils")
#################### Server ####################
def get_redis_connection():
"""get redis connection instance."""
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db)
#################### Data ####################
def read_bin(file_path, start_index, end_index):
with open(file_path, "rb") as f:
# read start_index
ref_start_index = int(np.frombuffer(f.read(4), dtype="<f")[0])
si = max(ref_start_index, start_index)
if si > end_index:
return pd.Series(dtype=np.float32)
# calculate offset
f.seek(4 * (si - ref_start_index) + 4)
# read nbytes
count = end_index - si + 1
data = np.frombuffer(f.read(4 * count), dtype="<f")
series = pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
return series
def np_ffill(arr: np.array):
"""
forward fill a 1D numpy array
Parameters
----------
arr : np.array
Input numpy 1D array
"""
mask = np.isnan(arr.astype(float)) # np.isnan only works on np.float
# get fill index
idx = np.where(~mask, np.arange(mask.shape[0]), 0)
np.maximum.accumulate(idx, out=idx)
return arr[idx]
#################### Search ####################
def lower_bound(data, val, level=0):
"""multi fields list lower bound.
for single field list use `bisect.bisect_left` instead
"""
left = 0
right = len(data)
while left < right:
mid = (left + right) // 2
if val <= data[mid][level]:
right = mid
else:
left = mid + 1
return left
def upper_bound(data, val, level=0):
"""multi fields list upper bound.
for single field list use `bisect.bisect_right` instead
"""
left = 0
right = len(data)
while left < right:
mid = (left + right) // 2
if val >= data[mid][level]:
left = mid + 1
else:
right = mid
return left
#################### HTTP ####################
def requests_with_retry(url, retry=5, **kwargs):
while retry > 0:
retry -= 1
try:
res = requests.get(url, timeout=1, **kwargs)
assert res.status_code in {200, 206}
return res
except AssertionError:
continue
except Exception as e:
log.warning("exception encountered {}".format(e))
continue
raise Exception("ERROR: requests failed!")
#################### Parse ####################
def parse_config(config):
# Check whether need parse, all object except str do not need to be parsed
if not isinstance(config, str):
return config
# Check whether config is file
if os.path.exists(config):
with open(config, "r") as f:
return yaml.safe_load(f)
# Check whether the str can be parsed
try:
return yaml.safe_load(config)
except BaseException:
raise ValueError("cannot parse config!")
#################### Other ####################
def drop_nan_by_y_index(x, y, weight=None):
# x, y, weight: DataFrame
# Find index of rows which do not contain Nan in all columns from y.
mask = ~y.isna().any(axis=1)
# Get related rows from x, y, weight.
x = x[mask]
y = y[mask]
if weight is not None:
weight = weight[mask]
return x, y, weight
def hash_args(*args):
# json.dumps will keep the dict keys always sorted.
string = json.dumps(args, sort_keys=True, default=str) # frozenset
return hashlib.md5(string.encode()).hexdigest()
def parse_field(field):
# Following patterns will be matched:
# - $close -> Feature("close")
# - $close5 -> Feature("close5")
# - $open+$close -> Feature("open")+Feature("close")
if not isinstance(field, str):
field = str(field)
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
def get_module_by_module_path(module_path: Union[str, ModuleType]):
"""Load module path
:param module_path:
:return:
"""
if isinstance(module_path, ModuleType):
module = module_path
else:
if module_path.endswith(".py"):
module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_")))
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class and kwargs from config info
Parameters
----------
config : [dict, str]
similar to config
default_module : Python module or str
It should be a python module to load the class type
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
Returns
-------
(type, dict):
the class object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
return klass, kwargs
def init_instance_by_config(
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
) -> Any:
"""
get initialized instance with config
Parameters
----------
config : Union[str, dict, object]
dict example.
{
'class': 'ClassName',
'kwargs': dict, # It is optional. {} will be used if not given
'model_path': path, # It is optional if module is given
}
str example.
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'
2) specify a class name
- "ClassName": getattr(module, config)() will be used.
object example:
instance of accept_types
default_module : Python module
Optional. It should be a python module.
NOTE: the "module_path" will be override by `module` arguments
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
accept_types: Union[type, Tuple[type]]
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
Returns
-------
object:
An initialized object based on the config info
"""
if isinstance(config, accept_types):
return config
if isinstance(config, str):
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
def compare_dict_value(src_data: dict, dst_data: dict):
"""Compare dict value
:param src_data:
:param dst_data:
:return:
"""
class DateEncoder(json.JSONEncoder):
# FIXME: This class can only be accurate to the day. If it is a minute,
# there may be a bug
def default(self, o):
if isinstance(o, (datetime.datetime, datetime.date)):
return o.strftime("%Y-%m-%d %H:%M:%S")
return json.JSONEncoder.default(self, o)
src_data = json.dumps(src_data, indent=4, sort_keys=True, cls=DateEncoder)
dst_data = json.dumps(dst_data, indent=4, sort_keys=True, cls=DateEncoder)
diff = difflib.ndiff(src_data, dst_data)
changes = [line for line in diff if line.startswith("+ ") or line.startswith("- ")]
return changes
def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):
"""Create or get a file or directory given the path and return_dir.
Parameters
----------
path: a string indicates the path or None indicates creating a temporary path.
return_dir: if True, create and return a directory; otherwise c&r a file.
"""
if path:
if return_dir and not os.path.exists(path):
os.makedirs(path)
elif not return_dir: # return a file, thus we need to create its parent directory
xpath = os.path.abspath(os.path.join(path, ".."))
if not os.path.exists(xpath):
os.makedirs(xpath)
else:
temp_dir = os.path.expanduser("~/tmp")
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
if return_dir:
_, path = tempfile.mkdtemp(dir=temp_dir)
else:
_, path = tempfile.mkstemp(dir=temp_dir)
return path
@contextlib.contextmanager
def save_multiple_parts_file(filename, format="gztar"):
"""Save multiple parts file
Implementation process:
1. get the absolute path to 'filename'
2. create a 'filename' directory
3. user does something with file_path('filename/')
4. remove 'filename' directory
5. make_archive 'filename' directory, and rename 'archive file' to filename
:param filename: result model path
:param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar"
:return: real model path
Usage::
>>> # The following code will create an archive file('~/tmp/test_file') containing 'test_doc_i'(i is 0-10) files.
>>> with save_multiple_parts_file('~/tmp/test_file') as filename_dir:
... for i in range(10):
... temp_path = os.path.join(filename_dir, 'test_doc_{}'.format(str(i)))
... with open(temp_path) as fp:
... fp.write(str(i))
...
"""
if filename.startswith("~"):
filename = os.path.expanduser(filename)
file_path = os.path.abspath(filename)
# Create model dir
if os.path.exists(file_path):
raise FileExistsError("ERROR: file exists: {}, cannot be create the directory.".format(file_path))
os.makedirs(file_path)
# return model dir
yield file_path
# filename dir to filename.tar.gz file
tar_file = shutil.make_archive(file_path, format=format, root_dir=file_path)
# Remove filename dir
if os.path.exists(file_path):
shutil.rmtree(file_path)
# filename.tar.gz rename to filename
os.rename(tar_file, file_path)
@contextlib.contextmanager
def unpack_archive_with_buffer(buffer, format="gztar"):
"""Unpack archive with archive buffer
After the call is finished, the archive file and directory will be deleted.
Implementation process:
1. create 'tempfile' in '~/tmp/' and directory
2. 'buffer' write to 'tempfile'
3. unpack archive file('tempfile')
4. user does something with file_path('tempfile/')
5. remove 'tempfile' and 'tempfile directory'
:param buffer: bytes
:param format: archive format: one of "zip", "tar", "gztar", "bztar", or "xztar"
:return: unpack archive directory path
Usage::
>>> # The following code is to print all the file names in 'test_unpack.tar.gz'
>>> with open('test_unpack.tar.gz') as fp:
... buffer = fp.read()
...
>>> with unpack_archive_with_buffer(buffer) as temp_dir:
... for f_n in os.listdir(temp_dir):
... print(f_n)
...
"""
temp_dir = os.path.expanduser("~/tmp")
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
with tempfile.NamedTemporaryFile("wb", delete=False, dir=temp_dir) as fp:
fp.write(buffer)
file_path = fp.name
try:
tar_file = file_path + ".tar.gz"
os.rename(file_path, tar_file)
# Create dir
os.makedirs(file_path)
shutil.unpack_archive(tar_file, format=format, extract_dir=file_path)
# Return temp dir
yield file_path
except Exception as e:
log.error(str(e))
finally:
# Remove temp tar file
if os.path.exists(tar_file):
os.unlink(tar_file)
# Remove temp model dir
if os.path.exists(file_path):
shutil.rmtree(file_path)
@contextlib.contextmanager
def get_tmp_file_with_buffer(buffer):
temp_dir = os.path.expanduser("~/tmp")
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
with tempfile.NamedTemporaryFile("wb", delete=True, dir=temp_dir) as fp:
fp.write(buffer)
file_path = fp.name
yield file_path
def remove_repeat_field(fields):
"""remove repeat field
:param fields: list; features fields
:return: list
"""
fields = copy.deepcopy(fields)
_fields = set(fields)
return sorted(_fields, key=fields.index)
def remove_fields_space(fields: [list, str, tuple]):
"""remove fields space
:param fields: features fields
:return: list or str
"""
if isinstance(fields, str):
return fields.replace(" ", "")
return [i.replace(" ", "") for i in fields if isinstance(i, str)]
def normalize_cache_fields(fields: [list, tuple]):
"""normalize cache fields
:param fields: features fields
:return: list
"""
return sorted(remove_repeat_field(remove_fields_space(fields)))
def normalize_cache_instruments(instruments):
"""normalize cache instruments
:return: list or dict
"""
if isinstance(instruments, (list, tuple, pd.Index, np.ndarray)):
instruments = sorted(list(instruments))
else:
# dict type stockpool
if "market" in instruments:
pass
else:
instruments = {k: sorted(v) for k, v in instruments.items()}
return instruments
def is_tradable_date(cur_date):
"""judgy whether date is a tradable date
----------
date : pandas.Timestamp
current date
"""
from ..data import D
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
"""get trading date range by shift
Parameters
----------
trading_date: pd.Timestamp
left_shift: int
right_shift: int
future: bool
"""
from ..data import D
start = get_date_by_shift(trading_date, left_shift, future=future)
end = get_date_by_shift(trading_date, right_shift, future=future)
calendar = D.calendar(start, end, future=future)
return calendar
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
----------
trading_date : pandas.Timestamp
current date
shift : int
clip_shift: bool
"""
from qlib.data import D
cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
shift_index = _index + shift
if shift_index < 0 or shift_index >= len(cal):
if clip_shift:
shift_index = np.clip(shift_index, 0, len(cal) - 1)
else:
raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range")
return cal[shift_index]
def get_next_trading_date(trading_date, future=False):
"""get next trading date
----------
cur_date : pandas.Timestamp
current date
"""
return get_date_by_shift(trading_date, 1, future=future)
def get_pre_trading_date(trading_date, future=False):
"""get previous trading date
----------
date : pandas.Timestamp
current date
"""
return get_date_by_shift(trading_date, -1, future=future)
def transform_end_date(end_date=None, freq="day"):
"""get previous trading date
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
Otherwise, returns the end_date
----------
end_date: str
end trading date
date : pandas.Timestamp
current date
"""
from ..data import D
last_date = D.calendar(freq=freq)[-1]
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
log.warning(
"\nInfo: the end_date in the configuration file is {}, "
"so the default last date {} is used.".format(end_date, last_date)
)
end_date = last_date
return end_date
def get_date_in_file_name(file_name):
"""Get the date(YYYY-MM-DD) written in file name
Parameter
file_name : str
:return
date : str
'YYYY-MM-DD'
"""
pattern = "[0-9]{4}-[0-9]{2}-[0-9]{2}"
date = re.search(pattern, str(file_name)).group()
return date
def split_pred(pred, number=None, split_date=None):
"""split the score file into two part
Parameter
---------
pred : pd.DataFrame (index:<instrument, datetime>)
A score file of stocks
number: the number of dates for pred_left
split_date: the last date of the pred_left
Return
-------
pred_left : pd.DataFrame (index:<instrument, datetime>)
The first part of original score file
pred_right : pd.DataFrame (index:<instrument, datetime>)
The second part of original score file
"""
if number is None and split_date is None:
raise ValueError("`number` and `split date` cannot both be None")
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = list(map(pd.Timestamp, dates))
if split_date is None:
date_left_end = dates[number - 1]
date_right_begin = dates[number]
date_left_start = None
else:
split_date = pd.Timestamp(split_date)
date_left_end = split_date
date_right_begin = split_date + pd.Timedelta(days=1)
if number is None:
date_left_start = None
else:
end_idx = bisect.bisect_right(dates, split_date)
date_left_start = dates[end_idx - number]
pred_temp = pred.sort_index()
pred_left = pred_temp.loc(axis=0)[:, date_left_start:date_left_end]
pred_right = pred_temp.loc(axis=0)[:, date_right_begin:]
return pred_left, pred_right
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
"""
Time slicing in Qlib or Pandas is a frequently-used action.
However, user often input all kinds of data format to represent time.
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
Parameters
----------
t : Union[None, str, pd.Timestamp]
original time
Returns
-------
Union[None, pd.Timestamp]:
"""
if t is None:
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
return t
else:
return pd.Timestamp(t)
def can_use_cache():
res = True
r = get_redis_connection()
try:
r.client()
except redis.exceptions.ConnectionError:
res = False
finally:
r.close()
return res
def exists_qlib_data(qlib_dir):
qlib_dir = Path(qlib_dir).expanduser()
if not qlib_dir.exists():
return False
calendars_dir = qlib_dir.joinpath("calendars")
instruments_dir = qlib_dir.joinpath("instruments")
features_dir = qlib_dir.joinpath("features")
# check dir
for _dir in [calendars_dir, instruments_dir, features_dir]:
if not (_dir.exists() and list(_dir.iterdir())):
return False
# check calendar bin
for _calendar in calendars_dir.iterdir():
if ("_future" not in _calendar.name) and (
not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin"))
):
return False
# check instruments
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False
return True
def check_qlib_data(qlib_config):
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
for _p in inst_dir.glob("*.txt"):
try:
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)
except AssertionError:
raise
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
"""
make the df index sorted
df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
This function could avoid such case
Parameters
----------
df : pd.DataFrame
Returns
-------
pd.DataFrame:
sorted dataframe
"""
idx = df.index if axis == 0 else df.columns
if idx.is_monotonic_increasing:
return df
else:
return df.sort_index(axis=axis)
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
def flatten_dict(d, parent_key="", sep=".") -> dict:
"""
Flatten a nested dict.
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
Args:
d (dict): the dict waiting for flatting
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
Returns:
dict: flatten dict
"""
items = []
for k, v in d.items():
if sep == FLATTEN_TUPLE:
new_key = (parent_key, k) if parent_key else k
else:
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
#################### Wrapper #####################
class Wrapper:
"""Wrapper class for anything that needs to set up during qlib.init"""
def __init__(self):
self._provider = None
def register(self, provider):
self._provider = provider
def __repr__(self):
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
def __getattr__(self, key):
if self._provider is None:
raise AttributeError("Please run qlib.init() first using qlib")
return getattr(self._provider, key)
def register_wrapper(wrapper, cls_or_obj, module_path=None):
"""register_wrapper
:param wrapper: A wrapper.
:param cls_or_obj: A class or class name or object instance.
"""
if isinstance(cls_or_obj, str):
module = get_module_by_module_path(module_path)
cls_or_obj = getattr(module, cls_or_obj)
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
wrapper.register(obj)
def load_dataset(path_or_obj):
"""load dataset from multiple file formats"""
if isinstance(path_or_obj, pd.DataFrame):
return path_or_obj
if not os.path.exists(path_or_obj):
raise ValueError(f"file {path_or_obj} doesn't exist")
_, extension = os.path.splitext(path_or_obj)
if extension == ".h5":
return pd.read_hdf(path_or_obj)
elif extension == ".pkl":
return pd.read_pickle(path_or_obj)
elif extension == ".csv":
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f"unsupported file type `{extension}`")
def code_to_fname(code: str):
"""stock code to file name
Parameters
----------
code: str
"""
# NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created
# reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows
replace_names = ["CON", "PRN", "AUX", "NUL"]
replace_names += [f"COM{i}" for i in range(10)]
replace_names += [f"LPT{i}" for i in range(10)]
prefix = "_qlib_"
if str(code).upper() in replace_names:
code = prefix + str(code)
return code
def fname_to_code(fname: str):
"""file name to stock code
Parameters
----------
fname: str
"""
prefix = "_qlib_"
if fname.startswith(prefix):
fname = fname.lstrip(prefix)
return fname