-
Notifications
You must be signed in to change notification settings - Fork 141
/
safety.py
178 lines (150 loc) · 5.57 KB
/
safety.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# -*- coding: utf-8 -*-
import errno
import json
import os
import time
from collections import namedtuple
import requests
from packaging.specifiers import SpecifierSet
from .constants import (API_MIRRORS, CACHE_FILE, CACHE_VALID_SECONDS,
OPEN_MIRRORS, REQUEST_TIMEOUT)
from .errors import (DatabaseFetchError, DatabaseFileNotFoundError,
InvalidKeyError)
from .util import RequirementFile
class Vulnerability(namedtuple("Vulnerability",
["name", "spec", "version", "advisory", "vuln_id"])):
pass
def get_from_cache(db_name):
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE) as f:
try:
data = json.loads(f.read())
if db_name in data:
if "cached_at" in data[db_name]:
if data[db_name]["cached_at"] + CACHE_VALID_SECONDS > time.time():
return data[db_name]["db"]
except json.JSONDecodeError:
pass
return False
def write_to_cache(db_name, data):
# cache is in: ~/safety/cache.json
# and has the following form:
# {
# "insecure.json": {
# "cached_at": 12345678
# "db": {}
# },
# "insecure_full.json": {
# "cached_at": 12345678
# "db": {}
# },
# }
if not os.path.exists(os.path.dirname(CACHE_FILE)):
try:
os.makedirs(os.path.dirname(CACHE_FILE))
with open(CACHE_FILE, "w") as _:
_.write(json.dumps({}))
except OSError as exc: # Guard against race condition
if exc.errno != errno.EEXIST:
raise
with open(CACHE_FILE, "r") as f:
try:
cache = json.loads(f.read())
except json.JSONDecodeError:
cache = {}
with open(CACHE_FILE, "w") as f:
cache[db_name] = {
"cached_at": time.time(),
"db": data
}
f.write(json.dumps(cache))
def fetch_database_url(mirror, db_name, key, cached, proxy):
headers = {}
if key:
headers["X-Api-Key"] = key
if cached:
cached_data = get_from_cache(db_name=db_name)
if cached_data:
return cached_data
url = mirror + db_name
r = requests.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers, proxies=proxy)
if r.status_code == 200:
data = r.json()
if cached:
write_to_cache(db_name, data)
return data
elif r.status_code == 403:
raise InvalidKeyError()
def fetch_database_file(path, db_name):
full_path = os.path.join(path, db_name)
if not os.path.exists(full_path):
raise DatabaseFileNotFoundError()
with open(full_path) as f:
return json.loads(f.read())
def fetch_database(full=False, key=False, db=False, cached=False, proxy={}):
if db:
mirrors = [db]
else:
mirrors = API_MIRRORS if key else OPEN_MIRRORS
db_name = "insecure_full.json" if full else "insecure.json"
for mirror in mirrors:
# mirror can either be a local path or a URL
if mirror.startswith("http://") or mirror.startswith("https://"):
data = fetch_database_url(mirror, db_name=db_name, key=key, cached=cached, proxy=proxy)
else:
data = fetch_database_file(mirror, db_name=db_name)
if data:
return data
raise DatabaseFetchError()
def get_vulnerabilities(pkg, spec, db):
for entry in db[pkg]:
for entry_spec in entry["specs"]:
if entry_spec == spec:
yield entry
def check(packages, key, db_mirror, cached, ignore_ids, proxy):
key = key if key else os.environ.get("SAFETY_API_KEY", False)
db = fetch_database(key=key, db=db_mirror, cached=cached, proxy=proxy)
db_full = None
vulnerable_packages = frozenset(db.keys())
vulnerable = []
for pkg in packages:
# Ignore recursive files not resolved
if isinstance(pkg, RequirementFile):
continue
# normalize the package name, the safety-db is converting underscores to dashes and uses
# lowercase
name = pkg.key.replace("_", "-").lower()
if name in vulnerable_packages:
# we have a candidate here, build the spec set
for specifier in db[name]:
spec_set = SpecifierSet(specifiers=specifier)
if spec_set.contains(pkg.version):
if not db_full:
db_full = fetch_database(full=True, key=key, db=db_mirror, cached=cached, proxy=proxy)
for data in get_vulnerabilities(pkg=name, spec=specifier, db=db_full):
vuln_id = data.get("id").replace("pyup.io-", "")
if vuln_id and vuln_id not in ignore_ids:
vulnerable.append(
Vulnerability(
name=name,
spec=specifier,
version=pkg.version,
advisory=data.get("advisory"),
vuln_id=vuln_id
)
)
return vulnerable
def review(vulnerabilities):
vulnerable = []
for vuln in vulnerabilities:
current_vuln = {
"name": vuln[0],
"spec": vuln[1],
"version": vuln[2],
"advisory": vuln[3],
"vuln_id": vuln[4],
}
vulnerable.append(
Vulnerability(**current_vuln)
)
return vulnerable