Skip to content

Commit 20079bd

Browse files
committed
Starter project files, async db chapter.
1 parent e4e0249 commit 20079bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1610
-0
lines changed
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import json
2+
import os
3+
import sys
4+
import time
5+
from typing import List, Optional, Dict
6+
7+
# noinspection PyPackageRequirements
8+
import progressbar
9+
from dateutil.parser import parse
10+
11+
sys.path.insert(0, os.path.abspath(os.path.join(
12+
os.path.dirname(__file__), "..")))
13+
14+
import data.db_session as db_session
15+
from data.package import Package
16+
from data.release import Release
17+
from data.user import User
18+
19+
20+
def main():
21+
init_db()
22+
session = db_session.create_session()
23+
user_count = session.query(User).count()
24+
session.close()
25+
if user_count == 0:
26+
file_data = do_load_files()
27+
users = find_users(file_data)
28+
29+
db_users = do_user_import(users)
30+
do_import_packages(file_data, db_users)
31+
32+
do_summary()
33+
34+
35+
def do_summary():
36+
session = db_session.create_session()
37+
38+
print("Final numbers:")
39+
print("Users: {:,}".format(session.query(User).count()))
40+
print("Packages: {:,}".format(session.query(Package).count()))
41+
print("Releases: {:,}".format(session.query(Release).count()))
42+
43+
44+
def do_user_import(user_lookup: Dict[str, str]) -> Dict[str, User]:
45+
print("Importing users ... ", flush=True)
46+
with progressbar.ProgressBar(max_value=len(user_lookup)) as bar:
47+
for idx, (email, name) in enumerate(user_lookup.items()):
48+
session = db_session.create_session()
49+
session.expire_on_commit = False
50+
51+
user = User()
52+
user.email = email
53+
user.name = name
54+
session.add(user)
55+
56+
session.commit()
57+
bar.update(idx)
58+
59+
print()
60+
sys.stderr.flush()
61+
sys.stdout.flush()
62+
63+
session = db_session.create_session()
64+
return {u.email: u for u in session.query(User)}
65+
66+
67+
def do_import_packages(file_data: List[dict], user_lookup: Dict[str, User]):
68+
errored_packages = []
69+
print("Importing packages and releases ... ", flush=True)
70+
with progressbar.ProgressBar(max_value=len(file_data)) as bar:
71+
for idx, p in enumerate(file_data):
72+
try:
73+
load_package(p, user_lookup)
74+
bar.update(idx)
75+
except Exception as x:
76+
errored_packages.append((p, " *** Errored out for package {}, {}".format(p.get('package_name'), x)))
77+
raise
78+
sys.stderr.flush()
79+
sys.stdout.flush()
80+
print()
81+
print("Completed packages with {} errors.".format(len(errored_packages)))
82+
for (p, txt) in errored_packages:
83+
print(txt)
84+
85+
86+
def do_load_files() -> List[dict]:
87+
data_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../data/pypi-top-100'))
88+
print("Loading files from {}".format(data_path))
89+
files = get_file_names(data_path)
90+
print("Found {:,} files, loading ...".format(len(files)), flush=True)
91+
time.sleep(.1)
92+
93+
file_data = []
94+
with progressbar.ProgressBar(max_value=len(files)) as bar:
95+
for idx, f in enumerate(files):
96+
file_data.append(load_file_data(f))
97+
bar.update(idx)
98+
99+
sys.stderr.flush()
100+
sys.stdout.flush()
101+
print()
102+
return file_data
103+
104+
105+
def find_users(data: List[dict]) -> dict:
106+
print("Discovering users...", flush=True)
107+
found_users = {}
108+
109+
with progressbar.ProgressBar(max_value=len(data)) as bar:
110+
for idx, p in enumerate(data):
111+
info = p.get('info')
112+
found_users.update(get_email_and_name_from_text(info.get('author'), info.get('author_email')))
113+
found_users.update(get_email_and_name_from_text(info.get('maintainer'), info.get('maintainer_email')))
114+
bar.update(idx)
115+
116+
sys.stderr.flush()
117+
sys.stdout.flush()
118+
print()
119+
print("Discovered {:,} users".format(len(found_users)))
120+
print()
121+
122+
return found_users
123+
124+
125+
def get_email_and_name_from_text(name: str, email: str) -> dict:
126+
data = {}
127+
128+
if not name or not email:
129+
return data
130+
131+
emails = email.strip().lower().split(',')
132+
names = name
133+
if len(email) > 1:
134+
names = name.strip().split(',')
135+
136+
for n, e in zip(names, emails):
137+
if not n or not e:
138+
continue
139+
140+
data[e.strip()] = n.strip()
141+
142+
return data
143+
144+
145+
def load_file_data(filename: str) -> dict:
146+
try:
147+
with open(filename, 'r', encoding='utf-8') as fin:
148+
data = json.load(fin)
149+
except Exception as x:
150+
print("ERROR in file: {}, details: {}".format(filename, x), flush=True)
151+
raise
152+
153+
return data
154+
155+
156+
def load_package(data: dict, user_lookup: Dict[str, User]):
157+
try:
158+
info = data.get('info', {})
159+
160+
p = Package()
161+
p.id = data.get('package_name', '').strip()
162+
if not p.id:
163+
return
164+
165+
p.author = info.get('author')
166+
p.author_email = info.get('author_email')
167+
168+
releases = build_releases(p.id, data.get("releases", {}))
169+
170+
if releases:
171+
p.created_date = releases[0].created_date
172+
173+
maintainers_lookup = get_email_and_name_from_text(info.get('maintainer'), info.get('maintainer_email'))
174+
maintainers = []
175+
176+
p.summary = info.get('summary')
177+
p.description = info.get('description')
178+
179+
p.home_page = info.get('home_page')
180+
p.docs_url = info.get('docs_url')
181+
p.package_url = info.get('package_url')
182+
183+
p.author = info.get('author')
184+
p.author_email = info.get('author_email')
185+
p.license = detect_license(info.get('license'))
186+
187+
session = db_session.create_session()
188+
session.add(p)
189+
session.add_all(releases)
190+
if maintainers:
191+
session.add_all(maintainers)
192+
session.commit()
193+
session.close()
194+
except OverflowError:
195+
# What the heck, people just putting fake data in here
196+
# Size is terabytes...
197+
pass
198+
except Exception:
199+
raise
200+
201+
202+
def detect_license(license_text: str) -> Optional[str]:
203+
if not license_text:
204+
return None
205+
206+
license_text = license_text.strip()
207+
208+
if len(license_text) > 100 or '\n' in license_text:
209+
return "CUSTOM"
210+
211+
license_text = license_text \
212+
.replace('Software License', '') \
213+
.replace('License', '')
214+
215+
if '::' in license_text:
216+
# E.g. 'License :: OSI Approved :: Apache Software License'
217+
return license_text \
218+
.split(':')[-1] \
219+
.replace(' ', ' ') \
220+
.strip()
221+
222+
return license_text.strip()
223+
224+
225+
def build_releases(package_id: str, releases: dict) -> List[Release]:
226+
db_releases = []
227+
for k in releases.keys():
228+
all_releases_for_version = releases.get(k)
229+
if not all_releases_for_version:
230+
continue
231+
232+
v = all_releases_for_version[-1]
233+
234+
r = Release()
235+
r.package_id = package_id
236+
r.major_ver, r.minor_ver, r.build_ver = make_version_num(k)
237+
r.created_date = parse(v.get('upload_time'))
238+
r.comment = v.get('comment_text')
239+
r.url = v.get('url')
240+
r.size = int(v.get('size', 0))
241+
242+
db_releases.append(r)
243+
244+
return db_releases
245+
246+
247+
def make_version_num(version_text):
248+
major, minor, build = 0, 0, 0
249+
if version_text:
250+
version_text = version_text.split('b')[0]
251+
parts = version_text.split('.')
252+
if len(parts) == 1:
253+
major = try_int(parts[0])
254+
elif len(parts) == 2:
255+
major = try_int(parts[0])
256+
minor = try_int(parts[1])
257+
elif len(parts) == 3:
258+
major = try_int(parts[0])
259+
minor = try_int(parts[1])
260+
build = try_int(parts[2])
261+
262+
return major, minor, build
263+
264+
265+
def try_int(text) -> int:
266+
try:
267+
return int(text)
268+
except:
269+
return 0
270+
271+
272+
def init_db():
273+
top_folder = os.path.dirname(__file__)
274+
rel_file = os.path.join('..', 'db', 'pypi.sqlite')
275+
db_file = os.path.abspath(os.path.join(top_folder, rel_file))
276+
db_session.global_init(db_file)
277+
278+
279+
def get_file_names(data_path: str) -> List[str]:
280+
files = []
281+
for f in os.listdir(data_path):
282+
if f.endswith('.json'):
283+
files.append(
284+
os.path.abspath(os.path.join(data_path, f))
285+
)
286+
287+
files.sort()
288+
return files
289+
290+
291+
if __name__ == '__main__':
292+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# noinspection PyUnresolvedReferences
2+
from data.package import Package
3+
# noinspection PyUnresolvedReferences
4+
from data.user import User
5+
# noinspection PyUnresolvedReferences
6+
from data.release import Release
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from pathlib import Path
2+
from typing import Callable, Optional
3+
4+
import sqlalchemy as sa
5+
import sqlalchemy.orm as orm
6+
from sqlalchemy.orm import Session
7+
8+
from data.modelbase import SqlAlchemyBase
9+
10+
__factory: Optional[Callable[[], Session]] = None
11+
12+
13+
def global_init(db_file: str):
14+
global __factory
15+
16+
if __factory:
17+
return
18+
19+
if not db_file or not db_file.strip():
20+
raise Exception("You must specify a db file.")
21+
22+
folder = Path(db_file).parent
23+
folder.mkdir(parents=True, exist_ok=True)
24+
25+
conn_str = 'sqlite:///' + db_file.strip()
26+
print("Connecting to DB with {}".format(conn_str))
27+
28+
# Adding check_same_thread = False after the recording. This can be an issue about
29+
# creating / owner thread when cleaning up sessions, etc. This is a sqlite restriction
30+
# that we probably don't care about in this example.
31+
engine = sa.create_engine(conn_str, echo=False, connect_args={"check_same_thread": False})
32+
__factory = orm.sessionmaker(bind=engine)
33+
34+
# noinspection PyUnresolvedReferences
35+
import data.__all_models
36+
37+
SqlAlchemyBase.metadata.create_all(engine)
38+
39+
40+
def create_session() -> Session:
41+
global __factory
42+
43+
if not __factory:
44+
raise Exception("You must call global_init() before using this method.")
45+
46+
session: Session = __factory()
47+
session.expire_on_commit = False
48+
49+
return session
50+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import sqlalchemy.ext.declarative
2+
3+
SqlAlchemyBase = sqlalchemy.ext.declarative.declarative_base()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import datetime
2+
from typing import List
3+
4+
import sqlalchemy as sa
5+
import sqlalchemy.orm as orm
6+
from data.modelbase import SqlAlchemyBase
7+
from data.release import Release
8+
9+
10+
class Package(SqlAlchemyBase):
11+
__tablename__ = 'packages'
12+
13+
id: str = sa.Column(sa.String, primary_key=True)
14+
created_date: datetime.datetime = sa.Column(sa.DateTime, default=datetime.datetime.now, index=True)
15+
last_updated: datetime.datetime = sa.Column(sa.DateTime, default=datetime.datetime.now, index=True)
16+
summary: str = sa.Column(sa.String, nullable=False)
17+
description: str = sa.Column(sa.String, nullable=True)
18+
19+
home_page: str = sa.Column(sa.String)
20+
docs_url: str = sa.Column(sa.String)
21+
package_url: str = sa.Column(sa.String)
22+
23+
author_name: str = sa.Column(sa.String)
24+
author_email: str = sa.Column(sa.String, index=True)
25+
26+
license: str = sa.Column(sa.String, index=True)
27+
28+
# releases relationship
29+
releases: List[Release] = orm.relation("Release", order_by=[
30+
Release.major_ver.desc(),
31+
Release.minor_ver.desc(),
32+
Release.build_ver.desc(),
33+
], back_populates='package')
34+
35+
def __repr__(self):
36+
return '<Package {}>'.format(self.id)
37+
38+
39+
# p = Package() # one query
40+
#
41+
# print(p.id)
42+
# print("All releases")
43+
# for r in p.releases:
44+
# print("{}.{}.{}".format(r.major_ver, r.minor_ver, r.build_ver))

0 commit comments

Comments
 (0)