/
utils.py
135 lines (107 loc) · 4.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from __future__ import print_function, unicode_literals
import logging
import os
import subprocess
import sys
from enum import Enum, unique
import requests
import snips_nlu
from snips_nlu import __about__
from snips_nlu.common.utils import parse_version
try:
from importlib import invalidate_caches
except ImportError:
def invalidate_caches():
from time import sleep
sleep(1)
@unique
class PrettyPrintLevel(Enum):
INFO = 0
WARNING = 1
ERROR = 2
SUCCESS = 3
FMT = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s"
DATE_FMT = "%H:%M:%S"
def pretty_print(*texts, **kwargs):
"""Print formatted message
Args:
*texts (str): Texts to print. Each argument is rendered as paragraph.
**kwargs: 'title' becomes coloured headline. exits=True performs sys
exit.
"""
exits = kwargs.get("exits")
title = kwargs.get("title")
level = kwargs.get("level", PrettyPrintLevel.INFO)
title_color = _color_from_level(level)
if title:
title = "\033[{color}m{title}\033[0m\n".format(title=title,
color=title_color)
else:
title = ""
message = "\n\n".join([text for text in texts])
print("\n{title}{message}\n".format(title=title, message=message))
if exits is not None:
sys.exit(exits)
def _color_from_level(level):
if level == PrettyPrintLevel.INFO:
return "92"
if level == PrettyPrintLevel.WARNING:
return "93"
if level == PrettyPrintLevel.ERROR:
return "91"
if level == PrettyPrintLevel.SUCCESS:
return "92"
else:
raise ValueError("Unknown PrettyPrintLevel: %s" % level)
def get_json(url, desc):
r = requests.get(url)
if r.status_code != 200:
raise OSError("%s: Received status code %s when fetching the resource"
% (desc, r.status_code))
return r.json()
def get_compatibility():
version = __about__.__version__
parsed_version = parse_version(version)
minor_version = "%s.%s" % (
parsed_version["major"], parsed_version["minor"])
table = get_json(__about__.__compatibility__, "Compatibility table")
nlu_table = table["snips-nlu"]
compatibility = nlu_table.get(version, nlu_table.get(minor_version))
if compatibility is None:
pretty_print("No compatible resources found for version %s" % version,
title="Resources compatibility error", exits=1,
level=PrettyPrintLevel.ERROR)
return compatibility
def get_resources_version(resource_fullname, resource_alias, compatibility):
if resource_fullname not in compatibility:
pretty_print("No compatible resources found for '%s'" % resource_alias,
title="Resources compatibility error", exits=1,
level=PrettyPrintLevel.ERROR)
return compatibility[resource_fullname][0]
def install_remote_package(download_url, user_pip_args=None):
pip_args = ['--no-cache-dir', '--no-deps']
if user_pip_args:
pip_args.extend(user_pip_args)
cmd = [sys.executable, '-m', 'pip', 'install'] + pip_args + [download_url]
exit_code = subprocess.call(cmd, env=os.environ.copy())
# Don't forget to invalidate caches after dynamically installing modules
# https://docs.python.org/3/library/importlib.html#importlib.import_module
invalidate_caches()
return exit_code
def check_resources_alias(resource_name, shortcuts):
available_aliases = set(shortcuts)
if resource_name.lower() not in available_aliases:
aliases = ", ".join(sorted(available_aliases))
pretty_print(
"No resources found for {r}, available resource aliases are "
"(case insensitive):\n{a}".format(r=resource_name, a=aliases),
title="Unknown language resources", exits=1,
level=PrettyPrintLevel.ERROR)
def set_nlu_logger(level=logging.INFO):
logger = logging.getLogger(snips_nlu.__name__)
logger.setLevel(level)
formatter = logging.Formatter(FMT, DATE_FMT)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
handler.setLevel(level)
logger.addHandler(handler)