Skip to content

Commit

Permalink
Merge pull request #1197 from vantage6/change/decode-all-envvars-in-w…
Browse files Browse the repository at this point in the history
…rapper

Decoding all environment variables in algorithm tools wrapper, so tha…
  • Loading branch information
bartvanb committed Apr 15, 2024
2 parents 84a873a + 4b4c78f commit b4e5e97
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 37 deletions.
14 changes: 3 additions & 11 deletions docs/algorithms/develop.rst
Expand Up @@ -97,20 +97,12 @@ as follows:

.. code:: python
from vantage6.algorithm.tools.util import get_env_var
def my_function():
input_file = get_env_var("INPUT_FILE")
token_file = get_env_var("DEFAULT_DATABASE_URI")
input_file = os.environ["INPUT_FILE"]
token_file = os.environ["DEFAULT_DATABASE_URI"]
# do something with the input file and database URI
.. note::

The ``get_env_var`` function is used here rather than the standard
``os.environ`` dictionary because the environment variables are encoded
for security purposes. The ``get_env_var`` function will decode the
environment variable for you.
pass
The environment variables that you specify in the node configuration file
can be used in the exact same manner. You can view all environment variables
Expand Down
48 changes: 24 additions & 24 deletions vantage6-algorithm-tools/vantage6/algorithm/tools/decorators.py
Expand Up @@ -10,7 +10,7 @@

from vantage6.algorithm.client import AlgorithmClient
from vantage6.algorithm.tools.mock_client import MockAlgorithmClient
from vantage6.algorithm.tools.util import info, error, warn, get_env_var
from vantage6.algorithm.tools.util import info, error, warn
from vantage6.algorithm.tools.wrappers import load_data
from vantage6.algorithm.tools.preprocessing import preprocess_data

Expand Down Expand Up @@ -93,12 +93,12 @@ def decorator(
if mock_client is not None:
return func(mock_client, *args, **kwargs)
# read server address from the environment
host = get_env_var("HOST")
port = get_env_var("PORT")
api_path = get_env_var("API_PATH")
host = os.environ["HOST"]
port = os.environ["PORT"]
api_path = os.environ["API_PATH"]

# read token from the environment
token_file = get_env_var("TOKEN_FILE")
token_file = os.environ["TOKEN_FILE"]
info("Reading token")
with open(token_file) as fp:
token = fp.read().strip()
Expand Down Expand Up @@ -195,7 +195,7 @@ def decorator(

# do any data preprocessing here
info(f"Applying preprocessing for database '{label}'")
env_prepro = get_env_var(f"{label.upper()}_PREPROCESSING")
env_prepro = os.environ.get(f"{label.upper()}_PREPROCESSING")
if env_prepro is not None:
preprocess = json.loads(env_prepro)
data_ = preprocess_data(data_, preprocess)
Expand Down Expand Up @@ -309,7 +309,7 @@ def decorator(*args, **kwargs) -> callable:
>>> def my_algorithm(metadata: RunMetaData, <other arguments>):
>>> pass
"""
token_file = get_env_var("TOKEN_FILE")
token_file = os.environ["TOKEN_FILE"]
info("Reading token")
with open(token_file) as fp:
token = fp.read().strip()
Expand All @@ -322,10 +322,10 @@ def decorator(*args, **kwargs) -> callable:
node_id=payload["node_id"],
collaboration_id=payload["collaboration_id"],
organization_id=payload["organization_id"],
temporary_directory=Path(get_env_var("TEMPORARY_FOLDER")),
output_file=Path(get_env_var("OUTPUT_FILE")),
input_file=Path(get_env_var("INPUT_FILE")),
token_file=Path(get_env_var("TOKEN_FILE")),
temporary_directory=Path(os.environ["TEMPORARY_FOLDER"]),
output_file=Path(os.environ["OUTPUT_FILE"]),
input_file=Path(os.environ["INPUT_FILE"]),
token_file=Path(os.environ["TOKEN_FILE"]),
)
return func(metadata, *args, **kwargs)

Expand Down Expand Up @@ -355,11 +355,11 @@ def get_ohdsi_metadata(label: str) -> OHDSIMetaData:
for var in expected_env_vars:
_check_environment_var_exists_or_exit(f"{label_}_DB_PARAM_{var}")

tmp = Path(get_env_var("TEMPORARY_FOLDER"))
tmp = Path(os.environ["TEMPORARY_FOLDER"])
metadata = OHDSIMetaData(
database=get_env_var(f"{label_}_DB_PARAM_CDM_DATABASE"),
cdm_schema=get_env_var(f"{label_}_DB_PARAM_CDM_SCHEMA"),
results_schema=get_env_var(f"{label_}_DB_PARAM_RESULTS_SCHEMA"),
database=os.environ[f"{label_}_DB_PARAM_CDM_DATABASE"],
cdm_schema=os.environ[f"{label_}_DB_PARAM_CDM_SCHEMA"],
results_schema=os.environ[f"{label_}_DB_PARAM_RESULTS_SCHEMA"],
incremental_folder=tmp / "incremental",
cohort_statistics_folder=tmp / "cohort_statistics",
export_folder=tmp / "export",
Expand Down Expand Up @@ -417,10 +417,10 @@ def _create_omop_database_connection(label: str) -> callable:
_check_environment_var_exists_or_exit(f"{label_}_DB_PARAM_{var}")

info("Reading OHDSI environment variables")
dbms = get_env_var(f"{label_}_DB_PARAM_DBMS")
uri = get_env_var(f"{label_}_DATABASE_URI")
user = get_env_var(f"{label_}_DB_PARAM_USER")
password = get_env_var(f"{label_}_DB_PARAM_PASSWORD")
dbms = os.environ[f"{label_}_DB_PARAM_DBMS"]
uri = os.environ[f"{label_}_DATABASE_URI"]
user = os.environ[f"{label_}_DB_PARAM_USER"]
password = os.environ[f"{label_}_DB_PARAM_PASSWORD"]
info(f" - dbms: {dbms}")
info(f" - uri: {uri}")
info(f" - user: {user}")
Expand Down Expand Up @@ -460,20 +460,20 @@ def _get_data_from_label(label: str) -> pd.DataFrame:
Data from the database
"""
# Load the input data from the input file - this may e.g. include the
database_uri = get_env_var(f"{label.upper()}_DATABASE_URI")
database_uri = os.environ[f"{label.upper()}_DATABASE_URI"]
info(f"Using '{database_uri}' with label '{label}' as database")

# Get the database type from the environment variable, this variable is
# set by the vantage6 node based on its configuration file.
database_type = get_env_var(f"{label.upper()}_DATABASE_TYPE", "csv").lower()
database_type = os.environ.get(f"{label.upper()}_DATABASE_TYPE", "csv").lower()

# Load the data based on the database type. Try to provide environment
# variables that should be available for some data types.
return load_data(
database_uri,
database_type,
query=get_env_var(f"{label.upper()}_QUERY"),
sheet_name=get_env_var(f"{label.upper()}_SHEET_NAME"),
query=os.environ.get(f"{label.upper()}_QUERY"),
sheet_name=os.environ.get(f"{label.upper()}_SHEET_NAME"),
)


Expand All @@ -488,7 +488,7 @@ def _get_user_database_labels() -> list[str]:
"""
# read the labels that the user requested, which is a comma
# separated list of labels.
labels = get_env_var("USER_REQUESTED_DATABASE_LABELS")
labels = os.environ["USER_REQUESTED_DATABASE_LABELS"]
return labels.split(",")


Expand Down
8 changes: 8 additions & 0 deletions vantage6-algorithm-tools/vantage6/algorithm/tools/util.py
@@ -1,6 +1,8 @@
import sys
import os
import base64
import binascii

from vantage6.common.globals import STRING_ENCODING, ENV_VAR_EQUALS_REPLACEMENT


Expand Down Expand Up @@ -40,6 +42,9 @@ def error(msg: str) -> None:
sys.stdout.write(f"error > {msg}\n")


# TODO v5+ move this function to wrap.py and no longer expose it to be used by
# algorithms but as part of _decode_env_vars. It is kept here for backwards
# compatibility with 4.2/4.3 algorithms
def get_env_var(var_name: str, default: str | None = None) -> str:
"""
Get the value of an environment variable. Environment variables are encoded
Expand Down Expand Up @@ -69,3 +74,6 @@ def get_env_var(var_name: str, default: str | None = None) -> str:
return base64.b32decode(encoded_env_var_value).decode(STRING_ENCODING)
except KeyError:
return default
except binascii.Error:
# If the decoding fails, return the original value
return os.environ[var_name]
20 changes: 18 additions & 2 deletions vantage6-algorithm-tools/vantage6/algorithm/tools/wrap.py
Expand Up @@ -52,8 +52,11 @@ def wrap_algorithm(log_traceback: bool = True) -> None:
exit(1)
info(f"wrapper for {module}")

# Decode environment variables that are encoded by the node.
_decode_env_vars()

# read input from the mounted input file.
input_file = get_env_var("INPUT_FILE")
input_file = os.environ["INPUT_FILE"]
info(f"Reading input file {input_file}")
input_data = load_input(input_file)

Expand All @@ -63,7 +66,7 @@ def wrap_algorithm(log_traceback: bool = True) -> None:

# write output from the method to mounted output file. Which will be
# transferred back to the server by the node-instance.
output_file = get_env_var("OUTPUT_FILE")
output_file = os.environ["OUTPUT_FILE"]
info(f"Writing output to {output_file}")

_write_output(output, output_file)
Expand Down Expand Up @@ -170,3 +173,16 @@ def _write_output(output: Any, output_file: str) -> None:
with open(output_file, "wb") as fp:
serialized = serialization.serialize(output)
fp.write(serialized)


def _decode_env_vars() -> None:
"""
Decode environment variables that are encoded by the node
Note that environment variables may be present that are not specific to vantage6,
such as HOME, PATH, etc. These are not encoded by the node and should not be
decoded here. The `get_env_var` function handles these properly so that the
original value is returned if the environment variable is not encoded.
"""
for env_var in os.environ:
os.environ[env_var] = get_env_var(env_var)

0 comments on commit b4e5e97

Please sign in to comment.