Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 175 additions & 95 deletions py4DSTEM/utils/configuration_checker.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,96 @@
#### this file contains a function/s that will check if various
# libaries/compute options are available
import importlib
from operator import mod

# list of modules we expect/may expect to be installed
# as part of a standard py4DSTEM installation
# this needs to be the import name e.g. import mp_api not mp-api
modules = [
"crystal4D",
"cupy",
"dask",
"dill",
"distributed",
"gdown",
"h5py",
"ipyparallel",
"jax",
"matplotlib",
"mp_api",
"ncempy",
"numba",
"numpy",
"pymatgen",
"skimage",
"sklearn",
"scipy",
"tensorflow",
"tensorflow-addons",
"tqdm",
]

# currently this was copy and pasted from setup.py,
# hopefully there's a programatic way to do this.
module_depenencies = {
"base": [
"numpy",
"scipy",
"h5py",
"ncempy",
"matplotlib",
"skimage",
"sklearn",
"tqdm",
"dill",
"gdown",
"dask",
"distributed",
],
"ipyparallel": ["ipyparallel", "dill"],
"cuda": ["cupy"],
"acom": ["pymatgen", "mp_api"],
"aiml": ["tensorflow", "tensorflow-addons", "crystal4D"],
"aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"],
"numba": ["numba"],
from importlib.metadata import requires
import re
from importlib.util import find_spec

# need a mapping of pypi/conda names to import names
import_mapping_dict = {
"scikit-image": "skimage",
"scikit-learn": "sklearn",
"scikit-optimize": "skopt",
"mp-api": "mp_api",
}


# programatically get all possible requirements in the import name style
def get_modules_list():
# Get the dependencies from the installed distribution
dependencies = requires("py4DSTEM")

# Define a regular expression pattern for splitting on '>', '>=', '='
delimiter_pattern = re.compile(r">=|>|==|<|<=")

# Extract only the module names without versions
module_names = [
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
for dependency in dependencies
]

# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
for index, module in enumerate(module_names):
if module in import_mapping_dict.keys():
module_names[index] = import_mapping_dict[module]

return module_names


# programatically get all possible requirements in the import name style,
# split into a dict where optional import names are keys
def get_modules_dict():
package_name = "py4DSTEM"
# Get the dependencies from the installed distribution
dependencies = requires(package_name)

# set the dictionary for modules and packages to go into
# optional dependencies will be added as they are discovered
modules_dict = {
"base": [],
}
# loop over the dependencies
for depend in dependencies:
# all the optional have extra in the name
# if its not there append it to base
if "extra" not in depend:
# String looks like: 'numpy>=1.19'
modules_dict["base"].append(depend)

# if it has extra in the string
else:
# get the name of the optional name
# depend looks like this 'numba>=0.49.1; extra == "numba"'
# grab whatever is in the double quotes i.e. numba
optional_name = re.search(r'"(.*?)"', depend).group(1)
# if the optional name is not in the dict as a key i.e. first requirement of hte optional dependency
if optional_name not in modules_dict:
modules_dict[optional_name] = [depend]
# if the optional_name is already in the dict then just append it to the list
else:
modules_dict[optional_name].append(depend)
# STRIP all the versioning and semi-colons
# Define a regular expression pattern for splitting on '>', '>=', '='
delimiter_pattern = re.compile(r">=|>|==|<|<=")
for key, val in modules_dict.items():
# modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val]
modules_dict[key] = [
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
for dependency in val
]

# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
for key, val in modules_dict.items():
for index, module in enumerate(val):
if module in import_mapping_dict.keys():
val[index] = import_mapping_dict[module]

return modules_dict


# module_depenencies = get_modules_dict()
modules = get_modules_list()


#### Class and Functions to Create Coloured Strings ####
class colours:
CEND = "\x1b[0m"
Expand Down Expand Up @@ -140,6 +175,7 @@ def create_underline(s: str) -> str:
### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used


# get the state of each modules as a dict key-val e.g. "numpy" : True
def get_import_states(modules: list = modules) -> dict:
"""
Check the ability to import modules and store the results as a boolean value. Returns as a dict.
Expand All @@ -163,16 +199,17 @@ def get_import_states(modules: list = modules) -> dict:
return import_states_dict


# Check
def get_module_states(state_dict: dict) -> dict:
"""_summary_

Args:
state_dict (dict): _description_
"""
given a state dict for all modules e.g. "numpy" : True,
this parses through and checks if all modules required for a state are true

Returns:
dict: _description_
returns dict "base": True, "ai-ml": False etc.
"""

# get the modules_dict
module_depenencies = get_modules_dict()
# create an empty dict to put module states into:
module_states = {}

Expand All @@ -196,13 +233,12 @@ def get_module_states(state_dict: dict) -> dict:


def print_import_states(import_states: dict) -> None:
"""_summary_

Args:
import_states (dict): _description_
"""
print with colours if the library could be imported or not
takes dict
"numpy" : True -> prints success
"pymatgen" : False -> prints failure

Returns:
_type_: _description_
"""
# m is the name of the import module
# state is whether it was importable
Expand All @@ -223,13 +259,11 @@ def print_import_states(import_states: dict) -> None:


def print_module_states(module_states: dict) -> None:
"""_summary_

Args:
module_states (dict): _description_

Returns:
_type_: _description_
"""
print with colours if all the imports required for module could be imported or not
takes dict
"base" : True -> prints success
"ai-ml" : Fasle -> prints failure
"""
# Print out the state of all the modules in colour code
# key is the name of a py4DSTEM Module
Expand All @@ -248,25 +282,33 @@ def print_module_states(module_states: dict) -> None:
return None


def perfrom_extra_checks(
def perform_extra_checks(
import_states: dict, verbose: bool, gratuitously_verbose: bool, **kwargs
) -> None:
"""_summary_

Args:
import_states (dict): _description_
verbose (bool): _description_
gratuitously_verbose (bool): _description_
import_states (dict): dict of modules and if they could be imported or not
verbose (bool): will show module states and all import states
gratuitously_verbose (bool): will run extra checks - Currently only for cupy

Returns:
_type_: _description_
"""

# print a output module
extra_checks_message = "Running Extra Checks"
extra_checks_message = create_bold(extra_checks_message)
print(f"{extra_checks_message}")
# For modules that import run any extra checks
if gratuitously_verbose:
# print a output module
extra_checks_message = "Running Extra Checks"
extra_checks_message = create_bold(extra_checks_message)
print(f"{extra_checks_message}")
# For modules that import run any extra checks
# get all the dependencies
dependencies = requires("py4DSTEM")
# Extract only the module names with versions
depends_with_requirements = [
dependency.split(";")[0] for dependency in dependencies
]
# print(depends_with_requirements)
# need to go from
for key, val in import_states.items():
if val:
# s = create_underline(key.capitalize())
Expand All @@ -281,7 +323,10 @@ def perfrom_extra_checks(
if gratuitously_verbose:
s = create_underline(key.capitalize())
print(s)
print_no_extra_checks(key)
# check
generic_versions(
key, depends_with_requires=depends_with_requirements
)
else:
pass

Expand All @@ -304,7 +349,7 @@ def import_tester(m: str) -> bool:
# try and import the module
try:
importlib.import_module(m)
except:
except Exception:
state = False

return state
Expand All @@ -324,6 +369,7 @@ def check_module_functionality(state_dict: dict) -> None:

# create an empty dict to put module states into:
module_states = {}
module_depenencies = get_modules_dict()

# key is the name of the module e.g. ACOM
# val is a list of its dependencies
Expand Down Expand Up @@ -359,6 +405,45 @@ def check_module_functionality(state_dict: dict) -> None:
#### ADDTIONAL CHECKS ####


def generic_versions(module: str, depends_with_requires: list[str]) -> None:
# module will be like numpy, skimage
# depends_with_requires look like: numpy >= 19.0, scikit-image
# get module_translated_name
# mapping scikit-image : skimage
for key, value in import_mapping_dict.items():
# if skimage == skimage get scikit-image
# print(f"{key = } - {value = } - {module = }")
if module in value:
module_depend_name = key
break
else:
# if cant find mapping set the search name to the same
module_depend_name = module
# print(f"{module_depend_name = }")
# find the requirement
for depend in depends_with_requires:
if module_depend_name in depend:
spec_required = depend
# print(f"{spec_required = }")
# get the version installed
spec_installed = find_spec(module)
if spec_installed is None:
s = f"{module} unable to import - {spec_required} required"
s = create_failure(s)
s = f"{s: <80}"
print(s)

else:
try:
version = importlib.metadata.version(module_depend_name)
except Exception:
version = "Couldn't test version"
s = f"{module} imported: {version = } - {spec_required} required"
s = create_warning(s)
s = f"{s: <80}"
print(s)


def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
"""
This function performs some additional tests which may be useful in
Expand All @@ -375,25 +460,18 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
# check that CUDA is detected correctly
cuda_availability = cp.cuda.is_available()
if cuda_availability:
s = " CUDA is Available "
s = f" CUDA is Available "
s = create_success(s)
s = f"{s: <80}"
print(s)
else:
s = " CUDA is Unavailable "
s = f" CUDA is Unavailable "
s = create_failure(s)
s = f"{s: <80}"
print(s)

# Count how many GPUs Cupy can detect
# probably should change this to a while loop ...
for i in range(24):
try:
d = cp.cuda.Device(i)
hasattr(d, "attributes")
except:
num_gpus_detected = i
break
num_gpus_detected = cp.cuda.runtime.getDeviceCount()

# print how many GPUs were detected, filter for a couple of special conditons
if num_gpus_detected == 0:
Expand Down Expand Up @@ -448,7 +526,9 @@ def print_no_extra_checks(m: str):


# dict of extra check functions
funcs_dict = {"cupy": check_cupy_gpu}
funcs_dict = {
"cupy": check_cupy_gpu,
}


#### main function used to check the configuration of the installation
Expand Down Expand Up @@ -493,7 +573,7 @@ def check_config(

print_import_states(states_dict)

perfrom_extra_checks(
perform_extra_checks(
import_states=states_dict,
verbose=verbose,
gratuitously_verbose=gratuitously_verbose,
Expand Down