Skip to content

Commit

Permalink
Support multiple CUDA versions (#33)
Browse files Browse the repository at this point in the history
* Support multiple CUDA versions

* cleanup

* exclude specific windows combinations from tests

* fix cpu backend ordering

* prioritize backend over version

* fix docstring
  • Loading branch information
pmeier committed Jun 29, 2021
1 parent eff3ea9 commit 73e8fc8
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 415 deletions.
113 changes: 74 additions & 39 deletions light_the_torch/_pip/find.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import re
from typing import Any, Iterable, List, NoReturn, Optional, Text, Tuple, Union, cast
from typing import (
Any,
Collection,
Iterable,
List,
NoReturn,
Optional,
Set,
Text,
Tuple,
Union,
cast,
)

from pip._internal.index.collector import LinkCollector
from pip._internal.index.package_finder import (
Expand All @@ -13,8 +25,10 @@
from pip._internal.models.search_scope import SearchScope
from pip._internal.req.req_install import InstallRequirement
from pip._internal.req.req_set import RequirementSet
from pip._vendor.packaging.version import Version

import light_the_torch.computation_backend as cb

from ..computation_backend import ComputationBackend, detect_computation_backend
from .common import (
InternalLTTError,
PatchedInstallCommand,
Expand All @@ -29,7 +43,9 @@

def find_links(
pip_install_args: List[str],
computation_backend: Optional[Union[str, ComputationBackend]] = None,
computation_backends: Optional[
Union[cb.ComputationBackend, Collection[cb.ComputationBackend]]
] = None,
channel: str = "stable",
platform: Optional[str] = None,
python_version: Optional[str] = None,
Expand All @@ -41,9 +57,9 @@ def find_links(
Args:
pip_install_args: Arguments passed to ``pip install`` that will be searched for
required PyTorch distributions
computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
Defaults to the available hardware of the running system preferring CUDA
over CPU.
computation_backends: Collection of supported computation backends, for example
``"cpu"`` or ``"cu102"``. Defaults to the available hardware of the running
system.
channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
``"test"``, and ``"nightly"``.
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
Expand All @@ -55,10 +71,12 @@ def find_links(
Returns:
Wheel links with given properties for all required PyTorch distributions.
"""
if computation_backend is None:
computation_backend = detect_computation_backend()
elif isinstance(computation_backend, str):
computation_backend = ComputationBackend.from_str(computation_backend)
if computation_backends is None:
computation_backends = cb.detect_compatible_computation_backends()
elif isinstance(computation_backends, cb.ComputationBackend):
computation_backends = {computation_backends}
else:
computation_backends = set(computation_backends)

if channel not in ("stable", "test", "nightly"):
raise ValueError(
Expand All @@ -69,7 +87,7 @@ def find_links(
dists = extract_dists(pip_install_args)

cmd = StopAfterPytorchLinksFoundCommand(
computation_backend=computation_backend, channel=channel
computation_backends=computation_backends, channel=channel
)
pip_install_args = adjust_pip_install_args(dists, platform, python_version)
options, args = cmd.parser.parse_args(pip_install_args)
Expand Down Expand Up @@ -172,37 +190,43 @@ def extract_computation_backend_from_link(self, link: Link) -> Optional[str]:

class PytorchCandidatePreferences(CandidatePreferences):
def __init__(
self, *args: Any, computation_backend: ComputationBackend, **kwargs: Any,
self,
*args: Any,
computation_backends: Set[cb.ComputationBackend],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.computation_backend = computation_backend
self.computation_backends = computation_backends

@classmethod
def from_candidate_preferences(
cls,
candidate_preferences: CandidatePreferences,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
) -> "PytorchCandidatePreferences":
return new_from_similar(
cls,
candidate_preferences,
("prefer_binary", "allow_all_prereleases",),
computation_backend=computation_backend,
computation_backends=computation_backends,
)


class PytorchCandidateEvaluator(CandidateEvaluator):
def __init__(
self, *args: Any, computation_backend: ComputationBackend, **kwargs: Any,
self,
*args: Any,
computation_backends: Set[cb.ComputationBackend],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.computation_backend = computation_backend
self.computation_backends = {cb.AnyBackend(), *computation_backends}

@classmethod
def from_candidate_evaluator(
cls,
candidate_evaluator: CandidateEvaluator,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
) -> "PytorchCandidateEvaluator":
return new_from_similar(
cls,
Expand All @@ -215,51 +239,62 @@ def from_candidate_evaluator(
"allow_all_prereleases",
"hashes",
),
computation_backend=computation_backend,
computation_backends=computation_backends,
)

def _sort_key(
self, candidate: InstallationCandidate
) -> Tuple[cb.ComputationBackend, Version]:
version = Version(
f"{candidate.version.major}"
f".{candidate.version.minor}"
f".{candidate.version.micro}"
)
computation_backend = cb.ComputationBackend.from_str(candidate.version.local)
return computation_backend, version

def get_applicable_candidates(
self, candidates: List[InstallationCandidate]
) -> List[InstallationCandidate]:
return [
candidate
for candidate in super().get_applicable_candidates(candidates)
if candidate.version.local == "any"
or candidate.version.local == self.computation_backend
if candidate.version.local in self.computation_backends
]


class PytorchLinkCollector(LinkCollector):
def __init__(
self,
*args: Any,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if channel == "stable":
url = "https://download.pytorch.org/whl/torch_stable.html"
urls = ["https://download.pytorch.org/whl/torch_stable.html"]
else:
url = (
urls = [
f"https://download.pytorch.org/whl/"
f"{channel}/{computation_backend}/torch_{channel}.html"
)
self.search_scope = SearchScope.create(find_links=[url], index_urls=[])
f"{channel}/{backend}/torch_{channel}.html"
for backend in sorted(computation_backends, key=str)
]
self.search_scope = SearchScope.create(find_links=urls, index_urls=[])

@classmethod
def from_link_collector(
cls,
link_collector: LinkCollector,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
) -> "PytorchLinkCollector":
return new_from_similar(
cls,
link_collector,
("session", "search_scope",),
("session", "search_scope"),
channel=channel,
computation_backend=computation_backend,
computation_backends=computation_backends,
)


Expand All @@ -270,18 +305,18 @@ class PytorchPackageFinder(PackageFinder):
def __init__(
self,
*args: Any,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._candidate_prefs = PytorchCandidatePreferences.from_candidate_preferences(
self._candidate_prefs, computation_backend=computation_backend
self._candidate_prefs, computation_backends=computation_backends
)
self._link_collector = PytorchLinkCollector.from_link_collector(
self._link_collector,
channel=channel,
computation_backend=computation_backend,
computation_backends=computation_backends,
)

def make_candidate_evaluator(
Expand All @@ -290,7 +325,7 @@ def make_candidate_evaluator(
candidate_evaluator = super().make_candidate_evaluator(*args, **kwargs)
return PytorchCandidateEvaluator.from_candidate_evaluator(
candidate_evaluator,
computation_backend=self._candidate_prefs.computation_backend,
computation_backends=self._candidate_prefs.computation_backends,
)

def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator:
Expand All @@ -301,7 +336,7 @@ def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator
def from_package_finder(
cls,
package_finder: PackageFinder,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
) -> "PytorchPackageFinder":
return new_from_similar(
Expand All @@ -315,7 +350,7 @@ def from_package_finder(
"candidate_prefs",
"ignore_requires_python",
),
computation_backend=computation_backend,
computation_backends=computation_backends,
channel=channel,
)

Expand All @@ -338,19 +373,19 @@ def resolve(
class StopAfterPytorchLinksFoundCommand(PatchedInstallCommand):
def __init__(
self,
computation_backend: ComputationBackend,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.computation_backend = computation_backend
self.computation_backends = computation_backends
self.channel = channel

def _build_package_finder(self, *args: Any, **kwargs: Any) -> PytorchPackageFinder:
package_finder = super()._build_package_finder(*args, **kwargs)
return PytorchPackageFinder.from_package_finder(
package_finder,
computation_backend=self.computation_backend,
computation_backends=self.computation_backends,
channel=self.channel,
)

Expand Down
7 changes: 4 additions & 3 deletions light_the_torch/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def _run(self, pip_install_args: List[str]) -> None:

class FindCommand(Command):
def __init__(self, args: argparse.Namespace) -> None:
self.computation_backend = args.computation_backend
# TODO split by comma
self.computation_backends = args.computation_backend
self.channel = args.channel
self.platform = args.platform
self.python_version = args.python_version
Expand All @@ -63,7 +64,7 @@ def __init__(self, args: argparse.Namespace) -> None:
def _run(self, pip_install_args: List[str]) -> None:
links = ltt.find_links(
pip_install_args,
computation_backend=self.computation_backend,
computation_backends=self.computation_backends,
channel=self.channel,
platform=self.platform,
python_version=self.python_version,
Expand All @@ -88,7 +89,7 @@ def __init__(self, args: argparse.Namespace) -> None:
def _run(self, pip_install_args: List[str]) -> None:
links = ltt.find_links(
pip_install_args,
computation_backend=CPUBackend() if self.force_cpu else None,
computation_backends={CPUBackend()} if self.force_cpu else None,
channel=self.channel,
verbose=self.verbose,
)
Expand Down
Loading

0 comments on commit 73e8fc8

Please sign in to comment.