Description
Describe the bug
With ModelTrainer, when I'm using the command parameter in the SourceCode with an argument provided as part of the command, for example python launcher.py -e test.py
, hyperparameters defined in the ModelTrainer are not passed to the training script.
To reproduce
A clear, step-by-step set of instructions to reproduce the bug.
from sagemaker.modules.configs import (
Compute,
OutputDataConfig,
RemoteDebugConfig,
SourceCode,
StoppingCondition,
)
from sagemaker.modules.train import ModelTrainer
# Define the script to be run
source_code = SourceCode(
source_dir="./scripts",
requirements="requirements.txt",
command="python launcher.py -e train.py",
)
# Define the compute
compute_configs = Compute(
instance_type=instance_type,
instance_count=instance_count,
keep_alive_period_in_seconds=0,
)
job_name = "train-ray-processing-train"
output_path = f"s3://{bucket_name}/{job_name}"
model_trainer = ModelTrainer(
training_image=image_uri,
source_code=source_code,
base_job_name=job_name,
compute=compute_configs,
hyperparameters={
"epochs": 25,
"learning_rate": 0.001,
"batch_size": 100,
},
stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
output_data_config=OutputDataConfig(
s3_output_path=output_path, compression_type="NONE"
),
role=role,
)
in the launcher.py:
from argparse import ArgumentParser, Namespace
def __read_params():
try:
parser = ArgumentParser()
parser.add_argument("-e", "--entrypoint", type=str)
parser.add_argument("--epochs", type=int, default=25)
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--batch_size", type=int, default=100)
# Parse only the arguments we care about and ignore the rest
args, unknown = parser.parse_known_args()
return args, unknown
except Exception as e:
raise e
if __name__ == "__main__":
args, _ = __read_params()
Expected behavior
both the arguments passed as command in the SoureCode, and the hyperparameters provided in the ModelTrainer definition, should be passed to the training script
Screenshots or logs
If applicable, add screenshots or logs to help explain your problem.
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 2.247.1
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): Any
- Framework version: Any
- Python version: 3.12
- CPU or GPU: CPU and GPU
- Custom Docker image (Y/N): N
Additional context
Add any other context about the problem here.