/
req_files.py
133 lines (99 loc) · 3.38 KB
/
req_files.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import re
import click
from .exceptions import ReqFileNotFound, ReqFileNotReadable, ReqFileNotWritable
UNPINNED_RE = re.compile(r"^[0-9a-zA-Z_\-]+$")
class ReqFile(object):
"""
Class to manage a requirements file
"""
file_path = None
def __init__(self, path=None, file_name="requirements.txt", auto_read=True):
self.file_name = file_name
self.exists = False
if path is None:
self.file_path = self.find_requirements_file()
else:
self.file_path = path
if self.file_path is not None:
self.exists = True
# Store requirements lines
self.lines = []
self.packages = {}
if auto_read:
self.read(self.file_path)
def find_requirements_file(self):
"""
Find the first requirements file matching file_name
"""
for dirname, subdirs, files in os.walk(os.getcwd()):
for fname in files:
if fname == self.file_name:
return os.path.join(dirname, fname)
def read(self, path):
"""
Read in requirements file
"""
# File doesn't exist, so just move along
if not self.exists:
return
if not os.path.exists(path):
raise ReqFileNotFound("{} not found".format(path))
if not os.access(path, os.R_OK):
raise ReqFileNotReadable("{} not readable".format(path))
if not os.access(path, os.W_OK):
raise ReqFileNotWritable("{} not writeable".format(path))
# Clear out any any existing lines
self.lines = []
with open(path) as f:
for i, line in enumerate(f):
self.parse_line(line, i)
def parse_line(self, line, line_number):
"""
Parse a line of our requirements file for later use
"""
# Save line untouched to rewrite it
self.lines.append(line.strip())
if "==" in line:
package, version = line.split("==")
self.packages[package] = version
if UNPINNED_RE.match(line):
click.secho(
"WARNING: Found unpinned package '{}' at line {}.".format(
line.strip(), line_number
),
fg="red",
)
def save(self, lines):
"""
Save these lines to the requirements.txt file
"""
# Don't do anything if there isn't anything to do
if not lines:
return False
# Don't do anything if there isn't a file to update
if not self.exists:
return False
# Always re-read in case something has changed
self.read(self.file_path)
new_lines = []
for r in lines:
FOUND = False
for l in self.lines:
l = l.strip()
# Skip lines we can't handle
if "==" not in l:
new_lines.append(l)
continue
pkg, version = l.split("==", 1)
if pkg in r:
new_lines.append(r)
FOUND = True
else:
new_lines.append(l)
if not FOUND:
new_lines.append(r)
with open(self.file_path, "w") as f:
for l in new_lines:
f.write("{}\n".format(l))
return True