Skip to content

Commit

Permalink
Merge aaf2064 into 51e870e
Browse files Browse the repository at this point in the history
  • Loading branch information
nestormh committed Apr 6, 2021
2 parents 51e870e + aaf2064 commit f12349a
Showing 1 changed file with 162 additions and 34 deletions.
196 changes: 162 additions & 34 deletions medikit/feature/python.py
Expand Up @@ -33,14 +33,17 @@
import os
import tempfile
from getpass import getuser
from types import SimpleNamespace

from pip._vendor.distlib.util import parse_requirement
from pip._internal.req.req_install import InstallRequirement
from piptools._compat import parse_requirements
from piptools.cache import DependencyCache
from piptools.locations import CACHE_DIR
from piptools.repositories import PyPIRepository
from piptools.resolver import Resolver
from piptools.utils import format_requirement
from piptools.exceptions import IncompatibleRequirements

import medikit
from medikit.events import subscribe
Expand All @@ -52,11 +55,27 @@


def _normalize_requirement(req):
bits = req.requirement.split()
""" Normalizes the requirement string. It considers the case of having or not an URL """

if req.constraints and not req.url:
bits = req.requirement.split()
else:
bits = [req.requirement, '@', req.url]

if req.extras:
bits = [bits[0] + "[{}]".format(",".join(req.extras))] + bits[1:]

return " ".join(bits)

def _get_valid_link_req(req):
""" Formats the repo based dependencies, which can be dirty in case of collision between repo based and package
based declarations (for similar packages) """

tokens = str(req).split("@")
req_name = parse_requirement(tokens[0])

return req_name.name + "@ " + "@".join(tokens[1:])


class PythonConfig(Feature.Config):
""" Configuration API for the «python» feature. """
Expand All @@ -73,6 +92,10 @@ def __init__(self):
self._create_packages = True
self.override_requirements = False
self.use_wheelhouse = False
# Use the same requirement versions among all the extras, when requirements coincide.
self.use_uniform_requirements = False
# Print the information of the "parent" requirement in the requirements*.txt files.
self.show_comes_from_info = False

@property
def package_dir(self):
Expand Down Expand Up @@ -217,6 +240,66 @@ def __add_vendors(self, reqs, extra=None):
for req in reqs:
self._vendors[extra].append(req)

def get_requirement_info_by_name(self, req, requirements_by_name=dict()):
""" Given a requirement, it provides its valid information to be included in the final file """

if self.use_uniform_requirements:
# If it is a repo and not a package, this will be True
if req.link:
requirement_name = parse_requirement(_get_valid_link_req(req.req)).name
else:
requirement_name = parse_requirement(str(req.req)).name

# If the requirement is not in the dict, it is because it was not needed as a dependency in the original
# set containing all requirements
if requirement_name in requirements_by_name.keys():
# In case we want to show the source for inherited dependencies
if self.show_comes_from_info and type(req.comes_from) == InstallRequirement:
return "{}\t\t\t# From: {}".format(requirements_by_name[requirement_name].requirement,
str(req.comes_from.req))
else:
return requirements_by_name[requirement_name].requirement

else:
return None
else:
# If not using uniform versions, we just need to provide the information based on wether it is
# a repository or a package
if self.show_comes_from_info and type(req.comes_from) == InstallRequirement:
return "{}\t\t\t# From: {}".format(format_requirement(req) if not req.link else str(req.req),
str(req.comes_from.req))
else:
return format_requirement(req) if not req.link else str(req.req)

def _check_duplicate_dependencies_by_extra(self, extra, requirements_by_name):
""" Checks there are not duplicate dependencies, in the case of private repositories"""

for name, req in sorted(self._requirements[extra].items()):
requirement_str = req.requirement if not req.url else req.url.strip().replace(" ", "")

if name not in requirements_by_name.keys():
# This can happen if additional requirements are included from the outside, for instance with pytest
continue

if (req.url or requirements_by_name[name].url) and requirements_by_name[name].requirement != requirement_str:
raise IncompatibleRequirements(req, requirements_by_name[name].url)

def check_duplicate_dependencies_uniform(self, requirements_by_name):
""" Checks there are not duplicate dependencies, when use_uniform_requirements==True """
for extra in itertools.chain((None,), self.get_extras()):
self._check_duplicate_dependencies_by_extra(extra, requirements_by_name)

def check_duplicate_dependencies_nonuniform(self, extra, resolver):
""" Checks there are not duplicate dependencies, when use_uniform_requirements==False """
requirements_by_name = {}
for req in resolver.resolve(max_rounds=10):
requirements_by_name[parse_requirement(str(req.req)).name] = SimpleNamespace(
requirement=format_requirement(req).strip().replace(" ", ""),
url=req.link
)

self._check_duplicate_dependencies_by_extra(extra, requirements_by_name)


class PythonFeature(Feature):
"""
Expand Down Expand Up @@ -414,48 +497,59 @@ def on_start(self, event):
version = "0.0.0"
self.render_file_inline(python_config.version_file, "__version__ = '{}'".format(version))

setup = python_config.get_setup()

context = {
"url": setup.pop("url", "http://example.com/"),
"download_url": setup.pop("download_url", "http://example.com/"),
}

for k, v in context.items():
context[k] = context[k].format(name=setup["name"], user=getuser(), version="{version}")

context.update(
{
"entry_points": setup.pop("entry_points", {}),
"extras_require": python_config.get("extras_require"),
"install_requires": python_config.get("install_requires"),
"python": python_config,
"setup": setup,
"banner": get_override_warning_banner(),
}
)

# Render (with overwriting) the allmighty setup.py
self.render_file("setup.py", "python/setup.py.j2", context, override=True)

@subscribe(medikit.on_end, priority=ABSOLUTE_PRIORITY)
def on_end(self, event):

# Our config object
python_config = event.config["python"]

# Pip / PyPI
repository = PyPIRepository([], cache_dir=CACHE_DIR)

# We just need to construct this structure if use_uniform_requirements == True
requirements_by_name = {}

if python_config.use_uniform_requirements:
tmpfile = tempfile.NamedTemporaryFile(mode="wt", delete=False)
for extra in itertools.chain((None,), python_config.get_extras()):
tmpfile.write("\n".join(python_config.get_requirements(extra=extra)) + "\n")
tmpfile.flush()

constraints = list(
parse_requirements(
tmpfile.name, finder=repository.finder, session=repository.session, options=repository.options
)
)

# This resolver is able to evaluate ALL the dependencies along the extras
resolver = Resolver(
constraints,
repository,
cache=DependencyCache(CACHE_DIR),
# cache=DependencyCache(tempfile.tempdir),
prereleases=False,
clear_caches=False,
allow_unsafe=False,
)

for req in resolver.resolve(max_rounds=10):
requirements_by_name[parse_requirement(str(req.req)).name] = SimpleNamespace(
requirement=format_requirement(req).strip().replace(" ", ""),
url=req.link
)

python_config.check_duplicate_dependencies_uniform(requirements_by_name)

# Now it iterates along the versions in extras and looks for the requirements and its dependencies, using the
# structure created above to select the unified versions (unless the flag indicates otherwise).
for extra in itertools.chain((None,), python_config.get_extras()):
requirements_file = "requirements{}.txt".format("-" + extra if extra else "")

if python_config.override_requirements or not os.path.exists(requirements_file):
tmpfile = tempfile.NamedTemporaryFile(mode="wt", delete=False)
if extra:
tmpfile.write("\n".join(python_config.get_requirements(extra=extra)))
else:
tmpfile.write("\n".join(python_config.get_requirements()))
tmpfile.write("\n".join(python_config.get_requirements(extra=extra)) + "\n")
tmpfile.flush()

constraints = list(
parse_requirements(
tmpfile.name, finder=repository.finder, session=repository.session, options=repository.options
Expand All @@ -470,19 +564,53 @@ def on_end(self, event):
allow_unsafe=False,
)

if not python_config.use_uniform_requirements:
python_config.check_duplicate_dependencies_nonuniform(extra, resolver)

requirements_list = []
for req in resolver.resolve(max_rounds=10):
if req.name != python_config.get("name"):
requirement = python_config.get_requirement_info_by_name(req, requirements_by_name)
if requirement:
requirements_list.append(requirement)

self.render_file_inline(
requirements_file,
"\n".join(
(
"-e .{}".format("[" + extra + "]" if extra else ""),
*(("-r requirements.txt",) if extra else ()),
*python_config.get_vendors(extra=extra),
*sorted(
format_requirement(req)
for req in resolver.resolve(max_rounds=10)
if req.name != python_config.get("name")
),
*sorted(requirements_list),
)
),
override=python_config.override_requirements,
)

# Updates setup file
setup = python_config.get_setup()

context = {
"url": setup.pop("url", ""),
"download_url": setup.pop("download_url", ""),
}

for k, v in context.items():
context[k] = context[k].format(name=setup["name"], user=getuser(), version="{version}")

context.update(
{
"entry_points": setup.pop("entry_points", {}),
"extras_require": python_config.get("extras_require"),
"install_requires": python_config.get("install_requires"),
"python": python_config,
"setup": setup,
"banner": get_override_warning_banner(),
}
)

from pprint import pprint
pprint(context)

# Render (with overwriting) the allmighty setup.py
self.render_file("setup.py", "python/setup.py.j2", context, override=True)

0 comments on commit f12349a

Please sign in to comment.